plug_grid()
specifies the type of grid used in the model tuning. It
accepts a function .f
that will be fed the tuning parameters defined
in the model and the recipe. Only grid functions which return a
data.frame
object will be allowed or expand.grid
. See the
details section for how expand.grid
can be used and the package
dials
for the grid functions. If a model has
been fit before adding the grid, it will need to be refit.
drop_grid()
removes the grid specification from the tidyflow. Note
that it keeps other preprocessing steps such as the recipe and model.
replace_grid()
first removes the grid, then adds a new grid
specification. Any model that has already been fit based on this
split will need to be refit.
plug_grid(x, .f, ...)
drop_grid(x)
replace_grid(x, .f, ...)
A tidyflow
A function which will be passed to the tuned arguments from the model
and recipe. There are two type of functions that can be used here. For
generating random grids, .f
must return an object of class
data.frame
. In particular, the user doesn't need to specified the
parameters in ...
since they are extracted and passed directly to
the grid function. See package dials
for all related
grid_*
functions. The other type of function that can be used is
expand.grid
. In particular, if .f
is expand.grid
all tuning
arguments should be specified in ...
. This does not support parameter
objects like mixture
but rather the raw values to be
expanded by expand.grid
. For example, instead of
mixture = dials::mixture()
it should be mixture = c(0, 0.5, 1)
.
See the details section and example section for a more thorough description.
arguments passed to .f
. The processing of ...
respects the quotation rules from .f
. In other words, if the function
allows variables as strings and as names, the user can specify both.
See the example section.
The tidyflow x
, updated with either a new or removed grid specification.
The grid specification is an optional step in the tidyflow. You can add
the data, prepare a recipe and fit the model without adding a grid
specification. However, for doing a grid search, the user will need to
specify a resample and grid specification with plug_resample
and plug_grid
respectively.
plug_grid
accepts two types of functions.
expand.grid
: Using expand.grid
allows to create a grid
of all possible combinations. For example, to create a grid of all
possible values in penalty
and mixture
, we can write
plug_grid(expand.grid, penalty = seq(0.01, 0.05, 0.01),
mixture = seq(0, 1, 0.1))
. Defining the grid this way, requires
the user to define all tuning parameters explicitily in this step.
For example, instead of specifying mixture = mixture()
from
mixture
, the user should specify the raw values
used to expand: mixture = c(0, 0.5, 1)
. This applies to all tuning
parameters defined in the model and recipe.
grid_*
: If the grid_*
functions from
dials
are specified, the user only needs to
specify the function in .f
and all tuning parameters are
extracted automatically. If the user wants to override the default
values for the parameters, it can do so by specifying the parameters
in ...
. For example, limiting the range of the mixture can be
specified as: plug_grid(grid_regular,
mixture = mixture(range = c(0, 0.5)))
. The benefit of this approach
is that the user can hand-pick some parameters to change manually
while the remaining are assigned sensible values based on
dials
. Parameters such as
mtry
which need to be estimated from the data are
assigned default values through finalize
, such
that the user doesn't have to set them manually. For more details
see the example section.
Regardless of the type of function used in plug_grid
, if a tuning
parameter in the model/recipe is assigned a name (for example,
tune("new_name")
) and the user is interested in specifying
the tuning values for that parameter using plug_grid
or
replace_grid
, then the parameter name in ...
should have
the custom name. See the example section for a concrete example.
if (FALSE) {
library(parsnip)
library(rsample)
library(tune)
library(dials)
library(recipes)
# Grid search:
# No need to define the values of the tuning parameters
# as they have defaults. For example, see the output of dials::penalty()
# `plug_grid` defines the grid. You can pass all of the arguments of
# `grid_regular`:
mod <-
mtcars %>%
tidyflow() %>%
plug_split(initial_split) %>%
plug_formula(mpg ~ .) %>%
plug_resample(vfold_cv) %>%
plug_model(set_engine(linear_reg(penalty = tune(), mixture = tune()), "glmnet")) %>%
plug_grid(grid_regular, levels = 5)
res <- fit(mod)
# See the grid that was generated after the fit:
res %>%
pull_tflow_grid()
# The argument `levels = 5` tells it to generate 5 x 5 combination
# of all possible vaues. That's why you have 25 rows.
# You can extract the result from `plug_grid` with `pull_tflow_fit_tuning`:
pull_tflow_fit_tuning(res)
# Visualize it:
pull_tflow_fit_tuning(res) %>%
autoplot()
# And explore it:
pull_tflow_fit_tuning(res) %>%
collect_metrics()
# If you want to specify tuning values, you can do so with
# `plug_grid` or `replace_grid` but they must have the same
# name as the tuning parameter
res2 <-
mod %>%
replace_grid(grid_regular, penalty = penalty(c(-1, 0)), levels = 2) %>%
fit()
res2 %>%
pull_tflow_fit_tuning() %>%
show_best("rsq")
# If tune assigns a name, then `plug_grid` or `replace_grid` must
# use that name to replace it
model <-
set_engine(
linear_reg(penalty = tune("my_penalty"), mixture = tune("my_mixture")),
"glmnet"
)
# You must use `my_penalty`
res3 <-
mod %>%
replace_model(model) %>%
replace_grid(grid_regular, my_penalty = penalty(c(-1, 0)), levels = 2) %>%
fit()
res3 %>%
pull_tflow_fit_tuning() %>%
show_best("rsq")
# If you want to create a grid of all possible combination of the tuning
# parameters, you must use only `expand.grid` and name every single
# model parameter:
res4 <-
mod %>%
replace_grid(expand.grid,
penalty = seq(0.01, 0.02, 0.005),
mixture = c(0, 0.5, 1)) %>%
fit()
# The resulting grid is all of the possible combinations
# from the values defined above:
res4 %>%
pull_tflow_grid()
# See how they values are not random, but rather
# all combination of the supplied values
res4 %>%
pull_tflow_fit_tuning() %>%
collect_metrics()
# You can also tune values from a recipe directly
res5 <-
res3 %>%
drop_formula() %>%
plug_recipe(~ recipe(mpg ~ ., data = .) %>% step_ns(hp, deg_free = tune())) %>%
fit()
res5 %>%
pull_tflow_fit_tuning() %>%
show_best("rsq")
}