Perform (weighted) prediction averaging from survival PredictionSurvs by connecting
PipeOpSurvAvg
to multiple PipeOpLearner outputs.
The resulting prediction will aggregate any predict types that are contained within all inputs.
Any predict types missing from at least one input will be set to NULL
. These are aggregated
as follows:
"response"
,"crank"
, and"lp"
are all a weighted average from the incoming predictions."distr"
is a distr6::VectorDistribution containing distr6::MixtureDistributions.
Weights can be set as a parameter; if none are provided, defaults to equal weights for each prediction.
Input and Output Channels
Input and output channels are inherited from PipeOpEnsemble with a PredictionSurv for inputs and outputs.
State
The $state
is left empty (list()
).
Parameters
The parameters are the parameters inherited from the PipeOpEnsemble.
Internals
Inherits from PipeOpEnsemble by implementing the
private$weighted_avg_predictions()
method.
Super classes
mlr3pipelines::PipeOp
-> mlr3pipelines::PipeOpEnsemble
-> PipeOpSurvAvg
Methods
Method new()
Creates a new instance of this R6 class.
Usage
PipeOpSurvAvg$new(innum = 0, id = "survavg", param_vals = list(), ...)
Arguments
innum
(numeric(1))
Determines the number of input channels. Ifinnum
is 0 (default), a vararg input channel is created that can take an arbitrary number of inputs.id
(
character(1)
)
Identifier of the resulting object.param_vals
(
list()
)
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction....
ANY
Additional arguments passed to mlr3pipelines::PipeOpEnsemble.
Examples
if (FALSE) { # \dontrun{
if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
library(mlr3)
library(mlr3pipelines)
task = tsk("rats")
p1 = lrn("surv.coxph")$train(task)$predict(task)
p2 = lrn("surv.kaplan")$train(task)$predict(task)
poc = po("survavg", param_vals = list(weights = c(0.2, 0.8)))
poc$predict(list(p1, p2))
}
} # }