parsnip constructs models and predictions by representing those actions in expressions. There are a few reasons for this:

  • It eliminates a lot of duplicate code.
  • Since the expressions are not evaluated until fitting, it eliminates a large amount of package dependencies.

A parsnip model function is itself very general. For example, the logistic_reg function itself doesn’t have any model code within it. Instead, each model function is associated with one or more computational engines. These might be different R packages or some function in another language (that can be evaluated by R).

This vignette describes the process of creating a new model function. Before proceeding, take a minute and read our guidelines on creating modeling packages to get the general themes and conventions that we use.

As an example, we’ll create a function for mixture discriminant analysis. There are a few packages that do this but we’ll focus on mda::mda:

The main hyper-parameter is the number of subclasses. We’ll name our function mixture_da.

Step 1. Make the objects for the general method

There are three objects that define the parameters and other characteristics of the model function.

First, is the object that describes the model’s mode(s). The modes are the type of model and the two main values are “classification” and “regression”. A third mode, “unknown”, is used for initializing objects but models will fail if it is used further.

The convention in parsnip is to use the name {model name}_modes. In our case, we have:

Next, we define the engines used by the model and the associated mode. Here, the columns correspond to the engine names and rows are the modes (via row names). We have two engines and one effective mode, so our object will have the suffix _engines:

A row for “unknown” modes is not needed in this object.

Now, we enumerate the main arguments for each engine. parsnip standardizes the names of arguments across different models and engines. For example, random forest and boosting use multiple trees to create the ensemble. Instead of using different argument names, parsnip standardizes on trees and the underlying code translates to the actual arguments used by the different functions.

In our case, the MDA argument name will be “sub_classes”.

Here, the object name will have the suffix _arg_key and will have columns for the engines and rows for the arguments. The entries for the data frame are the actual arguments for each engine (and is NA when an engine doesn’t have that argument). Ours:

As an example of a model with multiple engines, here is the object for logistic regression:

The internals of parsnip will use these objects during the creation of the model code.

Step 2. Create the model function

This is a fairly simple function that can follow a basic template. The main arguments to our function will be:

  • The mode. If the model can do more than one mode, you might default this to “unknown”. In our case, since it is only a classification model, it makes sense to default it to that mode.
  • The argument names (sub_classes here). These should be defaulted to NULL.
  • ... are not used in the main model function.

A basic version of the function is:

This is pretty simple since the data are not exposed to this function.

Step 3. Make the model object

This is where the details of the models are specified. This will be a list that has a few different elements:

  • libs is a character string that has any package names that will be required for the model fit.
  • fit has details for the model fit function.
  • pred, prob, and classes. These are lists of details for making predictions on numbers, class probabilities, or hard class predictions (respectively).

We’ll look at each. The convention here is to name this {model name}_{engine}_data. We’ll start with:

The fit module

The main arguments are:

  • interface a single character value that could be “formula”, “data.frame”, or “matrix”. This defines the type of interface used by the underlying fit function (mda::mda, in this case). This helps the translation of the data to be in an appropriate format for the that function.
  • protect is an optional list of function arguments that should not be changeable by the user. In this case, we probably don’t want users to pass data values to these arguments (until the fit function is called).
  • func is the package and name of the function that will be called. If you are using a locally defined function, only fun is required.
  • defaults is an optional list of arguments to the fit function that the user can change, but whose defaults can be set here. This isn’t needed in this case, but is describe later in this document.

For the first engine:

The numeric module

This is defined only for regression models (so is not added to the list). The convention used here is very similar to the two that are detailed in the next section. For numeric, the model requires an unnamed numeric vector output (usually).

Examples are here and here.

For multivariate models, the return value should be a matrix or data frame (otherwise a vector should be the results).

Note that the numeric module maps to the predict_numeric function in parsnip. However, the user-facing predict function is used to generate predictions and returns a tibble with a column named .pred (see the example below). When creating new models, you don’t have to write code for that part.

The class module

To make hard class predictions, the class object contains the details. The elements of the list are:

  • pre and post are optional functions that can preprocess the data being fed to the prediction code and to postprocess the raw output of the predictions. These won’t be need for this example, but a section below has examples of how these can be used when the model code is not easy to use. If the data being predicted has a simple type requirement, you can avoid using a pre function with the args below.
  • func is the prediction function (in the same format as above). In many cases, packages have a predict method for their model’s class but this is typically not exported. In this case (and the example below), it is simple enough to make a generic call to predict with no associated package.
  • args is a list of arguments to pass to the prediction function. These will mostly likely be wrapped in rlang::expr so that they are not evaluated when defining the method. For mda, the code would be predict(object, newdata, type = "class"). What is actually given to the function is the parsnip model fit object, which includes a sub-object called fit and this houses the mda model object. If the data need to be a matrix or data frame, you could also use new_data = quote(as.data.frame(new_data)) and so on.

The predict_class function will expect the result to be an unnamed character string or factor. This will be coerced to a factor with the same levels as the original data. As with the pred module, the user doesn’t call predict_class but uses predict instead and this produces a tibble with a column named .pred_class per the model guidlines.

The classprob module

This defines the class probabilities (if they can be computed). The format is identical to the class module but the output is expected to be a tibble with columns for each factor level.

As an example of the post function, the data frame created by mda:::predict.mda will be converted to a tibble. The arguments are x (the raw results coming from the predict method) and object (the parsnip model fit object). The latter has a sub-object called lvl which is a character string of the outcome’s factor levels (if any).

The post element converts the output to a tibble but the main predict method does proper naming of the column names.

Does it Work?

As a developer, one thing that may come in handy is the translate function. This will tell you what the model’s eventual syntax will be.

For example:

Let’s try it on the iris data:

set.seed(4622)
iris_split <- initial_split(iris, prop = 0.90)
iris_train <- training(iris_split)
iris_test  <-  testing(iris_split)

mda_spec <- mixture_da(sub_classes = 2)

mda_fit <- mda_spec %>%
  set_engine("mda") %>%
  fit(Species ~ ., data = iris_train)
mda_fit
#> parsnip model object
#> 
#> Call:
#> mda::mda(formula = formula, data = data, sub_classes = ~2)
#> 
#> Dimension: 4 
#> 
#> Percent Between-Group Variance Explained:
#>    v1    v2    v3    v4 
#>  94.9  97.9  99.8 100.0 
#> 
#> Degrees of Freedom (per dimension): 5 
#> 
#> Training Misclassification Error: 0.0221 ( N = 136 )
#> 
#> Deviance: 12.3

predict(mda_fit, new_data = iris_test) %>%
  bind_cols(iris_test %>% select(Species))
#> # A tibble: 14 x 2
#>    .pred_class Species   
#>    <fct>       <fct>     
#>  1 setosa      setosa    
#>  2 setosa      setosa    
#>  3 setosa      setosa    
#>  4 versicolor  versicolor
#>  5 versicolor  versicolor
#>  6 versicolor  versicolor
#>  7 versicolor  versicolor
#>  8 versicolor  versicolor
#>  9 versicolor  versicolor
#> 10 versicolor  versicolor
#> 11 versicolor  versicolor
#> 12 virginica   virginica 
#> 13 virginica   virginica 
#> 14 virginica   virginica

predict(mda_fit, new_data = iris_test, type = "prob") %>% 
  bind_cols(iris_test %>% select(Species))
#> # A tibble: 14 x 4
#>    .pred_setosa .pred_versicolor .pred_virginica Species   
#>           <dbl>            <dbl>           <dbl> <fct>     
#>  1     1.00e+ 0         2.62e-32        7.10e-65 setosa    
#>  2     1.00e+ 0         1.36e-25        2.36e-56 setosa    
#>  3     1.00e+ 0         9.11e-29        1.33e-60 setosa    
#>  4     1.76e-38        10.00e- 1        1.97e- 7 versicolor
#>  5     5.64e-36         9.95e- 1        5.03e- 3 versicolor
#>  6     6.84e-22        10.00e- 1        9.83e- 9 versicolor
#>  7     2.54e-37         9.22e- 1        7.83e- 2 versicolor
#>  8     2.70e-37         9.99e- 1        1.34e- 3 versicolor
#>  9     1.81e-37         8.06e- 1        1.94e- 1 versicolor
#> 10     9.83e-35         9.93e- 1        7.27e- 3 versicolor
#> 11     4.04e-37         9.97e- 1        3.00e- 3 versicolor
#> 12     1.93e-55         1.44e- 1        8.56e- 1 virginica 
#> 13     1.21e-50         4.19e- 1        5.81e- 1 virginica 
#> 14     2.08e-50         2.07e- 1        7.93e- 1 virginica

Pro-tips, what-ifs, exceptions, FAQ, and minutiae

There are various things that came to mind while writing this document.

Do I have to return a simple vector for predict_num and predict_class?

Previously, when discussing the numeric information:

For numeric, the model requires an unnamed numeric vector output (usually).

There are some occasions where a prediction for a single new sample may be multidimensional. Examples are enumerated here but some easy examples are:

  • confidence or prediction intervals
  • quantile regression predictions.

and so on. These can be accomodated via predict.model_fit using different type arguments.

However, there are some models (e.g. glmnet, plsr, Cubist, etc.) that can make predictions for different models from the same fitted model object. The regular predict method requires prediction from a single model but the multi_predict can. The guideline is to always return the same number of rows as in new_data. This means that the .pred column is a list-column of tibbles.

For example, for a multinomial glmnet model, we leave penalty unspecified when fitting and get predictions on a sequence of values:

This can be easily expanded to remove the list columns:

multi_predict doesn’t exist for every model and needs to be implmented by the developer. See methods("multi_predict") for examples in this package.

What is the defaults slot and why do I need it?

You might want to set defaults that can be overridden by the user. For example, for logistic regression with glm, it make sense to default family = binomial. However, if someone wants to use a different link function, they should be able to do that. For that model/engine definition, it has

so that is the default:

That’s what defaults are for.

Note that I wrapped binomial inside of expr. If I didn’t, it would substitute the results of executing binomial inside of the expression (and that’s a mess). Using namespaces is a good idea here.

What if I want more complex defaults?

The translate function can be used to check values or set defaults once the model’s mode is known. To do this, you can create a model-specific S3 method that first calls the general method (translate.model_spec) and then makes modifications or conducts error traps.

For example, the ranger and randomForest package functions have arguments for calculating importance. One is a logical and the other is a string. Since this is likely to lead to a bunch of frustration and GH issues, we can put in a check:

As another example, nnet::nnet has an option for the final layer to be linear (called linout). If mode = "regression", that should probably be set to TRUE. You couldn’t do this with the args (described above) since you need the function translated first.

In cases where the model requires different defaults, the translate method can also be used. See the code for the mars function to see how to check and potentially switch arguments for classification models.

My model fit requires more than one function call. So….?

The best course of action is to write wrapper so that it can be one call. This was the case with xgboost, C5.0, and keras.

Why would I preprocess my data?

There might be non-trivial transformations that the model prediction code requires (such as converting to a sparse matrix representation, etc.)

This would not include making dummy variables and model.matrix stuff. parsnip already does that for you.

Why would I postprocess my predictions?

What comes back from some R functions make be somewhat… arcane or problematic. As an example, for xgboost, if you fit a multiclass boosted tree, you might expect the class probabilities to come back as a matrix1. If you have four classes and make predictions on three samples, you get a vector of 12 probability values. You need to convert these to a rectangular data set.

Another example is the predict method for ranger, which encapsulates the actual predictions in a more complex object structure.

These are the types of problems that the postprocessor will solve.

Are there other modes?

Not yet but there will be. For example, it might make sense to have a different mode when doing risk-based modeling via Cox regression models. That would enable different classes of objects and those might be needed since the types of models don’t make direct predictions of the outcome.

If you have a suggestion, please ad a GitHub issue to discuss it.


  1. narrator: they don’t