-
-
Notifications
You must be signed in to change notification settings - Fork 168
/
misc_loss_functions.R
165 lines (143 loc) · 6.38 KB
/
misc_loss_functions.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#' Calculate Loss Functions
#'
#' @param predicted predicted scores, either vector of matrix, these are returned from the model specific \code{predict_function()}
#' @param observed observed scores or labels, these are supplied as explainer specific \code{y}
#' @param p_min for cross entropy, minimal value for probability to make sure that \code{log} will not explode
#' @param cutoff classification threshold for the accuracy loss functions
#' @param na.rm logical, should missing values be removed?
#' @param x either an explainer or type of the model. One of "regression", "classification", "multiclass".
#'
#' @return numeric - value of the loss function
#'
#' @aliases loss_cross_entropy loss_sum_of_squares loss_root_mean_square loss_accuracy loss_one_minus_auc
#' @export
#' @examples
#' \donttest{
#' library("ranger")
#' titanic_ranger_model <- ranger(survived~., data = titanic_imputed, num.trees = 50,
#' probability = TRUE)
#' loss_one_minus_auc(titanic_imputed$survived, yhat(titanic_ranger_model, titanic_imputed))
#'
#' HR_ranger_model_multi <- ranger(status~., data = HR, num.trees = 50, probability = TRUE)
#' loss_cross_entropy(as.numeric(HR$status), yhat(HR_ranger_model_multi, HR))
#'
#' }
#' @rdname loss_functions
#' @export
loss_cross_entropy <- function(observed, predicted, p_min = 0.0001, na.rm = TRUE) {
p <- sapply(seq_along(observed), function(i) predicted[i, observed[i]] )
sum(-log(pmax(p, p_min)), na.rm = TRUE)
}
attr(loss_cross_entropy, "loss_name") <- "Cross entropy"
#' @rdname loss_functions
#' @export
loss_sum_of_squares <- function(observed, predicted, na.rm = TRUE)
sum((observed - predicted)^2, na.rm = na.rm)
attr(loss_sum_of_squares, "loss_name") <- "Sum of squared residuals (SSR)"
#' @rdname loss_functions
#' @export
loss_root_mean_square <- function(observed, predicted, na.rm = TRUE)
sqrt(mean((observed - predicted)^2, na.rm = na.rm))
attr(loss_root_mean_square, "loss_name") <- "Root mean square error (RMSE)"
#' @rdname loss_functions
#' @export
loss_accuracy <- function(observed, predicted, na.rm = TRUE)
mean(observed == predicted, na.rm = na.rm) # this works for classes not probabilities
attr(loss_accuracy, "loss_name") <- "Accuracy"
#' @rdname loss_functions
#' @export
loss_one_minus_accuracy <- function(observed, predicted, cutoff = 0.5, na.rm = TRUE) {
tp = sum((observed == 1) * (predicted >= cutoff), na.rm = na.rm)
fp = sum((observed == 0) * (predicted >= cutoff), na.rm = na.rm)
tn = sum((observed == 0) * (predicted < cutoff), na.rm = na.rm)
fn = sum((observed == 1) * (predicted < cutoff), na.rm = na.rm)
acc <- (tp tn)/(tp fp tn fn)
1 - acc
}
attr(loss_one_minus_accuracy, "loss_name") <- "One minus Accuracy"
#' @rdname loss_functions
#' @export
get_loss_one_minus_accuracy <- function(cutoff = 0.5, na.rm = TRUE) {
function(o, p) loss_one_minus_accuracy(o, p, cutoff = cutoff, na.rm = na.rm)
}
#' @rdname loss_functions
#' @export
loss_one_minus_auc <- function(observed, predicted){
tpr_tmp <- tapply(observed, predicted, sum)
TPR <- c(0,cumsum(rev(tpr_tmp)))/sum(observed)
fpr_tmp <- tapply(1 - observed, predicted, sum)
FPR <- c(0,cumsum(rev(fpr_tmp)))/sum(1 - observed)
auc <- sum(diff(FPR)*(TPR[-1] TPR[-length(TPR)])/2)
1 - auc
}
attr(loss_one_minus_auc, "loss_name") <- "One minus AUC"
#' @rdname loss_functions
#' @export
get_loss_default <- function(x) {
# explainer is an explainer or type of an explainer
if ("explainer" %in% class(x)) x <- x$model_info$type
switch (x,
"regression" = loss_root_mean_square,
"classification" = loss_one_minus_auc,
"multiclass" = loss_cross_entropy,
stop("`explainer$model_info$type` should be one of ['regression', 'classification', 'multiclass'] - pass `model_info = list(type = $type$)` to the `explain` function. Submit an issue on https://github.com/ModelOriented/DALEX/issues if you think that this model should be covered by default.")
)
}
#' @rdname loss_functions
#' @export
loss_default <- function(x) {
warning("`loss_default()` is deprecated; use `get_loss_default()` instead.")
get_loss_default(x)
}
#' Wrapper for Loss Functions from the yardstick Package
#'
#' The yardstick package provides many auxiliary functions for calculating
#' the predictive performance of the model. However, they have an interface
#' that is consistent with the tidyverse philosophy. The loss_yardstick
#' function adapts loss functions from the yardstick package to functions
#' understood by DALEX. Type compatibility for y-values and for predictions
#' must be guaranteed by the user.
#'
#' @param loss loss function from the \code{yardstick} package
#' @param reverse shall the metric be reversed? for loss metrics lower values are better. \code{reverse = TRUE} is useful for accuracy-like metrics
#' @param reference if the metric is reverse then it is calculated as \code{reference - loss}. The default value is 1.
#'
#' @return loss function that can be used in the model_parts function
#'
#' @export
#' @examples
#' \donttest{
#' titanic_glm_model <- glm(survived~., data = titanic_imputed, family = "binomial")
#' explainer_glm <- DALEX::explain(titanic_glm_model,
#' data = titanic_imputed[,-8],
#' y = factor(titanic_imputed$survived))
#' # See the 'How to use DALEX with the yardstick package' vignette
#' # which explains this model with measures implemented in the 'yardstick' package
#' }
#'
#' @rdname get_loss_yardstick
#' @export
get_loss_yardstick <- function(loss, reverse = FALSE, reference = 1) {
# wrapper for yardstick loss functions
if (reverse) {
custom_loss <- function(observed, predicted) {
df <- data.frame(observed, predicted)
reference - loss(df, observed, predicted)$.estimate
}
attr(custom_loss, "loss_name") <- paste0(reference, " - ", deparse(substitute(loss)))
} else {
custom_loss <- function(observed, predicted) {
df <- data.frame(observed, predicted)
colnames(df) <- c("observed", "predicted")
loss(df, observed, predicted)$.estimate
}
attr(custom_loss, "loss_name") <- deparse(substitute(loss))
}
custom_loss
}
#' @rdname get_loss_yardstick
#' @export
loss_yardstick <- function(loss, reverse = FALSE, reference = 1) {
warning("`loss_yardstick()` is deprecated; use `get_loss_yardstick()` instead.")
get_loss_yardstick(loss = loss, reverse = reverse, reference = reference)
}