How to add estimability checking to your model's `predict` method
estimability package, Version 2.0.0
Source:vignettes/add-est-check.Rmd
add-est-check.RmdThe goal of this short vignette is to show how you can easily add
estimability checking to your package’s predict() methods.
Suppose that you have developed a model class that has elements
$coefficients, $formula, etc. Suppose it also
has an $env element, an environment that can hold
miscellaneous information. This is not absolutely necessary, but handy
if it exists. Your model class involves some kind of linear
predictor.
We are concerned with models that:
- Allow rank deficiencies (where some predictors may be excluded)
- Allow predictions for new data
For any such model, it is important to add estimability checking to your predict method, because the regression coefficients are not unique – and hence that predictions may not be unique. It can be shown that predictions on new data are unique only for cases that fall within the row space of the model matrix. The estimability package is designed to check for this.
The recommended design for accommodating rank-deficient models is to
follow the example of stats::lm objects, where any
predictors that are excluded have a corresponding regression coefficient
of NA. Please note that this NA code actually
doesn’t actually means the coefficient is missing; it is a code that
means that that coefficient has been constrained to be zero. In what
follows, we assume that this convention is used.
First note that estimability checking is not needed unless you are
predicting for new data. So that’s where you need to incorporate
estimability checking. The predict method should be coded
something like this:
predict.mymod <- function(object, newdata, ...) {
# ... some setup code ...
if (!missing(newdata)) {
X <- # ... code to set up the model matrix for newdata ...
b <- coef(object)
if (any(is.na(b))) { # we have rank deficiency so test estimability
if (is.null (nbasis <- object$env$nbasis))
nbasis <- object$nbasis <-
estimability::nonest.basis(model.matrix(object))
b[is.na(b)] <- 0
pred <- X %*% b
pred[!estimability::is.estble(X, nbasis)] <- NA
}
else
pred <- X %*% coef(object)
}
# ... perhaps more code ...
pred
}
That’s it – and this is the fancy version, where we can save
nbasis for use with possible future predictions. Any
non-estimable cases are flagged as NA in the
pred vector.
An alternative way to code this would be to exclude the columns of
X and elements of b that correspond to
NAs in b. But be careful, because you need
all the columns in X in order to check
estimability.
The only other thing you need to do is add estimability
to the Imports list in your `Description file.
A complete example
Below is a complete model-fitting function that illustrates how to
navigate the various practical aspects of incorporating estimability
checking in your code. This code emphasises readability rather than
computational efficiency, and lacks many aspects of robustness or
checking of inputs. The function returns information about the model
terms and factor levels, the computed coefficients coef and
their variance-covariance matrix vcov, and the basis
nbasis of the null space of the model matrix. It utilizes
the Cholesky decomposition with pivoting. Pivoting involves rearranging
the order of predictors, managing rank-deficient cases by moving some
linearly dependent predictors to the end of the line. We can the ignore
or “discard” those dependent predictors, which in estimability parlance
is obtaining a solution by constraining the coefficients of those
predictors to be zero.
Much of what is shown here is how to deal with the pivot
attribute, which comprises the indices of the reordered predictors. The
rank attribute tells us how many nonzero coefficients there
are. We compute coef using the first rank
pivoted predictors, and by putting NAs for the coefficients
constrained to zero, while vcov is only for the
non-NA elements of coef. As is also true in
the standard lm function, the NA elements are
just used to signal which coefficients were constrained to zero. The
nonest.basis() function sees the pivot
attribute and re-orders the predictors accordingly, so we don’t un-pivot
that result. The last few lines of code are necessary for passing needed
information to the predict method.
mylm <- function(formula, data) {
y <- data[[all.vars(formula)[1]]]
X <- model.matrix(formula[-2], data = data)
ch <- chol(t(X) %*% X, pivot = TRUE) |> suppressWarnings()
rank = attr(ch, "rank")
pivot <- attr(ch, "pivot")
XpXinv <- chol2inv(ch, size = rank)
coef <- rep(NA, ncol(X))
names(coef) <- colnames(X)
coef[pivot[1:rank]] <- XpXinv %*% (t(X[, pivot[1:rank], drop = FALSE]) %*% y)
nonNA <- which(!is.na(coef))
fit <- X[ , nonNA, drop = FALSE] %*% coef[nonNA]
mse <- sum((y - fit)^2) / (length(y) - rank)
ord <- order(pivot[1:rank])
vcov <- mse * XpXinv[ord, ord, drop = FALSE]
dimnames(vcov) <- list(names(coef)[nonNA], names(coef)[nonNA])
nbasis <- estimability::nonest.basis(ch)
terms <- terms(formula) |> delete.response()
xlev = list()
for (v in all.vars(terms))
if(!is.null(lv <- levels(data[[v]])))
xlev[[v]] <- lv
obj <- list(terms = terms, xlev = xlev, coef = coef, vcov = vcov, nbasis = nbasis)
class(obj) <- "mylm"
obj
}The following code is a corresponding S3 method to provide
predictions and standard errors for new data, with estimability
checking. The model.frame call ensures that we incorporate
the original factor coding in the new data. We then construct the
corresponding model matrix X for the new data, and test the
rows for estimability. Once that is determined, we obtain the matrix
XX with only the estimable rows and corresponding to
non-NA elements of coef. The predictions and
SEs of any non-estimable rows are set to NA.
predict.mylm <- function(mod, newdata) {
mf <- model.frame(mod$terms, data = newdata, xlev = mod$xlev)
X <- model.matrix(mod$terms, data = mf)
pred <- se <- rep(NA, nrow(X))
estble <- estimability::is.estble(X, mod$nbasis)
XX <- X[estble, !is.na(mod$coef), drop = FALSE]
pred[estble] <- XX %*% mod$coef[!is.na(mod$coef)]
se[estble] <- sqrt(diag(XX %*% mod$vcov %*% t(XX)))
list(pred = pred, se = se)
}Finally, here is a test of these functions, using a subset of the
warpbreaks data with two complete cells excluded along with
a few more observations. Then we obtain predictions for all factor
combinations. We find that the two excluded cells are non-estimable.
# test code
warp = warpbreaks[11:40, ]
warp.mylm = mylm(breaks ~ wool*tension, warp)
new = do.call(expand.grid, warp.mylm$xlev)
cbind(new, predict(warp.mylm, newdata = new))## wool tension pred se
## 1 A L NA NA
## 2 B L 28.22222 3.324365
## 3 A M 24.75000 3.526021
## 4 B M 25.75000 4.986547
## 5 A H 24.55556 3.324365
## 6 B H NA NA