Skip to content

Commit

Permalink
add prediction() method for coxph models (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
leeper committed Aug 29, 2016
1 parent 920f57f commit 56d4a56
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 5 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 16,7 @@ S3method(persp,lm)
S3method(persp,loess)
S3method(plot,margins)
S3method(plot,marginslist)
S3method(prediction,coxph)
S3method(prediction,glm)
S3method(prediction,lm)
S3method(prediction,loess)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 3,7 @@
## margins 0.2.7

* Added a `prediction()` method for "polr" objects (from `MASS::polr()`). (#3)
* Added a `prediction()` method for "coxph" objects (from `survival::coxph()`). (#3)

## margins 0.2.7

Expand Down
47 changes: 42 additions & 5 deletions R/prediction.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 41,8 @@ prediction.lm <- function(model, data, type = "response", ...) {
structure(list(fitted = pred[["fit"]],
se.fitted = pred[["se.fit"]]),
class = c("prediction", "data.frame"),
row.names = seq_len(length(pred[["fit"]])))
row.names = seq_len(length(pred[["fit"]])),
type = type)
}

#' @rdname prediction
Expand All @@ -67,7 68,9 @@ prediction.glm <- function(model, data, type = c("response", "link"), ...) {
structure(list(fitted = pred[["fit"]],
se.fitted = pred[["se.fit"]]),
class = c("prediction", "data.frame"),
row.names = seq_len(length(pred[["fit"]])))
row.names = seq_len(length(pred[["fit"]])),
model.class = class(model),
type = type)
}

#' @rdname prediction
Expand All @@ -93,7 96,9 @@ prediction.loess <- function(model, data, type = "response", ...) {
structure(list(fitted = pred[["fit"]],
se.fitted = pred[["se.fit"]]),
class = c("prediction", "data.frame"),
row.names = seq_len(length(pred[["fit"]])))
row.names = seq_len(length(pred[["fit"]])),
model.class = class(model),
type = type)
}

#' @rdname prediction
Expand All @@ -118,7 123,37 @@ prediction.nls <- function(model, data, ...) {
structure(list(fitted = pred[["fit"]],
se.fitted = pred[["se.fit"]]),
class = c("prediction", "data.frame"),
row.names = seq_len(length(pred[["fit"]])))
row.names = seq_len(length(pred[["fit"]])),
model.class = class(model),
type = type)
}

#' @rdname prediction
#' @export
prediction.coxph <- function(model, data, type = c("risk", "expected", "lp"), ...) {
# setup data
if (missing(data)) {
if (!is.null(model[["call"]][["data"]])) {
data <- eval(model[["call"]][["data"]], parent.frame())
} else {
data <- get_all_vars(model[["terms"]], data = model[["model"]])
}
}

type <- match.arg(type)

# extract predicted value at input value (value can only be 1 number)
pred <- predict(model, newdata = data, type = type, se.fit = TRUE, ...)
class(pred[["fit"]]) <- c("fit", "numeric")
class(pred[["se.fit"]]) <- c("se.fit", "numeric")

# obs-x-2 data.frame of predictions
structure(list(fitted = pred[["fit"]],
se.fitted = pred[["se.fit"]]),
class = c("prediction", "data.frame"),
row.names = seq_len(length(pred[["fit"]])),
model.class = class(model),
type = type)
}

#' @rdname prediction
Expand Down Expand Up @@ -146,7 181,9 @@ prediction.polr <- function(model, data, ...) {
se.fitted = pred[["se.fit"]]),
probs),
class = c("prediction", "data.frame"),
row.names = seq_len(length(pred[["fit"]])))
row.names = seq_len(length(pred[["fit"]])),
model.class = class(model),
type = type)
}

#' @export
Expand Down
4 changes: 4 additions & 0 deletions man/prediction.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 56d4a56

Please sign in to comment.