Source code for torch_measure.metrics.reliability
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Reliability and fit statistics for measurement models.
Consolidated from predictive-eval/utils/metrics.py and factor-model/calibration/metrics.py.
"""
from __future__ import annotations
import torch
[docs]
def infit_statistics(predicted: torch.Tensor, observed: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""Compute Rasch infit (information-weighted) mean square statistics per item.
Infit is sensitive to unexpected responses near item difficulty.
Values near 1.0 indicate good fit. Values > 1.3 indicate underfit (noise),
values < 0.7 indicate overfit (Guttman pattern).
Parameters
----------
predicted : torch.Tensor
Predicted probabilities (n_subjects, n_items).
observed : torch.Tensor
Observed binary responses (n_subjects, n_items).
mask : torch.Tensor | None
Boolean mask of entries to include.
Returns
-------
torch.Tensor
Infit statistics per item, shape (n_items,).
"""
if mask is None:
mask = ~torch.isnan(observed)
p = predicted.clamp(1e-7, 1 - 1e-7)
variance = p * (1 - p) # Bernoulli variance
residual_sq = (observed - p) ** 2
# Weighted mean square: sum(residual^2) / sum(variance) per item
numerator = (residual_sq * mask.float()).sum(dim=0)
denominator = (variance * mask.float()).sum(dim=0).clamp(min=1e-10)
return numerator / denominator
[docs]
def outfit_statistics(
predicted: torch.Tensor, observed: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
"""Compute Rasch outfit (unweighted) mean square statistics per item.
Outfit is sensitive to unexpected responses far from item difficulty.
Parameters
----------
predicted : torch.Tensor
Predicted probabilities (n_subjects, n_items).
observed : torch.Tensor
Observed binary responses (n_subjects, n_items).
mask : torch.Tensor | None
Boolean mask of entries to include.
Returns
-------
torch.Tensor
Outfit statistics per item, shape (n_items,).
"""
if mask is None:
mask = ~torch.isnan(observed)
p = predicted.clamp(1e-7, 1 - 1e-7)
variance = p * (1 - p)
standardized_residual_sq = ((observed - p) ** 2) / variance
# Simple mean of standardized residuals per item
numerator = (standardized_residual_sq * mask.float()).sum(dim=0)
count = mask.float().sum(dim=0).clamp(min=1)
return numerator / count
[docs]
def item_total_correlation(data: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
"""Compute corrected item-total correlation for each item.
For each item, computes the Pearson correlation between the item
responses and the total score excluding that item.
Parameters
----------
data : torch.Tensor
Binary response matrix (n_subjects, n_items).
mask : torch.Tensor | None
Boolean mask.
Returns
-------
torch.Tensor
Item-total correlations, shape (n_items,).
"""
if mask is None:
mask = ~torch.isnan(data)
data_clean = data.clone()
data_clean[~mask] = 0.0
total = (data_clean * mask.float()).sum(dim=1) # (N,)
n_items = data.shape[1]
correlations = []
for j in range(n_items):
item_mask = mask[:, j]
item_vals = data[item_mask, j]
corrected_total = total[item_mask] - item_vals # exclude item j
if len(item_vals) < 3 or item_vals.std() < 1e-10 or corrected_total.std() < 1e-10:
correlations.append(torch.tensor(0.0))
continue
# Pearson correlation
x = item_vals - item_vals.mean()
y = corrected_total - corrected_total.mean()
r = (x * y).sum() / (x.norm() * y.norm() + 1e-10)
correlations.append(r)
return torch.stack(correlations)
[docs]
def cronbach_alpha(data: torch.Tensor, mask: torch.Tensor | None = None) -> float:
"""Compute Cronbach's alpha reliability coefficient.
Parameters
----------
data : torch.Tensor
Response matrix (n_subjects, n_items).
mask : torch.Tensor | None
Boolean mask.
Returns
-------
float
Cronbach's alpha.
"""
if mask is None:
mask = ~torch.isnan(data)
data_clean = data.clone()
data_clean[~mask] = 0.0
k = data.shape[1]
item_vars = []
for j in range(k):
m = mask[:, j]
vals = data[m, j]
if len(vals) > 1:
item_vars.append(vals.var().item())
else:
item_vars.append(0.0)
sum_item_var = sum(item_vars)
total = (data_clean * mask.float()).sum(dim=1)
# Only use subjects with all items observed for total variance
all_observed = mask.all(dim=1)
total_var = total.var().item() if all_observed.sum() < 3 else total[all_observed].var().item()
if total_var < 1e-10:
return 0.0
return (k / (k - 1)) * (1 - sum_item_var / total_var)