Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bf7ac8c
default method for PIT and LOO-PIT
TeemuSailynoja Nov 12, 2024
d9b02be
Add validation for loo_weights.
TeemuSailynoja Mar 6, 2025
359f6b1
Check PIT bounds only once.
TeemuSailynoja Mar 14, 2025
72a3a30
Normalize loo_weights in pit.default. Add tests for internals and pit…
TeemuSailynoja Mar 14, 2025
75d2592
Handle PIT = 0 without warnings. Add documentation for input parameters.
TeemuSailynoja Mar 14, 2025
efc37d6
default method for PIT and LOO-PIT
TeemuSailynoja Nov 12, 2024
0258555
Add validation for loo_weights.
TeemuSailynoja Mar 6, 2025
48dc3a6
Normalize loo_weights in pit.default. Add tests for internals and pit…
TeemuSailynoja Mar 14, 2025
f43023e
Handle PIT = 0 without warnings. Add documentation for input parameters.
TeemuSailynoja Mar 14, 2025
b2ffb7e
Merge remote-tracking branch 'refs/remotes/origin/add_pit_functions' …
TeemuSailynoja Mar 14, 2025
8f2c361
Implement pit for draws and rvar. Add unit tests.
TeemuSailynoja Mar 19, 2025
ab21007
Make PIT description short and add myself as contributor.
TeemuSailynoja Mar 19, 2025
c00b326
typo1 R/pit.R
TeemuSailynoja Mar 19, 2025
2537986
change \(x) to function(x) for backewards compatibility.
TeemuSailynoja Mar 19, 2025
f58e347
Merge branch 'stan-dev:master' into add_pit_functions
TeemuSailynoja Mar 19, 2025
7acb46f
fix typo in test-pit.R
TeemuSailynoja Mar 19, 2025
985a6c5
Update documentation of pit(), use validate_weights, log = FALSE by d…
TeemuSailynoja Mar 27, 2025
203fdb3
make x optional in validate_y
TeemuSailynoja Mar 27, 2025
9787124
fix rvar and draws checking
TeemuSailynoja Mar 27, 2025
69e3560
Add tolerance to warning of pit over 1. Always clip pit at 1.
TeemuSailynoja Mar 27, 2025
af8f26a
remove test for warning when pit over tolerance. Can't reproduce.
TeemuSailynoja Mar 27, 2025
f33a216
fix example in pit
TeemuSailynoja Mar 27, 2025
f8a9894
Update NEWS.md
TeemuSailynoja Mar 28, 2025
006ef16
Update NEWS.md
TeemuSailynoja Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Authors@R: c(person("Paul-Christian", "Bürkner", email = "paul.buerkner@gmail.c
person("Ben", "Lambert", role = c("ctb")),
person("Ozan", "Adıgüzel", role = c("ctb")),
person("Jacob", "Socolar", role = c("ctb")),
person("Noa", "Kallioinen", role = c("ctb")))
person("Noa", "Kallioinen", role = c("ctb")),
person("Teemu", "Säilynoja", role = c("ctb")))
Description: Provides useful tools for both users and developers of packages
for fitting Bayesian models or working with output from Bayesian models.
The primary goals of the package are to:
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ S3method(pareto_min_ss,rvar)
S3method(pareto_smooth,default)
S3method(pareto_smooth,rvar)
S3method(pillar_shaft,rvar)
S3method(pit,default)
S3method(pit,draws_matrix)
S3method(pit,rvar)
S3method(print,draws_array)
S3method(print,draws_df)
S3method(print,draws_list)
Expand Down Expand Up @@ -474,6 +477,7 @@ export(pareto_khat)
export(pareto_khat_threshold)
export(pareto_min_ss)
export(pareto_smooth)
export(pit)
export(ps_convergence_rate)
export(ps_khat_threshold)
export(ps_min_ss)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

* Convert lists of matrices to `draws_array` objects.
* Improve the documentation in various places.
* Implement `pit()` for draws and rvar objects. LOO-PIT can be computed using `weights`.

# posterior 1.6.0

Expand Down
179 changes: 179 additions & 0 deletions R/pit.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#' Probability integral transform
#'
#' Probability integral transform (PIT). LOO-PIT is given by a weighted sample.
#'
#' @name pit
#'
#' @param x (draws) A [`draws_matrix`] object or one coercible to a
#' `draws_matrix` object, or an [`rvar`] object.
#'
#' @param y (observations) A 1D vector, or an array of dim(x), if x is `rvar`.
#' Each element of `y` corresponds to a variable in `x`.
#'
#' @param weights A matrix of weights for each draw and variable. `weights`
#' should have one column per variable in `x`, and `ndraws(x)` rows.
#'
#' @param log (logical) Are the weights passed already on the log scale? The
#' default is `FALSE`, that is, expecting `weights` to be on the standard
#' (non-log) scale.
#'
#' @template args-methods-dots
#'
#' @details The `pit()` function computes the probability integral transform of
Comment thread
paul-buerkner marked this conversation as resolved.
#' `y` using the empirical cumulative distribution computed from the samples
#' in `x`. For continuous valued `y` and `x`, the PIT for the elements of `y`
#' is computed as the empirical cumulative distribution value:
#'
#' PIT(y_i) = Pr(x_i < y_i),
#'
#' where x_i, is the corresponding set of draws in `x`. For `draws` objects,
#' this corresponds to the draws of the *i*th variable, and for `rvar`
#' the elements of `y` and `x` are matched.
#'
#' The draws in `x` can further be provided (log-)weights in
# `weights`, which enables for example the computation of LOO-PITs.
#'
#' If `y` and `x` are discrete, randomisation is used to obtain continuous PIT
#' values. (see, e.g., Czado, C., Gneiting, T., Held, L.: Predictive model
#' assessment for count data. Biometrics 65(4), 1254–1261 (2009).)
#'
#' @return A numeric vector of length `length(y)` containing the PIT values, or
#' an array of shape `dim(y)`, if `x` is an `rvar`.

#' @examples
#' # PIT for a draws object
#' x <- example_draws()
#' # Create a vector of observations
#' y <- rnorm(nvariables(x), 5, 5)
#' pit(x, y)
#'
#' # Compute weighted PIT (for example LOO-PIT)
#' weights <- matrix(runif(length(x)), ncol = nvariables(x))
#'
#' pit(x, y, weights)
#'
#' # PIT for an rvar
#' x <- rvar(example_draws())
#' # Create an array of observations with the same dimensions as x.
#' y_arr <- array(rnorm(length(x), 5, 5), dim = dim(x))
#' pit(x, y_arr)
#'
NULL

#' @rdname pit
#' @export
pit <- function(x, y, ...) UseMethod("pit")

#' @rdname pit
#' @export
pit.default <- function(x, y, weights = NULL, log = FALSE, ...) {
x <- as_draws_matrix(x)
if (!is.null(weights)) {
weights <- as_draws_matrix(weights)
}
pit(x, y, weights, log)
}

#' @rdname pit
#' @export
pit.draws_matrix <- function(x, y, weights = NULL, log = FALSE, ...) {
y <- validate_y(y, x)
if (!is.null(weights)) {
weights <- sapply(seq_len(nvariables(x)), \(var_idx) {
validate_weights(weights[, var_idx], x[, var_idx], log)
})
weights <- normalize_log_weights(weights)
}
pit <- vapply(seq_len(ncol(x)), function(j) {
sel_min <- x[, j] < y[j]
if (!any(sel_min)) {
pit <- 0
} else {
if (is.null(weights)) {
pit <- mean(sel_min)
} else {
pit <- exp(log_sum_exp(weights[sel_min, j]))
}
}

sel_sup <- x[, j] == y[j]
if (any(sel_sup)) {
# randomized PIT for discrete y (see, e.g., Czado, C., Gneiting, T.,
# Held, L.: Predictive model assessment for count data.
# Biometrics 65(4), 1254–1261 (2009).)
if (is.null(weights)) {
pit_sup <- pit + mean(sel_sup)
} else {
pit_sup <- pit + exp(log_sum_exp(weights[sel_sup, j]))
}

pit <- runif(1, pit, pit_sup)
}
pit
}, FUN.VALUE = 1.0)

if (any(pit > 1 + 1e-10)) {
warning_no_call(
Comment thread
paul-buerkner marked this conversation as resolved.
paste(
"Some PIT values larger than 1. ",
"This is usually due to numerical inaccuracies. ",
"Largest value: ",
max(pit),
"\nRounding PIT > 1 to 1.",
sep = ""
)
)
}

setNames(pmin(1, pit), variables(x))
}

#' @rdname pit
#' @export
pit.rvar <- function(x, y, weights = NULL, log = FALSE, ...) {
y <- validate_y(y, x)
if (is.null(weights)) {
out <- array(
runif(length(y), Pr(x < y), Pr(x <= y)),
dim(x),
dimnames(x)
)
} else {
out <- array(
data = pit(
x = as_draws_matrix(c(x)),
y = c(y),
weights = weights,
log = log
),
dim = dim(x),
dimnames = dimnames(x)
)
}
out
}

# internal ----------------------------------------------------------------

validate_y <- function(y, x = NULL) {
if (!is.numeric(y)) {
stop_no_call("`y` must be numeric.")
}
if (anyNA(y)) {
stop_no_call("NAs not allowed in `y`.")
}
if (is_rvar(x)) {
if (length(x) != length(y) || any(dim(y) != dim(x))) {
stop_no_call("`dim(y)` must match `dim(x)`.")
}
} else if (is_draws(x)) {
if (!is.vector(y, mode = "numeric") || length(y) != nvariables(x)) {
stop_no_call("`y` must be a vector of length `nvariables(x)`.")
}
}
y
}

normalize_log_weights <- function(log_weights) {
apply(log_weights, 2, function(col) col - log_sum_exp(col))
}
3 changes: 3 additions & 0 deletions R/weight_draws.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ validate_weights <- function(weights, draws, log = FALSE) {
if (length(weights) != ndraws(draws)) {
stop_no_call("Number of weights must match the number of draws.")
}
if (any(weights == Inf)) {
stop_no_call("Weights must not be positive infinite.")
}
if (!log) {
if (any(weights < 0)) {
stop_no_call("Weights must be non-negative.")
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ reference:
- entropy
- dissent
- modal_category
- pit
- title: "Functionality specific to the rvar datatype"
desc: >
The `draws_rvar` format (a structured list of `rvar` objects) has the same
Expand Down
78 changes: 78 additions & 0 deletions man/pit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/posterior-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading