earth
packageFunction | Works |
---|---|
tidypredict_fit() , tidypredict_sql() ,
parse_model() |
✔ |
tidypredict_to_column() |
✔ |
tidypredict_test() |
✔ |
tidypredict_interval() ,
tidypredict_sql_interval() |
✗ |
parsnip |
✔ |
tidypredict_
functionslibrary(earth)
data("etitanic", package = "earth")
<- earth(age ~ sibsp + parch, data = etitanic, degree = 3) model
Create the R formula
tidypredict_fit(model)
#> 22.2918960405403 + (ifelse(parch < 2, 2 - parch, 0) * 4.85356462114366) +
#> (ifelse(parch > 2, parch - 2, 0) * 13.0493891423277) + (ifelse(parch >
#> 4, parch - 4, 0) * -18.8998708031826) + (ifelse(sibsp > 1,
#> sibsp - 1, 0) * -7.71566779782023) + (ifelse(sibsp > 1, sibsp -
#> 1, 0) * ifelse(parch < 1, 1 - parch, 0) * 7.40395975552272) +
#> (ifelse(sibsp > 1, sibsp - 1, 0) * ifelse(parch > 1, parch -
#> 1, 0) * 4.41874354843212)
SQL output example
tidypredict_sql(model, dbplyr::simulate_odbc())
#> <SQL> 22.2918960405403 + (CASE WHEN (`parch` < 2.0) THEN (2.0 - `parch`) WHEN NOT(`parch` < 2.0) THEN (0.0) END * 4.85356462114366) + (CASE WHEN (`parch` > 2.0) THEN (`parch` - 2.0) WHEN NOT(`parch` > 2.0) THEN (0.0) END * 13.0493891423277) + (CASE WHEN (`parch` > 4.0) THEN (`parch` - 4.0) WHEN NOT(`parch` > 4.0) THEN (0.0) END * -18.8998708031826) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT(`sibsp` > 1.0) THEN (0.0) END * -7.71566779782023) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT(`sibsp` > 1.0) THEN (0.0) END * CASE WHEN (`parch` < 1.0) THEN (1.0 - `parch`) WHEN NOT(`parch` < 1.0) THEN (0.0) END * 7.40395975552272) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT(`sibsp` > 1.0) THEN (0.0) END * CASE WHEN (`parch` > 1.0) THEN (`parch` - 1.0) WHEN NOT(`parch` > 1.0) THEN (0.0) END * 4.41874354843212)
Add the prediction to the original table
library(dplyr)
%>%
etitanic tidypredict_to_column(model) %>%
glimpse()
#> Rows: 1,046
#> Columns: 7
#> $ pclass <fct> 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1…
#> $ survived <int> 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1…
#> $ sex <fct> female, male, female, male, female, male, female, male, femal…
#> $ age <dbl> 29.0000, 0.9167, 2.0000, 30.0000, 25.0000, 48.0000, 63.0000, …
#> $ sibsp <int> 0, 1, 1, 1, 1, 0, 1, 0, 2, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1…
#> $ parch <int> 0, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1…
#> $ fit <dbl> 31.99903, 22.29190, 22.29190, 22.29190, 22.29190, 31.99903, 3…
Confirm that tidypredict
results match to the
model’s predict()
results
tidypredict_test(model, etitanic)
#> tidypredict test results
#> Difference threshold: 1e-12
#>
#> All results are within the difference threshold
tidypredict
supports the glm
argument as
well:
<- earth(survived ~ .,
model data = etitanic,
glm = list(family = binomial), degree = 2)
tidypredict_fit(model)
#> 1 - 1/(1 + exp(0.961709503833137 + (ifelse(age > 32, age - 32,
#> 0) * -0.00471937845111412) + (ifelse(pclass == "2nd", 1,
#> 0) * ifelse(sex == "male", 1, 0) * -0.265689204831234) +
#> (ifelse(pclass == "3rd", 1, 0) * -0.815453520799443) + (ifelse(pclass ==
#> "3rd", 1, 0) * ifelse(sibsp < 4, 4 - sibsp, 0) * 0.102221808806423) +
#> (ifelse(pclass == "3rd", 1, 0) * ifelse(sex == "male", 1,
#> 0) * 0.19310203418243) + (ifelse(sex == "male", 1, 0) *
#> -0.570034960114421) + (ifelse(sex == "male", 1, 0) * ifelse(age <
#> 16, 16 - age, 0) * 0.0450523226133542)))
The spec sets the is_glm
entry to 1, as well as the
family
and link
entries.
str(parse_model(model), 2)
#> List of 2
#> $ general:List of 6
#> ..$ model : chr "earth"
#> ..$ type : chr "tree"
#> ..$ version: num 2
#> ..$ is_glm : num 1
#> ..$ family : chr "binomial"
#> ..$ link : chr "logit"
#> $ terms :List of 8
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_tree" "list"
parsnip
fitted models are also supported by
tidypredict
:
library(parsnip)
<- mars(mode = "regression", prod_degree = 3) %>%
p_model set_engine("earth") %>%
fit(age ~ sibsp + parch, data = etitanic)
tidypredict_fit(p_model)
#> 22.2918960405403 + (ifelse(parch < 2, 2 - parch, 0) * 4.85356462114366) +
#> (ifelse(parch > 2, parch - 2, 0) * 13.0493891423277) + (ifelse(parch >
#> 4, parch - 4, 0) * -18.8998708031826) + (ifelse(sibsp > 1,
#> sibsp - 1, 0) * -7.71566779782023) + (ifelse(sibsp > 1, sibsp -
#> 1, 0) * ifelse(parch < 1, 1 - parch, 0) * 7.40395975552272) +
#> (ifelse(sibsp > 1, sibsp - 1, 0) * ifelse(parch > 1, parch -
#> 1, 0) * 4.41874354843212)
Here is an example of the model spec:
<- parse_model(model)
pm str(pm, 2)
#> List of 2
#> $ general:List of 6
#> ..$ model : chr "earth"
#> ..$ type : chr "tree"
#> ..$ version: num 2
#> ..$ is_glm : num 1
#> ..$ family : chr "binomial"
#> ..$ link : chr "logit"
#> $ terms :List of 8
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_tree" "list"
str(pm$terms[1:2])
#> List of 2
#> $ :List of 4
#> ..$ label : chr "(Intercept)"
#> ..$ coef : num 0.962
#> ..$ is_intercept: num 1
#> ..$ fields :List of 1
#> .. ..$ : list()
#> $ :List of 4
#> ..$ label : chr "h(age-32)"
#> ..$ coef : num -0.00472
#> ..$ is_intercept: num 0
#> ..$ fields :List of 1
#> .. ..$ :List of 4
#> .. .. ..$ type: chr "operation"
#> .. .. ..$ col : chr "age"
#> .. .. ..$ val : num 32
#> .. .. ..$ op : chr "morethan"