Source code for torch_measure.metrics.calibration
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Calibration metrics for predicted probabilities.
Consolidated from predictive-eval/utils/metrics.py compute_ece.
"""
from __future__ import annotations
import torch
[docs]
def expected_calibration_error(
predicted: torch.Tensor,
observed: torch.Tensor,
mask: torch.Tensor | None = None,
n_bins: int = 15,
) -> float:
"""Compute Expected Calibration Error (ECE).
Measures how well predicted probabilities match observed frequencies.
ECE = 0 means perfectly calibrated.
Parameters
----------
predicted : torch.Tensor
Predicted probabilities.
observed : torch.Tensor
Observed binary outcomes.
mask : torch.Tensor | None
Boolean mask of entries to evaluate.
n_bins : int
Number of calibration bins.
Returns
-------
float
ECE value in [0, 1].
"""
if mask is None:
mask = ~torch.isnan(observed)
p = predicted[mask].flatten()
y = observed[mask].flatten().float()
if len(p) == 0:
return 0.0
bin_boundaries = torch.linspace(0, 1, n_bins + 1)
ece = 0.0
for i in range(n_bins):
in_bin = (p >= bin_boundaries[i]) & (p < bin_boundaries[i + 1])
if i == n_bins - 1: # include right boundary for last bin
in_bin = in_bin | (p == bin_boundaries[i + 1])
n_in_bin = in_bin.sum().item()
if n_in_bin == 0:
continue
avg_predicted = p[in_bin].mean().item()
avg_observed = y[in_bin].mean().item()
ece += abs(avg_predicted - avg_observed) * (n_in_bin / len(p))
return ece
[docs]
def brier_score(
predicted: torch.Tensor,
observed: torch.Tensor,
mask: torch.Tensor | None = None,
) -> float:
"""Compute the Brier score (mean squared error of probabilities).
Parameters
----------
predicted : torch.Tensor
Predicted probabilities.
observed : torch.Tensor
Observed binary outcomes.
mask : torch.Tensor | None
Boolean mask.
Returns
-------
float
Brier score in [0, 1]. Lower is better.
"""
if mask is None:
mask = ~torch.isnan(observed)
p = predicted[mask].flatten()
y = observed[mask].flatten().float()
if len(p) == 0:
return 0.0
return ((p - y) ** 2).mean().item()