Source code for torch_measure.metrics.scalability
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Mokken scalability analysis.
Consolidated from fantastic-bugs/src/metrics.py.
"""
from __future__ import annotations
import torch
[docs]
def mokken_scalability(data: torch.Tensor, mask: torch.Tensor | None = None) -> dict:
"""Compute Mokken scalability coefficients.
Mokken scaling is a non-parametric IRT approach that tests whether
items form a unidimensional scale. The H coefficient measures how
well item pairs conform to the Guttman pattern.
H >= 0.5: strong scale
0.4 <= H < 0.5: medium scale
0.3 <= H < 0.4: weak scale
H < 0.3: not a scale
Parameters
----------
data : torch.Tensor
Binary response matrix (n_subjects, n_items).
mask : torch.Tensor | None
Boolean mask.
Returns
-------
dict
Dictionary with:
- 'H': Overall scalability coefficient
- 'H_items': Per-item scalability coefficients, shape (n_items,)
- 'H_pairs': Pairwise scalability matrix, shape (n_items, n_items)
"""
if mask is None:
mask = ~torch.isnan(data)
data_clean = data.clone()
data_clean[~mask] = 0.0
n_items = data.shape[1]
# Item means (facility/easiness)
item_sums = (data_clean * mask.float()).sum(dim=0)
item_counts = mask.float().sum(dim=0).clamp(min=1)
item_means = item_sums / item_counts
# Pairwise H coefficients
h_pairs = torch.zeros(n_items, n_items)
total_obs_error = 0.0
total_exp_error = 0.0
for i in range(n_items):
for j in range(i + 1, n_items):
# Ensure p_i >= p_j (order by difficulty)
pi, pj = item_means[i], item_means[j]
if pi < pj:
pi, pj = pj, pi
col_easy, col_hard = j, i
else:
col_easy, col_hard = i, j
# Both observed
pair_mask = mask[:, i] & mask[:, j]
if pair_mask.sum() < 5:
continue
# Observed Guttman errors: easy=0, hard=1
obs_errors = ((data_clean[pair_mask, col_easy] == 0) & (data_clean[pair_mask, col_hard] == 1)).float().sum()
# Expected errors under independence
n_pair = pair_mask.sum().float()
exp_errors = n_pair * (1 - pi) * pj
if exp_errors > 0:
h_ij = 1 - obs_errors / exp_errors
h_pairs[i, j] = h_ij
h_pairs[j, i] = h_ij
total_obs_error += obs_errors.item()
total_exp_error += exp_errors.item()
# Overall H
h_overall = 1 - total_obs_error / max(total_exp_error, 1e-10)
# Per-item H
h_items = torch.zeros(n_items)
for i in range(n_items):
item_obs = 0.0
item_exp = 0.0
for j in range(n_items):
if i == j:
continue
pi, pj = item_means[i], item_means[j]
if pi < pj:
pi, pj = pj, pi
col_easy, col_hard = j, i
else:
col_easy, col_hard = i, j
pair_mask = mask[:, i] & mask[:, j]
if pair_mask.sum() < 5:
continue
obs_e = ((data_clean[pair_mask, col_easy] == 0) & (data_clean[pair_mask, col_hard] == 1)).float().sum()
exp_e = pair_mask.sum().float() * (1 - pi) * pj
item_obs += obs_e.item()
item_exp += exp_e.item()
h_items[i] = 1 - item_obs / max(item_exp, 1e-10)
return {
"H": h_overall,
"H_items": h_items,
"H_pairs": h_pairs,
}