Skip to content
Function Works
tidypredict_fit(), tidypredict_sql(), parse_model()
tidypredict_to_column()
tidypredict_test()
tidypredict_interval(), tidypredict_sql_interval()
parsnip

tidypredict_ functions

library(xgboost)

logregobj <- function(preds, dtrain) {
  labels <- xgboost::getinfo(dtrain, "label")
  preds <- 1 / (1 + exp(-preds))
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}

xgb_bin_data <- xgboost::xgb.DMatrix(
  as.matrix(mtcars[, -9]),
  label = mtcars$am
)

model <- xgboost::xgb.train(
  params = list(max_depth = 2, objective = "binary:logistic", base_score = 0.5),
  data = xgb_bin_data, nrounds = 50
)
  • Create the R formula

    tidypredict_fit(model)
    #> 1 - 1/(1 + exp(0 + case_when(wt >= 3.18000007 ~ -0.436363667, 
    #>     (qsec < 19.1849995 | is.na(qsec)) & (wt < 3.18000007 | is.na(wt)) ~ 
    #>         0.428571463, qsec >= 19.1849995 & (wt < 3.18000007 | 
    #>         is.na(wt)) ~ 0) + case_when((wt < 3.01250005 | is.na(wt)) ~ 
    #>     0.311573088, (hp < 222.5 | is.na(hp)) & wt >= 3.01250005 ~ 
    #>     -0.392053694, hp >= 222.5 & wt >= 3.01250005 ~ -0.0240745768) + 
    #>     case_when((gear < 3.5 | is.na(gear)) ~ -0.355945677, (wt < 
    #>         3.01250005 | is.na(wt)) & gear >= 3.5 ~ 0.325712085, 
    #>         wt >= 3.01250005 & gear >= 3.5 ~ -0.0384863913) + case_when((gear < 
    #>     3.5 | is.na(gear)) ~ -0.309683114, (wt < 3.01250005 | is.na(wt)) & 
    #>     gear >= 3.5 ~ 0.283893973, wt >= 3.01250005 & gear >= 3.5 ~ 
    #>     -0.032039877) + case_when((gear < 3.5 | is.na(gear)) ~ -0.275577009, 
    #>     (wt < 3.01250005 | is.na(wt)) & gear >= 3.5 ~ 0.252453178, 
    #>     wt >= 3.01250005 & gear >= 3.5 ~ -0.0266750772) + case_when((gear < 
    #>     3.5 | is.na(gear)) ~ -0.248323873, (qsec < 17.6599998 | is.na(qsec)) & 
    #>     gear >= 3.5 ~ 0.261978835, qsec >= 17.6599998 & gear >= 3.5 ~ 
    #>     -0.00959526002) + case_when((gear < 3.5 | is.na(gear)) ~ 
    #>     -0.225384533, (wt < 3.01250005 | is.na(wt)) & gear >= 3.5 ~ 
    #>     0.218285918, wt >= 3.01250005 & gear >= 3.5 ~ -0.0373593047) + 
    #>     case_when((gear < 3.5 | is.na(gear)) ~ -0.205454513, (qsec < 
    #>         18.7550011 | is.na(qsec)) & gear >= 3.5 ~ 0.196076646, 
    #>         qsec >= 18.7550011 & gear >= 3.5 ~ -0.0544253439) + case_when((wt < 
    #>     3.01250005 | is.na(wt)) ~ 0.149246693, (qsec < 17.4099998 | 
    #>     is.na(qsec)) & wt >= 3.01250005 ~ 0.0354709327, qsec >= 17.4099998 & 
    #>     wt >= 3.01250005 ~ -0.226075932) + case_when((gear < 3.5 | 
    #>     is.na(gear)) ~ -0.184417158, (wt < 3.01250005 | is.na(wt)) & 
    #>     gear >= 3.5 ~ 0.176768288, wt >= 3.01250005 & gear >= 3.5 ~ 
    #>     -0.0237750355) + case_when((gear < 3.5 | is.na(gear)) ~ -0.168993726, 
    #>     (qsec < 18.6049995 | is.na(qsec)) & gear >= 3.5 ~ 0.155569643, 
    #>     qsec >= 18.6049995 & gear >= 3.5 ~ -0.0325752236) + case_when((wt < 
    #>     3.01250005 | is.na(wt)) ~ 0.119126029, wt >= 3.01250005 ~ 
    #>     -0.105012275) + case_when((qsec < 17.1749992 | is.na(qsec)) ~ 
    #>     0.117254697, qsec >= 17.1749992 ~ -0.0994235724) + case_when((wt < 
    #>     3.18000007 | is.na(wt)) ~ 0.097100094, wt >= 3.18000007 ~ 
    #>     -0.10567718) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0824323222, wt >= 3.18000007 ~ -0.091120176) + case_when((qsec < 
    #>     17.5100002 | is.na(qsec)) ~ 0.0854752287, qsec >= 17.5100002 ~ 
    #>     -0.0764453933) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0749477893, wt >= 3.18000007 ~ -0.0799863264) + case_when((qsec < 
    #>     17.7099991 | is.na(qsec)) ~ 0.0728750378, qsec >= 17.7099991 ~ 
    #>     -0.0646049976) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0682478622, wt >= 3.18000007 ~ -0.0711427554) + case_when((wt < 
    #>     3.18000007 | is.na(wt)) ~ 0.0579533465, wt >= 3.18000007 ~ 
    #>     -0.0613371208) + case_when((qsec < 18.1499996 | is.na(qsec)) ~ 
    #>     0.0595484748, qsec >= 18.1499996 ~ -0.0546668135) + case_when((wt < 
    #>     3.18000007 | is.na(wt)) ~ 0.0535288528, wt >= 3.18000007 ~ 
    #>     -0.0558333211) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0454574414, wt >= 3.18000007 ~ -0.048143398) + case_when((qsec < 
    #>     18.5600014 | is.na(qsec)) ~ 0.0422042683, qsec >= 18.5600014 ~ 
    #>     -0.0454404354) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0420555808, wt >= 3.18000007 ~ -0.0449385941) + case_when((qsec < 
    #>     18.5600014 | is.na(qsec)) ~ 0.0393446013, qsec >= 18.5600014 ~ 
    #>     -0.0425945036) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0391179025, wt >= 3.18000007 ~ -0.0420661867) + case_when((qsec < 
    #>     18.4099998 | is.na(qsec)) ~ 0.0304145869, qsec >= 18.4099998 ~ 
    #>     -0.031833414) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0362136625, wt >= 3.18000007 ~ -0.038949281) + case_when((qsec < 
    #>     18.4099998 | is.na(qsec)) ~ 0.0295153651, qsec >= 18.4099998 ~ 
    #>     -0.0307046026) + case_when((drat < 3.80999994 | is.na(drat)) ~ 
    #>     -0.0306891855, drat >= 3.80999994 ~ 0.0288283136) + case_when((qsec < 
    #>     18.4099998 | is.na(qsec)) ~ 0.0271221269, qsec >= 18.4099998 ~ 
    #>     -0.0281750448) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
    #>     0.0228891298, qsec >= 18.4099998 ~ -0.0238814205) + case_when((drat < 
    #>     3.80999994 | is.na(drat)) ~ -0.0296511576, drat >= 3.80999994 ~ 
    #>     0.0280048084) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
    #>     0.0214707125, qsec >= 18.4099998 ~ -0.0224219449) + case_when((qsec < 
    #>     18.4099998 | is.na(qsec)) ~ 0.0181306079, qsec >= 18.4099998 ~ 
    #>     -0.0190209728) + case_when((wt < 3.18000007 | is.na(wt)) ~ 
    #>     0.0379650332, wt >= 3.18000007 ~ -0.0395050682) + case_when((qsec < 
    #>     18.4099998 | is.na(qsec)) ~ 0.0194106717, qsec >= 18.4099998 ~ 
    #>     -0.0202215631) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
    #>     0.0164139606, qsec >= 18.4099998 ~ -0.0171694476) + case_when((qsec < 
    #>     18.4099998 | is.na(qsec)) ~ 0.013879573, qsec >= 18.4099998 ~ 
    #>     -0.0145772668) + case_when((qsec < 18.4099998 | is.na(qsec)) ~ 
    #>     0.0117362784, qsec >= 18.4099998 ~ -0.0123759825) + case_when((wt < 
    #>     3.18000007 | is.na(wt)) ~ 0.0388614088, wt >= 3.18000007 ~ 
    #>     -0.0400568396) + log(0.5/(1 - 0.5))))
  • Add the prediction to the original table

    library(dplyr)
    
    mtcars %>%
      tidypredict_to_column(model) %>%
      glimpse()
    #> Rows: 32
    #> Columns: 12
    #> $ mpg  <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19.2,…
    #> $ cyl  <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4, 4,…
    #> $ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, 140…
    #> $ hp   <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180, 18…
    #> $ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.92,…
    #> $ wt   <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, 3.1…
    #> $ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, 22.…
    #> $ vs   <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1,…
    #> $ am   <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,…
    #> $ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4,…
    #> $ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2, 1,…
    #> $ fit  <dbl> 0.98576418, 0.98576418, 0.92735110, 0.01081509, 0.04639094,…
  • Confirm that tidypredict results match to the model’s predict() results. The xg_df argument expects the xgb.DMatrix data set.

    tidypredict_test(model, mtcars, xg_df = xgb_bin_data)
    #> tidypredict test results
    #> Difference threshold: 1e-12
    #> 
    #>  All results are within the difference threshold

parsnip

parsnip fitted models are also supported by tidypredict:

library(parsnip)

p_model <- boost_tree(mode = "regression") %>%
  set_engine("xgboost") %>%
  fit(am ~ ., data = mtcars)
tidypredict_test(p_model, mtcars, xg_df = xgb_bin_data)
#> tidypredict test results
#> Difference threshold: 1e-12
#> 
#> Fitted records above the threshold: 15
#> 
#> Fit max  difference:
#> Lower max difference:
#> Upper max difference:8.06462707725331e-08

Parse model spec

Here is an example of the model spec:

pm <- parse_model(model)
str(pm, 2)
#> List of 2
#>  $ general:List of 7
#>   ..$ model        : chr "xgb.Booster"
#>   ..$ type         : chr "xgb"
#>   ..$ niter        : num 50
#>   ..$ params       :List of 4
#>   ..$ feature_names: chr [1:10] "mpg" "cyl" "disp" "hp" ...
#>   ..$ nfeatures    : int 10
#>   ..$ version      : num 1
#>  $ trees  :List of 42
#>   ..$ 0 :List of 3
#>   ..$ 1 :List of 3
#>   ..$ 2 :List of 3
#>   ..$ 3 :List of 3
#>   ..$ 4 :List of 3
#>   ..$ 5 :List of 3
#>   ..$ 6 :List of 3
#>   ..$ 7 :List of 3
#>   ..$ 8 :List of 3
#>   ..$ 9 :List of 3
#>   ..$ 10:List of 3
#>   ..$ 11:List of 2
#>   ..$ 12:List of 2
#>   ..$ 13:List of 2
#>   ..$ 14:List of 2
#>   ..$ 15:List of 2
#>   ..$ 16:List of 2
#>   ..$ 17:List of 2
#>   ..$ 18:List of 2
#>   ..$ 19:List of 2
#>   ..$ 20:List of 2
#>   ..$ 21:List of 2
#>   ..$ 22:List of 2
#>   ..$ 23:List of 2
#>   ..$ 24:List of 2
#>   ..$ 25:List of 2
#>   ..$ 26:List of 2
#>   ..$ 27:List of 2
#>   ..$ 28:List of 2
#>   ..$ 29:List of 2
#>   ..$ 30:List of 2
#>   ..$ 31:List of 2
#>   ..$ 32:List of 2
#>   ..$ 33:List of 2
#>   ..$ 34:List of 2
#>   ..$ 35:List of 2
#>   ..$ 36:List of 2
#>   ..$ 37:List of 2
#>   ..$ 38:List of 2
#>   ..$ 39:List of 2
#>   ..$ 40:List of 2
#>   ..$ 41:List of 2
#>  - attr(*, "class")= chr [1:3] "parsed_model" "pm_xgb" "list"
str(pm$trees[1])
#> List of 1
#>  $ 0:List of 3
#>   ..$ :List of 2
#>   .. ..$ prediction: num -0.436
#>   .. ..$ path      :List of 1
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "less"
#>   .. .. .. ..$ missing: logi FALSE
#>   ..$ :List of 2
#>   .. ..$ prediction: num 0.429
#>   .. ..$ path      :List of 2
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "qsec"
#>   .. .. .. ..$ val    : num 19.2
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE
#>   ..$ :List of 2
#>   .. ..$ prediction: num 0
#>   .. ..$ path      :List of 2
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "qsec"
#>   .. .. .. ..$ val    : num 19.2
#>   .. .. .. ..$ op     : chr "less"
#>   .. .. .. ..$ missing: logi FALSE
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE