Quick Start#
This quick start does the following:
Installthe packageSimulatea toy dataset with ground truth plotsFit GAM modelsincluding NodeGAM, EBM, XGB-GAM, and SplineVisualizethe GAMs’ plots
Installation#
pip install nodegam
Simulate a toy dataset#
To test our package, we simulate a simple binary classification dataset with 3 features sampled from Uniform distributions from -5 to 5 i.e.
\(x_0, x_1, x_2 \sim U[-5, 5]\)
And the target is simulated as:
\(\hat{y} = x_0^2 + 2 * x_1 + sin(x_2)\)
and
\(y \sim \text{Bern}(sigmoid(\hat{y}))\)
[1]:
from nodegam.sklearn import NodeGAMRegressor, NodeGAMClassifier
from nodegam.gams.MySpline import MySplineLogisticGAM, MySplineGAM
from nodegam.gams.MyEBM import MyExplainableBoostingClassifier, MyExplainableBoostingRegressor
from nodegam.gams.MyXGB import MyXGBOnehotClassifier, MyXGBOnehotRegressor
from nodegam.gams.MyBagging import MyBaggingClassifier, MyBaggingRegressor
from nodegam.utils import sigmoid_np, average_GAM_dfs
from nodegam.vis_utils import vis_GAM_effects
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
[2]:
# Generate dataset
N = 25000
x1 = np.random.uniform(-5, 5, size=N)
x2 = np.random.uniform(-5, 5, size=N)
x3 = np.random.uniform(-5, 5, size=N)
f1 = lambda x: (x) ** 2 - 8
f2 = lambda x: x * 2
f3 = lambda x: np.sin(x)
y_prob = sigmoid_np(f1(x1) + f2(x2) + f3(x3))
# Sample
y = (np.random.random(N) < y_prob).astype(int)
X = pd.DataFrame(np.vstack([x1, x2, x3]).T)
X.shape, y.shape
[2]:
((25000, 3), (25000,))
Ground Truth GAM graph
[7]:
x = np.linspace(-5, 5, 1000)
fig, ax = plt.subplots(1, 3, figsize=(18, 4.5))
ax[0].plot(x, f1(x))
ax[1].plot(x, f2(x))
ax[2].plot(x, f3(x))
[7]:
[<matplotlib.lines.Line2D at 0x7ff0a4561390>]
Fit GAM Models#
We test 4 packages: NodeGAM, Spline, EBM, and XGB.
We train NodeGAM 3 times with different seeds to get stdev on the shape graphs
[8]:
n_seeds = 3
# NodeGAM fitting
nodegams = []
for seed in range(n_seeds):
model = NodeGAMClassifier(
in_features=3,
objective='ce_loss',
verbose=False, # Set to True to see the training progress
seed=seed,
)
train_record = model.fit(X, y)
nodegams.append(model)
Other GAM baselines
[ ]:
ebm = MyExplainableBoostingClassifier()
ebm.fit(X, y)
spline = MySplineLogisticGAM(verbose=False)
bagged_spline = MyBaggingClassifier(base_estimator=spline, n_estimators=3) # Bagging 3 times to get stdev on the shape graphs
bagged_spline.fit(X, y)
xgb_gam = MyXGBOnehotClassifier()
bagged_xgb = MyBaggingClassifier(base_estimator=xgb_gam, n_estimators=3)
bagged_xgb.fit(X, y)
Visualization#
We visualize the GAM main effects.
[11]:
fig, ax = vis_GAM_effects({
'nodegam': average_GAM_dfs([m.get_GAM_df(X) for m in nodegams]),
'ebm': ebm.get_GAM_df(),
'xgb-gam': bagged_xgb.get_GAM_df(),
'spline': bagged_spline.get_GAM_df(),
}, top_interactions=0)
100%|██████████| 6/6 [00:00<00:00, 122.51it/s]
bin features 0 with uniq val 24996 to only 256
bin features 1 with uniq val 24993 to only 256
bin features 2 with uniq val 24994 to only 256
Finish "Run values through model" in 99ms
Finish "Extract values" in 53ms
Run "Purify interactions to main effects".........
100%|██████████| 7/7 [00:00<00:00, 46.38it/s]
Finish "Purify interactions to main effects" in 60ms
Finish "Center main effects" in 1ms
Finish "Construct table" in 157ms
bin features 0 with uniq val 24996 to only 256
bin features 1 with uniq val 24993 to only 256
bin features 2 with uniq val 24994 to only 256
Run "Run values through model".........
100%|██████████| 6/6 [00:00<00:00, 123.71it/s]
0%| | 0/7 [00:00<?, ?it/s]
Finish "Run values through model" in 93ms
Finish "Extract values" in 52ms
Finish "Purify interactions to main effects" in 66ms
Finish "Center main effects" in 1ms
Run "Construct table".........
100%|██████████| 7/7 [00:00<00:00, 25.17it/s]
100%|██████████| 6/6 [00:00<00:00, 123.33it/s]
Finish "Construct table" in 283ms
bin features 0 with uniq val 24996 to only 256
bin features 1 with uniq val 24993 to only 256
bin features 2 with uniq val 24994 to only 256
Finish "Run values through model" in 93ms
Finish "Extract values" in 53ms
Run "Purify interactions to main effects".........
100%|██████████| 7/7 [00:00<00:00, 41.94it/s]
Finish "Purify interactions to main effects" in 67ms
Finish "Center main effects" in 1ms
Finish "Construct table" in 172ms