Source code for torch_measure.metrics.validity

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Validity analysis metrics including Differential Item Functioning (DIF)."""

from __future__ import annotations

import torch


[docs] def differential_item_functioning( data: torch.Tensor, group: torch.Tensor, mask: torch.Tensor | None = None, method: str = "mh", ) -> dict: """Detect Differential Item Functioning (DIF). DIF occurs when subjects of equal ability from different groups have different probabilities of answering an item correctly. Parameters ---------- data : torch.Tensor Binary response matrix (n_subjects, n_items). group : torch.Tensor Group membership for each subject (n_subjects,). Binary (0/1). mask : torch.Tensor | None Boolean mask. method : str DIF detection method. Currently supports "mh" (Mantel-Haenszel). Returns ------- dict Dictionary with: - 'mh_statistic': Mantel-Haenszel chi-square per item, shape (n_items,) - 'effect_size': MH odds ratio (Delta-MH) per item, shape (n_items,) - 'flagged': Boolean mask of items flagged for DIF """ if mask is None: mask = ~torch.isnan(data) data_clean = data.clone() data_clean[~mask] = 0.0 n_items = data.shape[1] total_scores = (data_clean * mask.float()).sum(dim=1) # Bin by total score (quintiles) n_bins = 5 percentiles = torch.quantile(total_scores, torch.linspace(0, 1, n_bins + 1)) odds_ratios = torch.zeros(n_items) for j in range(n_items): alpha_sum = 0.0 beta_sum = 0.0 for k in range(n_bins): low = percentiles[k] high = percentiles[k + 1] if k < n_bins - 1: in_bin = (total_scores >= low) & (total_scores < high) else: in_bin = (total_scores >= low) & (total_scores <= high) in_bin = in_bin & mask[:, j] if in_bin.sum() < 2: continue g0 = in_bin & (group == 0) g1 = in_bin & (group == 1) n0 = g0.sum().float() n1 = g1.sum().float() if n0 == 0 or n1 == 0: continue # 2x2 table within this score stratum a = (data_clean[g1, j] == 1).sum().float() # focal correct b = (data_clean[g1, j] == 0).sum().float() # focal incorrect c = (data_clean[g0, j] == 1).sum().float() # reference correct d = (data_clean[g0, j] == 0).sum().float() # reference incorrect n_total = (a + b + c + d).clamp(min=1) alpha_sum += (a * d / n_total).item() beta_sum += (b * c / n_total).item() if beta_sum > 0: odds_ratios[j] = alpha_sum / beta_sum else: odds_ratios[j] = 1.0 # Delta-MH effect size (ETS scale) delta_mh = -2.35 * torch.log(odds_ratios.clamp(min=1e-10)) # Flag items with |Delta-MH| > 1.0 (moderate to large DIF) flagged = delta_mh.abs() > 1.0 return { "effect_size": delta_mh, "odds_ratio": odds_ratios, "flagged": flagged, }