parsnip
model from scratchScratch.Rmd
parsnip
constructs models and predictions by representing those actions in expressions. There are a few reasons for this:
A parsnip
model function is itself very general. For example, the logistic_reg
function itself doesn’t have any model code within it. Instead, each model function is associated with one or more computational engines. These might be different R packages or some function in another language (that can be evaluated by R).
This vignette describes the process of creating a new model function. Before proceeding, take a minute and read our guidelines on creating modeling packages to get the general themes and conventions that we use.
As an example, we’ll create a function for mixture discriminant analysis. There are a few packages that do this but we’ll focus on mda::mda
:
str(mda::mda)
#> function (formula = formula(data), data = sys.frame(sys.parent()),
#> subclasses = 3, sub.df = NULL, tot.df = NULL, dimension = sum(subclasses) -
#> 1, eps = .Machine$double.eps, iter = 5, weights = mda.start(x,
#> g, subclasses, trace, ...), method = polyreg, keep.fitted = (n *
#> dimension < 5000), trace = FALSE, ...)
The main hyper-parameter is the number of subclasses. We’ll name our function mixture_da
.
There are three objects that define the parameters and other characteristics of the model function.
First, is the object that describes the model’s mode(s). The modes are the type of model and the two main values are “classification” and “regression”. A third mode, “unknown”, is used for initializing objects but models will fail if it is used further.
The convention in parsnip
is to use the name {model name}_modes
. In our case, we have:
mixture_da_modes <- c("classification", "unknown")
Next, we define the engines used by the model and the associated mode. Here, the columns correspond to the engine names and rows are the modes (via row names). We have two engines and one effective mode, so our object will have the suffix _engines
:
mixture_da_engines <- data.frame(
mda = TRUE,
row.names = c("classification")
)
mixture_da_engines
#> mda
#> classification TRUE
A row for “unknown” modes is not needed in this object.
Now, we enumerate the main arguments for each engine. parsnip
standardizes the names of arguments across different models and engines. For example, random forest and boosting use multiple trees to create the ensemble. Instead of using different argument names, parsnip
standardizes on trees
and the underlying code translates to the actual arguments used by the different functions.
In our case, the MDA argument name will be “sub_classes”.
Here, the object name will have the suffix _arg_key
and will have columns for the engines and rows for the arguments. The entries for the data frame are the actual arguments for each engine (and is NA
when an engine doesn’t have that argument). Ours:
mixture_da_arg_key <- data.frame(
mda = "sub_classes",
row.names = "sub_classes",
stringsAsFactors = FALSE
)
As an example of a model with multiple engines, here is the object for logistic regression:
parsnip:::logistic_reg_arg_key
#> glm glmnet spark stan keras
#> penalty NA lambda reg_param NA decay
#> mixture NA alpha elastic_net_param NA <NA>
The internals of parsnip
will use these objects during the creation of the model code.
This is a fairly simple function that can follow a basic template. The main arguments to our function will be:
sub_classes
here). These should be defaulted to NULL
....
are not used in the main model function.A basic version of the function is:
mixture_da <-
function(mode = "classification", sub_classes = NULL) {
# Check for correct mode
if (!(mode %in% mixture_da_modes))
stop("`mode` should be one of: ",
paste0("'", mixture_da_modes, "'", collapse = ", "),
call. = FALSE)
# Capture the arguments in quosures
args <- list(sub_classes = rlang::enquo(sub_classes))
# Save some empty slots for future parts of the specification
out <- list(args = args, eng_args = NULL,
mode = mode, method = NULL, engine = NULL)
# set classes in the correct order
class(out) <- make_classes("mixture_da")
out
}
This is pretty simple since the data are not exposed to this function.
This is where the details of the models are specified. This will be a list that has a few different elements:
libs
is a character string that has any package names that will be required for the model fit.fit
has details for the model fit function.pred
, prob
, and classes
. These are lists of details for making predictions on numbers, class probabilities, or hard class predictions (respectively).We’ll look at each. The convention here is to name this {model name}_{engine}_data
. We’ll start with:
mixture_da_mda_data <- list(libs = "mda")
fit
moduleThe main arguments are:
interface
a single character value that could be “formula”, “data.frame”, or “matrix”. This defines the type of interface used by the underlying fit function (mda::mda
, in this case). This helps the translation of the data to be in an appropriate format for the that function.protect
is an optional list of function arguments that should not be changeable by the user. In this case, we probably don’t want users to pass data values to these arguments (until the fit
function is called).func
is the package and name of the function that will be called. If you are using a locally defined function, only fun
is required.defaults
is an optional list of arguments to the fit function that the user can change, but whose defaults can be set here. This isn’t needed in this case, but is describe later in this document.For the first engine:
mixture_da_mda_data$fit <-
list(
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "mda", fun = "mda"),
defaults = list()
)
numeric
moduleThis is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For numeric
, the model requires an unnamed numeric vector output (usually).
For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results).
Note that the numeric
module maps to the predict_numeric
function in parsnip
. However, the user-facing predict
function is used to generate predictions and returns a tibble with a column named .pred
(see the example below). When creating new models, you don’t have to write code for that part.
class
moduleTo make hard class predictions, the class
object contains the details. The elements of the list are:
pre
and post
are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won’t be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a pre
function with the args
below.func
is the prediction function (in the same format as above). In many cases, packages have a predict method for their model’s class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to predict
with no associated package.args
is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr
so that they are not evaluated when defining the method. For mda
, the code would be predict(object, newdata, type = "class")
. What is actually given to the function is the parsnip
model fit object, which includes a sub-object called fit
and this houses the mda
model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data))
and so on.mixture_da_mda_data$class <-
list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args =
# These lists should be of the form:
# {predict.mda argument name} = {values provided from parsnip objects}
list(
# We don't want the first two arguments evaluated right now
# since they don't exist yet. `type` is a simple object that
# doesn't need to have its evaluation deferred.
object = quote(object$fit),
newdata = quote(new_data),
type = "class"
)
)
The predict_class
function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the pred
module, the user doesn’t call predict_class
but uses predict
instead and this produces a tibble with a column named .pred_class
per the model guidlines.
classprob
moduleThis defines the class probabilities (if they can be computed). The format is identical to the class
module but the output is expected to be a tibble with columns for each factor level.
As an example of the post
function, the data frame created by mda:::predict.mda
will be converted to a tibble. The arguments are x
(the raw results coming from the predict method) and object
(the parsnip
model fit object). The latter has a sub-object called lvl
which is a character string of the outcome’s factor levels (if any).
mixture_da_mda_data$classprob <-
list(
pre = NULL,
post = function(x, object) {
tibble::as_tibble(x)
},
func = c(fun = "predict"),
args =
list(
object = quote(object$fit),
newdata = quote(new_data),
type = "posterior"
)
)
The post
element converts the output to a tibble but the main predict
method does proper naming of the column names.
As a developer, one thing that may come in handy is the translate
function. This will tell you what the model’s eventual syntax will be.
For example:
library(tidymodels)
mixture_da(sub_classes = 2) %>%
set_engine("mda") %>%
translate()
#> Model Specification (classification)
#>
#> Main Arguments:
#> sub_classes = 2
#>
#> Computational engine: mda
#>
#> Model fit template:
#> mda::mda(formula = missing_arg(), data = missing_arg(), weights = missing_arg(),
#> sub_classes = 2)
Let’s try it on the iris data:
set.seed(4622)
iris_split <- initial_split(iris, prop = 0.90)
iris_train <- training(iris_split)
iris_test <- testing(iris_split)
mda_spec <- mixture_da(sub_classes = 2)
mda_fit <- mda_spec %>%
set_engine("mda") %>%
fit(Species ~ ., data = iris_train)
mda_fit
#> parsnip model object
#>
#> Call:
#> mda::mda(formula = formula, data = data, sub_classes = ~2)
#>
#> Dimension: 4
#>
#> Percent Between-Group Variance Explained:
#> v1 v2 v3 v4
#> 95.8 98.3 99.9 100.0
#>
#> Degrees of Freedom (per dimension): 5
#>
#> Training Misclassification Error: 0.0221 ( N = 136 )
#>
#> Deviance: 13.4
predict(mda_fit, new_data = iris_test) %>%
bind_cols(iris_test %>% select(Species))
#> # A tibble: 14 x 2
#> .pred_class Species
#> <fct> <fct>
#> 1 setosa setosa
#> 2 setosa setosa
#> 3 setosa setosa
#> 4 setosa setosa
#> 5 versicolor versicolor
#> 6 versicolor versicolor
#> 7 versicolor versicolor
#> 8 versicolor versicolor
#> 9 versicolor versicolor
#> 10 virginica virginica
#> 11 virginica virginica
#> 12 virginica virginica
#> 13 virginica virginica
#> 14 virginica virginica
predict(mda_fit, new_data = iris_test, type = "prob") %>%
bind_cols(iris_test %>% select(Species))
#> # A tibble: 14 x 4
#> .pred_setosa .pred_versicolor .pred_virginica Species
#> <dbl> <dbl> <dbl> <fct>
#> 1 1.00e+ 0 5.97e-24 1.64e-65 setosa
#> 2 1.00e+ 0 8.08e-28 2.99e-65 setosa
#> 3 1.00e+ 0 1.09e-22 1.16e-60 setosa
#> 4 1.00e+ 0 6.82e-29 1.09e-72 setosa
#> 5 2.21e- 53 9.99e- 1 9.37e- 4 versicolor
#> 6 3.02e- 30 10.00e- 1 4.74e- 8 versicolor
#> 7 4.61e- 33 10.00e- 1 1.77e- 6 versicolor
#> 8 2.78e- 45 9.99e- 1 1.03e- 3 versicolor
#> 9 1.33e- 21 10.00e- 1 3.46e-14 versicolor
#> 10 2.87e- 76 3.93e- 4 10.00e- 1 virginica
#> 11 3.73e- 71 4.69e- 5 10.00e- 1 virginica
#> 12 2.26e-111 4.48e-15 10.00e- 1 virginica
#> 13 1.12e- 58 2.24e- 1 7.76e- 1 virginica
#> 14 2.54e- 56 1.86e- 1 8.14e- 1 virginica
There are various things that came to mind while writing this document.
predict_num
and predict_class
?Previously, when discussing the numeric
information:
For
numeric
, the model requires an unnamed numeric vector output (usually).
There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated here but some easy examples are:
and so on. These can be accomodated via predict.model_fit
using different type
arguments.
However, there are some models (e.g. glmnet
, plsr
, Cubist
, etc.) that can make predictions for different models from the same fitted model object. The regular predict
method requires prediction from a single model but the multi_predict
can. The guideline is to always return the same number of rows as in new_data
. This means that the .pred
column is a list-column of tibbles.
For example, for a multinomial glmnet
model, we leave penalty
unspecified when fitting and get predictions on a sequence of values:
mod <- multinom_reg(mixture = 1/3) %>%
set_engine("glmnet")
mod_fit <- fit(mod, Species ~ ., data = iris)
preds <- multi_predict(mod_fit, iris[1:3, -5], penalty = c(0, 0.01, 0.1), type = "prob")
preds
#> # A tibble: 3 x 1
#> .pred
#> <list>
#> 1 <tibble [3 × 4]>
#> 2 <tibble [3 × 4]>
#> 3 <tibble [3 × 4]>
preds[[".pred"]][1]
#> [[1]]
#> # A tibble: 3 x 4
#> .pred_setosa .pred_versicolor .pred_virginica penalty
#> <dbl> <dbl> <dbl> <dbl>
#> 1 1.000 0.000231 5.70e-22 0
#> 2 0.980 0.0197 3.13e- 7 0.01
#> 3 0.876 0.113 1.12e- 2 0.1
This can be easily expanded to remove the list columns:
preds %>%
mutate(.row = 1:nrow(preds)) %>%
tidyr::unnest()
#> # A tibble: 9 x 5
#> .row .pred_setosa .pred_versicolor .pred_virginica penalty
#> <int> <dbl> <dbl> <dbl> <dbl>
#> 1 1 1.000 0.000231 5.70e-22 0
#> 2 1 0.980 0.0197 3.13e- 7 0.01
#> 3 1 0.876 0.113 1.12e- 2 0.1
#> 4 2 0.998 0.00201 6.26e-20 0
#> 5 2 0.940 0.0601 1.41e- 6 0.01
#> 6 2 0.780 0.207 1.35e- 2 0.1
#> 7 3 1.000 0.000218 2.24e-21 0
#> 8 3 0.975 0.0247 4.41e- 7 0.01
#> 9 3 0.847 0.143 1.00e- 2 0.1
multi_predict
doesn’t exist for every model and needs to be implmented by the developer. See methods("multi_predict")
for examples in this package.
defaults
slot and why do I need it?You might want to set defaults that can be overridden by the user. For example, for logistic regression with glm
, it make sense to default family = binomial
. However, if someone wants to use a different link function, they should be able to do that. For that model/engine definition, it has
defaults = list(family = expr(stats::binomial))
so that is the default:
logistic_reg() %>% translate(engine = "glm")
#> Logistic Regression Model Specification (classification)
#>
#> Computational engine: glm
#>
#> Model fit template:
#> stats::glm(formula = missing_arg(), data = missing_arg(), weights = missing_arg(),
#> family = stats::binomial)
# but you can change it:
logistic_reg() %>%
set_engine("glm", family = stats::binomial(link = "probit")) %>%
translate()
#> Logistic Regression Model Specification (classification)
#>
#> Engine-Specific Arguments:
#> family = stats::binomial(link = "probit")
#>
#> Computational engine: glm
#>
#> Model fit template:
#> stats::glm(formula = missing_arg(), data = missing_arg(), weights = missing_arg(),
#> family = stats::binomial(link = "probit"))
That’s what defaults
are for.
Note that I wrapped binomial
inside of expr
. If I didn’t, it would substitute the results of executing binomial
inside of the expression (and that’s a mess). Using namespaces is a good idea here.
The translate
function can be used to check values or set defaults once the model’s mode is known. To do this, you can create a model-specific S3 method that first calls the general method (translate.model_spec
) and then makes modifications or conducts error traps.
For example, the ranger
and randomForest
package functions have arguments for calculating importance. One is a logical and the other is a string. Since this is likely to lead to a bunch of frustration and GH issues, we can put in a check:
# Simplified version
translate.rand_forest <- function (x, engine = x$engine, ...){
# Run the general method to get the real arguments in place
x <- translate.default(x, engine, ...)
# Make code easier to read
arg_vals <- x$method$fit$args
# Check and see if they make sense for the engine and/or mode:
if (engine == "ranger") {
if (any(names(arg_vals) == "importance"))
# We want to check the type of `importance` but it is a quosure. We first
# get the expression. It is is logical, the value of `quo_get_expr` will
# not be an expression but the actual logical. The wrapping of `isTRUE`
# is there in case it is not an atomic value.
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
stop("`importance` should be a character value. See ?ranger::ranger.",
call. = FALSE)
if (x$mode == "classification" && !any(names(arg_vals) == "probability"))
arg_vals$probability <- TRUE
}
x$method$fit$args <- arg_vals
x
}
As another example, nnet::nnet
has an option for the final layer to be linear (called linout
). If mode = "regression"
, that should probably be set to TRUE
. You couldn’t do this with the args
(described above) since you need the function translated first.
In cases where the model requires different defaults, the translate
method can also be used. See the code for the mars
function to see how to check and potentially switch arguments for classification models.
The best course of action is to write wrapper so that it can be one call. This was the case with xgboost
, C5.0
, and keras
.
There might be non-trivial transformations that the model prediction code requires (such as converting to a sparse matrix representation, etc.)
This would not include making dummy variables and model.matrix
stuff. parsnip
already does that for you.
What comes back from some R functions make be somewhat… arcane or problematic. As an example, for xgboost
, if you fit a multiclass boosted tree, you might expect the class probabilities to come back as a matrix^{1}. If you have four classes and make predictions on three samples, you get a vector of 12 probability values. You need to convert these to a rectangular data set.
Another example is the predict method for ranger
, which encapsulates the actual predictions in a more complex object structure.
These are the types of problems that the postprocessor will solve.
Not yet but there will be. For example, it might make sense to have a different mode when doing risk-based modeling via Cox regression models. That would enable different classes of objects and those might be needed since the types of models don’t make direct predictions of the outcome.
If you have a suggestion, please ad a GitHub issue to discuss it.
narrator: they don’t↩