Basic Example
This page shows a basic example for using SIRUS.jl on a dataset via the Machine Learning Julia (MLJ.jl) interface. For more details on what SIRUS is and how it works, see the Advanced Example.
begin
using CategoricalArrays: categorical
using CSV: CSV
using DataDeps: DataDeps, DataDep, @datadep_str
using DataFrames
using MLJ
using StableRNGs: StableRNG
using SIRUS: StableRulesClassifier
end
To show the algorithm, we'll use Haberman's survival dataset. We load it via DataDeps.jl
so that we can use a checksum for verification and to cache the dataset.
function register_haberman()
name = "Haberman"
message = "Haberman's Survival Data Set"
remote_path = "https://github.com/rikhuijzer/haberman-survival-dataset/releases/download/v1.0.0/haberman.csv"
checksum = "a7e9aeb249e11ac17c2b8ea4fdafd5c9392219d27cb819ffaeb8a869eb727a0f"
DataDeps.register(DataDep(name, message, remote_path, checksum))
end;
We can call this registration function and then load the dataset:
function load_haberman()::DataFrame
register_haberman()
path = joinpath(datadep"Haberman", "haberman.csv")
df = CSV.read(path, DataFrame)
df[!, :survival] = categorical(df.survival)
df
end;
data = load_haberman()
age | year | nodes | survival | |
---|---|---|---|---|
1 | 30 | 1964 | 1 | 1 |
2 | 30 | 1962 | 3 | 1 |
3 | 30 | 1965 | 0 | 1 |
4 | 31 | 1959 | 2 | 1 |
5 | 31 | 1965 | 4 | 1 |
6 | 33 | 1958 | 10 | 1 |
7 | 33 | 1960 | 0 | 1 |
8 | 34 | 1959 | 0 | 0 |
9 | 34 | 1966 | 9 | 0 |
10 | 34 | 1958 | 30 | 1 |
... | ||||
306 | 83 | 1958 | 2 | 0 |
And split it into features (X
) and outcomes (y
):
X = select(data, Not(:survival));
y = data.survival;
Next, we can load the model that we want to use. Since Haberman's outcome column (survival
) contains 0's and 1's, we use the StableRulesClassifier
.
Here, the hyperparameters were manually tuned for this dataset. On this dataset:
The number of split points
q
was lowered to 4. This means that the random forest can choose from only 4 locations to split the data. With such a low number, the accuracy typically goes down, but the rules become more interpretable.The max tree depth
max_depth
was set to 2. This is the highest that SIRUS can go, as is discussed in the original paper. The reason is that rules for greater tree depths do almost certainly not end up in the final rule set, so it's computationally cheaper to not determine them in the first place.The
lambda
parameter was set to 1. This parameter is used with a ridge regression to determine the weights in the final rule set. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100).The max number of rules
max_rules
was set to 8. Typically, the model becomes more accurate with more rules, but less interpretable. When the model becomes less accurate with more rules, then ensure that you have set the rightlambda
.
model = StableRulesClassifier(; rng=StableRNG(1), q=4, max_depth=2, lambda=1, max_rules=8);
Next, we will show two common use-cases:
fit the model to the full dataset and
fit the model to cross-validation folds (to evaluate model performance).
Fitting to the full dataset
mach = let
mach = machine(model, X, y)
MLJ.fit!(mach)
end;
And inspect the fitted model:
mach.fitresult
StableRules model with 8 rules: if X[i, :nodes] < 7.0 then 0.238 else 0.046 + if X[i, :nodes] < 2.0 then 0.183 else 0.055 + if X[i, :age] ≥ 62.0 & X[i, :year] < 1959.0 then 0.0 else 0.001 + if X[i, :year] < 1959.0 & X[i, :nodes] ≥ 2.0 then 0.0 else 0.006 + if X[i, :nodes] ≥ 7.0 & X[i, :age] ≥ 62.0 then 0.0 else 0.008 + if X[i, :year] < 1959.0 & X[i, :nodes] ≥ 7.0 then 0.0 else 0.003 + if X[i, :year] ≥ 1966.0 & X[i, :age] < 42.0 then 0.0 else 0.008 + if X[i, :nodes] ≥ 7.0 & X[i, :age] ≥ 42.0 then 0.014 else 0.045 and 2 classes: [0, 1]. Note: showing only the probability for class 1 since class 0 has probability 1 - p.
This shows that the model contains 8 rules where the first rule, for example, can be interpreted as
If the number of detected axillary nodes is lower than 7, then take 0.238 and otherwise take 0.046.
This calculation is done for all 8 rules and the score is summed to get a prediction.
The predictions are of the type UnivariateFinite
from MLJ's CategoricalDistributions.jl
:
predictions = predict(mach, X)
306-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt8, Float64}: UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.434, 1=>0.163) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) ⋮ UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492) UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.24, 1=>0.358)
To get the underlying predictions out of these objects, use pdf
. For example, to get the prediction for the class 0 for the first datapoint, use:
pdf(predictions[1], 0)
0.10530822372436523
See https://alan-turing-institute.github.io/MLJ.jl/dev/getting_started/#Fit-and-predict for more information.
Model Evaluation via Cross-Validation
Let's define our Cross-Validation (CV) strategy with 10 folds. Also, we enable shuffling to make it more likely that our model sees cases from both survival
classes:
resampling = CV(; rng=StableRNG(1), nfolds=10, shuffle=true);
We use the Area Under the Curve (AUC) measure since that measure is appropriate for binary classification tasks. More specifically, the measure gives the area under the receiver operating characteristic curve. For this measure, a score of 0.5 means that our model is as good (or bad, actually) as random guessing, and a score of 0.0 means predicting all cases wrong and 1.0 means predicting all cases correctly.
evaluate(model, X, y; resampling, measure=auc)
PerformanceEvaluation object with these fields: model, measure, operation, measurement, per_fold, per_observation, fitted_params_per_fold, report_per_fold, train_test_rows, resampling, repeats Extract: ┌──────────────────┬───────────┬─────────────┬─────────┬──────────────────────────────── │ measure │ operation │ measurement │ 1.96*SE │ per_fold ⋯ ├──────────────────┼───────────┼─────────────┼─────────┼──────────────────────────────── │ AreaUnderCurve() │ predict │ 0.704 │ 0.0762 │ [0.542, 0.771, 0.78, 0.598, 0 ⋯ └──────────────────┴───────────┴─────────────┴─────────┴──────────────────────────────── 1 column omitted