Skip to content

Add pit functions#395

Merged
paul-buerkner merged 24 commits intostan-dev:masterfrom
TeemuSailynoja:add_pit_functions
Mar 28, 2025
Merged

Add pit functions#395
paul-buerkner merged 24 commits intostan-dev:masterfrom
TeemuSailynoja:add_pit_functions

Conversation

@TeemuSailynoja
Copy link
Copy Markdown
Contributor

Summary

Add functions for computing probability integral transforms. This includes both the regular PIT of observations with respect to posterior predictive draws, as well as, a weighted version, where LOO weights can optionally be supplied.

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to
license the submitted work under the following licenses:

@TeemuSailynoja
Copy link
Copy Markdown
Contributor Author

This would close #338.

@paul-buerkner
Copy link
Copy Markdown
Collaborator

Thank you! Are there tests in place already?

@TeemuSailynoja
Copy link
Copy Markdown
Contributor Author

No, this is still a draft PR as I only have the default method and have not rested with rvars or draws objects.

@TeemuSailynoja
Copy link
Copy Markdown
Contributor Author

A lot of failing checks. I'm still working on the PR, but should I be worried with those?

@TeemuSailynoja
Copy link
Copy Markdown
Contributor Author

@paul-buerkner I think this should now be at a state ready for a review.

@TeemuSailynoja TeemuSailynoja marked this pull request as ready for review March 19, 2025 11:55
Copy link
Copy Markdown
Collaborator

@paul-buerkner paul-buerkner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Looks good overall. I have noted down a couple of questions and suggestions.

R/pit.R Outdated

#' @rdname pit
#' @export
pit.default <- function(x, y, loo_weights = NULL, log = TRUE) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are weights only valid if they come from loo? Or can they be, in theory, any kind of weights? If so, can we just call them weights?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed them to weights. The difference to the weights already available to draws objects in .weights is, that this function requires a weight for each draw and each variable.
We were discussing with @n-kall that in the presence of the per-draw .weights they should also be taken into account.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay. Makes sense.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this incorporation of .weights something you want to do as part of this PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a future extension for the functionality of pit().

R/pit.R Outdated
#' @param y (observations) A 1D vector of observations. Each element of `y`
#' corresponds to a column in `x`.

#' @param loo_weights A [`draws_matrix`] object or one coercible to a
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

highlight that by default these are expected to be log weights? What is more, do is log = TRUE a good default? Can we assume users to see this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the default to FALSE, and used the same doc-string as weight_draws to highlight the default value.

R/pit.R Outdated

#' @rdname pit
#' @export
pit.default <- function(x, y, loo_weights = NULL, log = TRUE) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for extensibility, please add ... to all methods and the generic

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

R/pit.R Outdated
#' @return A numeric vector of length `length(y)` containing the PIT values.
#'
#' @export
pit <- function(x, y, loo_weights = NULL, log = TRUE) UseMethod("pit")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this signature is very restrictive. I would suggest the signature (x, y, ...) for better extensibility.

Copy link
Copy Markdown
Contributor Author

@TeemuSailynoja TeemuSailynoja Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the default signature to the suggested one.

pit <- function(x, y, ...) UseMethod("pit")

R/pit.R Outdated
#' @rdname pit
#' @export
pit.draws_matrix <- function(x, y, loo_weights = NULL, log = TRUE) {
if (length(y) != ncol(x)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could y be better validated? E.g., check this it is numeric or something along these lines?

Copy link
Copy Markdown
Contributor Author

@TeemuSailynoja TeemuSailynoja Mar 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a function for validate_y

validate_y <- function(y, x = NULL) {
  if (!is.numeric(y)) {
    stop_no_call("`y` must be numeric.")
  }
  if (anyNA(y)) {
    stop_no_call("NAs not allowed in `y`.")
  }
  if (is_rvar(x)) {
    if (length(x) != length(y) || any(dim(y) != dim(x))) {
      stop_no_call("`dim(y)` must match `dim(x)`.")
    }
  } else if (is_draws(x)) {
    if (!is.vector(y, mode = "numeric") || length(y) != nvariables(x)) {
      stop_no_call("`y` must be a vector of length `nvariables(x)`.")
    }
  }
  y
}

Currently, there is no other function in posterior that would use this validation, and the requirement of matching dimensions with x seems specific to this pit().
Just in case of future extensions, I left x as an optional argument.

Copy link
Copy Markdown
Collaborator

@paul-buerkner paul-buerkner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved. Looks good now.

@paul-buerkner paul-buerkner merged commit aba4bb8 into stan-dev:master Mar 28, 2025
7 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants