API
MLJ Interface Types
SIRUS.StableRulesClassifier
— TypeStableRulesClassifier
A model type for constructing a stable rules classifier, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableRulesClassifier = @load StableRulesClassifier pkg=SIRUS
Do model = StableRulesClassifier()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableRulesClassifier(rng=...)
.
StableRulesClassifier
implements the explainable rule-based model based on a random forest.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.max_rules::Int=10
: This is the most important hyperparameter afterlambda
. The more rules, the more accurate the model should be. If this is not the case, tunelambda
first. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10 to 40 rules should provide reasonable accuracy while remaining interpretable.lambda::Float64=1.0
: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100). When trying the range, one good check is to verify that an increase inmax_rules
increases performance. If this is not the case, then try a different value forlambda
.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableRules
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.
SIRUS.StableRulesRegressor
— TypeStableRulesRegressor
A model type for constructing a stable rules regressor, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableRulesRegressor = @load StableRulesRegressor pkg=SIRUS
Do model = StableRulesRegressor()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableRulesRegressor(rng=...)
.
StableRulesRegressor
implements the explainable rule-based regression model based on a random forest.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.max_rules::Int=10
: This is the most important hyperparameter afterlambda
. The more rules, the more accurate the model should be. If this is not the case, tunelambda
first. However, more rules will also decrease model interpretability. So, it is important to find a good balance here. In most cases, 10 to 40 rules should provide reasonable accuracy while remaining interpretable.lambda::Float64=1.0
: The weights of the final rules are determined via a regularized regression over each rule as a binary feature. This hyperparameter specifies the strength of the ridge (L2) regularizer. SIRUS is very sensitive to the choice of this hyperparameter. Ensure that you try the full range from 10^-4 to 10^4 (e.g., 0.001, 0.01, ..., 100). When trying the range, one good check is to verify that an increase inmax_rules
increases performance. If this is not the case, then try a different value forlambda
.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableRules
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.
SIRUS.StableForestClassifier
— TypeStableForestClassifier
A model type for constructing a stable forest classifier, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableForestClassifier = @load StableForestClassifier pkg=SIRUS
Do model = StableForestClassifier()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableForestClassifier(rng=...)
.
StableForestClassifier
implements the random forest classifier with a stabilized forest structure (Bénard et al., 2021). This stabilization increases stability when extracting rules. The impact on the predictive accuracy compared to standard random forests should be relatively small.
Just like normal random forests, this model is not easily explainable. If you are interested in an explainable model, use the StableRulesClassifier
or StableRulesRegressor
.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableForest
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.
SIRUS.StableForestRegressor
— TypeStableForestRegressor
A model type for constructing a stable forest regressor, based on SIRUS.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
StableForestRegressor = @load StableForestRegressor pkg=SIRUS
Do model = StableForestRegressor()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in StableForestRegressor(rng=...)
.
StableForestRegressor
implements the random forest regressor with a stabilized forest structure (Bénard et al., 2021).
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
, or<:OrderedFactor
; check column scitypes withschema(X)
y
: the target, which can be anyAbstractVector
whose element scitype is<:OrderedFactor
or<:Multiclass
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyperparameters
rng::AbstractRNG=default_rng()
: Random number generator. Using aStableRNG
fromStableRNGs.jl
is advised.partial_sampling::Float64=0.7
: Ratio of samples to use in each subset of the data. The default should be fine for most cases.n_trees::Int=1000
: The number of trees to use. It is advisable to use at least thousand trees to for a better rule selection, and in turn better predictive performance.max_depth::Int=2
: The depth of the tree. A lower depth decreases model complexity and can therefore improve accuracy when the sample size is small (reduce overfitting).q::Int=10
: Number of cutpoints to use per feature. The default value should be fine for most situations.min_data_in_leaf::Int=5
: Minimum number of data points per leaf.
Fitted parameters
The fields of fitted_params(mach)
are:
fitresult
: AStableForest
object.
Operations
predict(mach, Xnew)
: Return a vector of predictions for each row ofXnew
.
SIRUS Types
SIRUS.SubClause
— TypeSubClause(
feature::Int,
feature_name::AbstractString,
splitval::Number,
direction::Symbol
)
A subclause denotes a conditional on one feature. Each rule contains a clause with one or more subclauses. For example, the rule if X[i, 1] > 3 & X[i, 2] < 4, then ...
contains two subclauses.
A subclause is equivalent to a split in a decision tree. In other words, each rule is based on one or more subclauses. In pratise, a rule is based on at most two subclauses (has at most two subclauses). The reason for this is that rules with more than two subclauses will not end up in the final model, as is discussed in the original SIRUS paper.
The data inside a SubClause
can be accessed via
Note: this name is not perfect. A formally better name would be "predicate atom", but that takes more characters and is also not very intuitive. Instead, the word Clause
and SubClause
seemed pretty short and clear.
SIRUS.Clause
— TypeClause(subclauses::Vector{SubClause})
A clause denotes a conditional on one or more features. Each rule contains a clause with one or more subclauses.
A clause is equivalent to a path in a decision tree. For example, the clause X[i, 1] > 3 & X[i, 2] < 4
can be interpreted as a path going through two nodes.
In the original SIRUS paper, a path of length d
is defined as consisting of d
subclauses. As discussed above, in practice the number of subclauses or subclauses d ≤ 2
.
Note that a path can also be a path to a node; not necessarily a leaf.
Data can be accessed via subclauses
.
Clauses can be constructed from a textual representation:
Example
julia> Clause(" X[i, 1] < 32000 ")
Clause(" X[i, 1] < 32000.0 ")
SIRUS.Rule
— TypeRule(clause::Clause, then::LeafContent, otherwise::LeafContent)
A rule is a clause with a then and otherwise probability. For example, the rule if X[i, 1] > 3 & X[i, 2] < 4, then 0.1 else 0.2
is a rule with two subclauses. The name otherwise
is used internally instead of else
since else
is a reserved keyword.
Data can be accessed via
clause(::Rule)
,subclauses(::Rule)
,then(::Rule)
,otherwise(::Rule)
,feature(::Rule)
,features(::Rule)
,feature_name(::Rule)
,feature_names(::Rule)
,splitval(::Rule)
,splitvals(::Rule)
,direction(::Rule)
, anddirections(::Rule)
.
Rules can be constructed from a textual representation:
Example
julia> Rule(Clause(" X[i, 1] < 32000 "), [0.1], [0.4])
Rule(Clause(" X[i, 1] < 32000.0 "), [0.1], [0.4])
SIRUS Methods
SIRUS.feature
— Functionfeature(s::SubClause) -> Int
Return the feature number for a subclause.
feature(rule::Rule) -> Int
Return the feature number for a rule with one subclause. Throws an error if the rule has multiple subclauses. Use features
for rules with multiple subclauses.
SIRUS.features
— Functionfeatures(rule::Rule) -> Vector{Int}
Return a vector of feature numbers; one for each clause in rule
.
SIRUS.feature_name
— Functionfeature_name(s::SubClause) -> String
Return the feature name for a subclause.
feature_name(rule::Rule) -> String
Return the feature name for a rule with one subclause. Throws an error if the rule has multiple subclauses. Use feature_names
for rules with multiple subclauses.
SIRUS.feature_names
— Functionfeature_names(rule::Rule) -> Vector{String}
Return a vector of feature names; one for each clause in rule
.
SIRUS.splitval
— Functionsplitval(s::SubClause)
Return the split value for a subclause. The function currently returns a Float32
but this might change in the future.
splitval(rule::Rule)
Return the splitvalue for a rule with one subclause. Throws an error if the rule has multiple subclauses. Use splitvals
for rules with multiple subclauses.
SIRUS.splitvals
— Functionsplitvals(rule::Rule)
Return the splitvalues for a rule; one for each subclause.
SIRUS.clause
— Functionclause(rule::Rule) -> Clause
Return the clause for a rule. The clause is a path in a decision tree after the conversion to rules. A clause consists of one or more subclauses.
SIRUS.subclauses
— Functionsubclauses(c::Clause) -> Vector{SubClause}
Return the subclauses for a clause.
subclauses(rule::Rule) -> Vector{SubClause}
Return the subclauses for a rule.
SIRUS.direction
— Functiondirection(s::SubClause) -> Symbol
Return the direction of the comparison for a subclause. Can be either :L
or :R
, which is equivalent to <
or ≥
respectively.
direction(rule::Rule) -> Symbol
Return the direction for a rule with one subclause. Throws an error if the rule has multiple subclauses. Use directions
for rules with multiple subclauses.
SIRUS.directions
— Functiondirections(rule::Rule) -> Vector{Symbol}
Return a vector of split directions; one for each clause in rule
.
SIRUS.feature_importance
— Functionfeature_importance(
models::Union{StableRules, Vector{StableRules}},
feature_name::AbstractString
)
Estimate the importance of the given feature_name
. The aim is to satisfy the following property, so that the features can be ordered by importance:
Given two features A and B, if A has more effect on the outcome, then
feature_importance(model, A) > feature_importance(model, B)
.
This is based on the gap_size
function. The gap size is the difference between the then and otherwise (else) probabilities. A smaller gap size implies a smaller CART-splitting criterion, which implies a smaller occurrence frequency (see the appendix at https://proceedings.mlr.press/v130/benard21a.html for an example).
This function provides only an importance estimate because the effect on the outcome depends on the data.
SIRUS.feature_importances
— Functionfeature_importances(
models::Union{StableRules, Vector{StableRules}}
feat_names::Vector{String}
)::Vector{NamedTuple{(:feature_name, :importance), Tuple{String, Float64}}}
Return the feature names and importances, sorted by feature importance in descending order.
SIRUS.then
— Functionthen(rule::Rule)
Return the then probabilities for a rule. The return type is a vector of probabilities; the exact element type may change over time.
SIRUS.otherwise
— Functionotherwise(rule::Rule)
Return the otherwise probabilities for a rule. The return type is a vector of probabilities; the exact element type may change over time.
SIRUS.gap_size
— Functiongap_size(rule::Rule)
Return the gap size for a rule. The gap size is used by Bénard et al. in the appendix of their PMLR paper (https://proceedings.mlr.press/v130/benard21a.html). Via an example, they specify that the gap size is the difference between the then and otherwise (else) probabilities.
A smaller gap size implies a smaller CART-splitting criterion, which implies a smaller occurrence frequency.
SIRUS.EmpiricalQuantiles.cutpoints
— Functioncutpoints(V::AbstractVector, q::Int)
Return a vector of q
cutpoints taken from the empirical distribution from data V
."
cutpoints(X, q::Int)
Return a vector of vectors containing
- one inner vector for each feature in the dataset
X
and - inner vectors containing
q
unique cutpoints, that is,length(V[i])
≤q
for all i in V.
Using unique here to avoid checking splits twice.
SIRUS.satisfies
— Functionsatisfies(row::AbstractVector, rule::Rule) -> Bool
Return whether data row
satisfies rule
.
SIRUS.unpack_rule
— Functionunpack_rule(rule::Rule) -> NamedTuple
Unpack a rule into it's components. This is useful for plotting. It returns a named tuple with the following fields:
feature
feature_name
splitval
direction
then
otherwise
SIRUS.unpack_model
— Functionunpack_model(model::StableRules) -> Vector{NamedTuple}
Unpack a model containing only single subclauses (max_depth=1
) into it's components. This is useful for plotting. It returns a vector of named tuples with the following fields:
weight
feature
feature_name
splitval
direction
then
otherwise
One row for each rule in the model
.
SIRUS.unpack_models
— Functionunpack_models(
models::Vector{StableRules},
feature_name::String
) -> Vector{NamedTuple}
Unpack a vector of models containing only single subclauses (max_depth=1
) into it's components. This is useful when plotting the rules that the model has learned for each feature. It returns a vector of named tuples with the following fields for each rule in the models
that contains feature_name
:
weight
feature
feature_name
splitval
direction
then
otherwise