Basic Example

This page shows a basic example for using SIRUS.jl on a dataset via the Machine Learning Julia (MLJ.jl) interface. For more details on what SIRUS is and how it works, see the Advanced Example.

    using CategoricalArrays: categorical
    using CSV: CSV
    using DataDeps: DataDeps, DataDep, @datadep_str
    using DataFrames
    using MLJ
    using StableRNGs: StableRNG
    using SIRUS: StableRulesClassifier

To show the algorithm, we'll use Haberman's survival dataset. We load it via DataDeps.jl so that we can use a checksum for verification and to cache the dataset.

function register_haberman()
    name = "Haberman"
    message = "Haberman's Survival Data Set"
    remote_path = ""
    checksum = "a7e9aeb249e11ac17c2b8ea4fdafd5c9392219d27cb819ffaeb8a869eb727a0f"
    DataDeps.register(DataDep(name, message, remote_path, checksum))

We can call this registration function and then load the dataset:

function load_haberman()::DataFrame
    path = joinpath(datadep"Haberman", "haberman.csv")
    df =, DataFrame)
    df[!, :survival] = categorical(df.survival)
data = load_haberman()

And split it into features (X) and outcomes (y):

X = select(data, Not(:survival));
y = data.survival;

Next, we can load the model that we want to use. Since Haberman's outcome column (survival) contains 0's and 1's, we use the StableRulesClassifier.

Here, the hyperparameters were manually tuned for this dataset. On this dataset:

  • The number of split points q was lowered to 4. This means that the random forest can choose from only 4 locations to split the data. With such a low number, the accuracy typically goes down, but the rules become more interpretable.

  • The max tree depth max_depth was set to 2. This is the highest that SIRUS can go, as is discussed in the original paper. The reason is that rules for greater tree depths do almost certainly not end up in the final rule set, so it's computationally cheaper to not determine them in the first place.

  • The lambda parameter was set to 1. This parameter is used with a ridge regression to determine the weights in the final rule set. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100).

  • The max number of rules max_rules was set to 8. Typically, the model becomes more accurate with more rules, but less interpretable. When the model becomes less accurate with more rules, then ensure that you have set the right lambda.

model = StableRulesClassifier(; rng=StableRNG(1), q=4, max_depth=2, lambda=1, max_rules=8);

Next, we will show two common use-cases:

  1. fit the model to the full dataset and

  2. fit the model to cross-validation folds (to evaluate model performance).

Fitting to the full dataset

mach = let
    mach = machine(model, X, y)!(mach)

And inspect the fitted model:

StableRules model with 8 rules:
 if X[i, :nodes] < 7.0 then 0.238 else 0.046 +
 if X[i, :nodes] < 2.0 then 0.183 else 0.055 +
 if X[i, :age] ≥ 62.0 & X[i, :year] < 1959.0 then 0.0 else 0.001 +
 if X[i, :year] < 1959.0 & X[i, :nodes] ≥ 2.0 then 0.0 else 0.006 +
 if X[i, :nodes] ≥ 7.0 & X[i, :age] ≥ 62.0 then 0.0 else 0.008 +
 if X[i, :year] < 1959.0 & X[i, :nodes] ≥ 7.0 then 0.0 else 0.003 +
 if X[i, :year] ≥ 1966.0 & X[i, :age] < 42.0 then 0.0 else 0.008 +
 if X[i, :nodes] ≥ 7.0 & X[i, :age] ≥ 42.0 then 0.014 else 0.045
and 2 classes: [0, 1]. 
Note: showing only the probability for class 1 since class 0 has probability 1 - p.

This shows that the model contains 8 rules where the first rule, for example, can be interpreted as

If the number of detected axillary nodes is lower than 7, then take 0.238 and otherwise take 0.046.

This calculation is done for all 8 rules and the score is summed to get a prediction.

The predictions are of the type UnivariateFinite from MLJ's CategoricalDistributions.jl:

predictions = predict(mach, X)
306-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt8, Float64}:
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.434, 1=>0.163)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.233, 1=>0.365)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.105, 1=>0.492)
 UnivariateFinite{ScientificTypesBase.Multiclass{2}}(0=>0.24, 1=>0.358)

To get the underlying predictions out of these objects, use pdf. For example, to get the prediction for the class 0 for the first datapoint, use:

pdf(predictions[1], 0)

Model Evaluation via Cross-Validation

Let's define our Cross-Validation (CV) strategy with 10 folds. Also, we enable shuffling to make it more likely that our model sees cases from both survival classes:

resampling = CV(; rng=StableRNG(1), nfolds=10, shuffle=true);

We use the Area Under the Curve (AUC) measure since that measure is appropriate for binary classification tasks. More specifically, the measure gives the area under the receiver operating characteristic curve. For this measure, a score of 0.5 means that our model is as good (or bad, actually) as random guessing, and a score of 0.0 means predicting all cases wrong and 1.0 means predicting all cases correctly.

evaluate(model, X, y; resampling, measure=auc)
PerformanceEvaluation object with these fields:
  model, measure, operation,
  measurement, per_fold, per_observation,
  fitted_params_per_fold, report_per_fold,
  train_test_rows, resampling, repeats
│ measure          │ operation │ measurement │
│ AreaUnderCurve() │ predict   │ 0.704       │
│ per_fold                                                            │ 1.96*SE │
│ [0.542, 0.771, 0.78, 0.598, 0.604, 0.88, 0.856, 0.72, 0.686, 0.605] │ 0.0762  │