Skip to contents

Perform (weighted) prediction averaging from survival PredictionSurvs by connecting PipeOpSurvAvg to multiple PipeOpLearner outputs.

The resulting prediction will aggregate any predict types that are contained within all inputs. Any predict types missing from at least one input will be set to NULL. These are aggregated as follows:

Weights can be set as a parameter; if none are provided, defaults to equal weights for each prediction.

Input and Output Channels

Input and output channels are inherited from PipeOpEnsemble with a PredictionSurv for inputs and outputs.

State

The $state is left empty (list()).

Parameters

The parameters are the parameters inherited from the PipeOpEnsemble.

Internals

Inherits from PipeOpEnsemble by implementing the private$weighted_avg_predictions() method.

Super classes

mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpEnsemble -> PipeOpSurvAvg

Methods

Inherited methods


Method new()

Creates a new instance of this R6 class.

Usage

PipeOpSurvAvg$new(innum = 0, id = "survavg", param_vals = list(), ...)

Arguments

innum

(numeric(1))
Determines the number of input channels. If innum is 0 (default), a vararg input channel is created that can take an arbitrary number of inputs.

id

(character(1))
Identifier of the resulting object.

param_vals

(list())
List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.

...

ANY
Additional arguments passed to mlr3pipelines::PipeOpEnsemble.


Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpSurvAvg$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

if (FALSE) {
if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
  library(mlr3)
  library(mlr3pipelines)

  task = tsk("rats")
  p1 = lrn("surv.coxph")$train(task)$predict(task)
  p2 = lrn("surv.kaplan")$train(task)$predict(task)
  poc = po("survavg", param_vals = list(weights = c(0.2, 0.8)))
  poc$predict(list(p1, p2))
}
}