Source code for torch_measure.models.multifacet_twopl

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Many-Facet 2PL IRT model with anchoring (requires pyro-ppl).

Consolidated from safety-irt/model/irt.py.

Extends :class:`MultiFacetRasch` with per-item discrimination (alpha) and a
sparse Student-t prior on the item-by-facet interaction (tau). Supports
anchor items whose tau is forced near zero — useful for cross-lingual safety
analysis where a set of "ground truth" items have known invariant difficulty
across languages.

Example used in code:
P(correct) = sigmoid(alpha_i * ((theta_n + delta_nl) - (beta_i + gamma_l + tau_il)))

where:
- theta_n: subject ability
- delta_nl: subject-facet aptitude (e.g., model-language ability)
- alpha_i: item discrimination
- beta_i: base item difficulty (reference facet level)
- gamma_l: global facet shift (reference level = 0)
- tau_il: item-facet interaction (reference level = 0; anchors near 0)
"""

from __future__ import annotations

from collections.abc import Sequence

import torch
from torch import nn

from torch_measure.models._base import IRTModel


[docs] class MultiFacet2PL(IRTModel): """Many-Facet 2PL IRT Model with anchoring (Bayesian SVI only). Parameters ---------- n_subjects : int Number of subjects. n_items : int Number of items. n_facet_levels : int Number of levels in the additional facet (e.g., number of languages). device : str Device to place parameters on. Notes ----- Estimation is Bayesian SVI via Pyro. Install with: ``pip install torch_measure[bayesian]`` """ def __init__( self, n_subjects: int, n_items: int, n_facet_levels: int, device: str = "cpu", ) -> None: super().__init__(n_subjects, n_items, device) self.n_facet_levels = n_facet_levels # Posterior-mean storage (filled by .fit(); also used by .predict()) self.ability = nn.Parameter(torch.zeros(n_subjects, device=self._device)) self.difficulty = nn.Parameter(torch.zeros(n_items, device=self._device)) self._discrimination_raw = nn.Parameter(torch.zeros(n_items, device=self._device)) self.gamma = nn.Parameter(torch.zeros(n_facet_levels, device=self._device)) self.tau = nn.Parameter(torch.zeros(n_items, n_facet_levels, device=self._device)) self.delta = nn.Parameter(torch.zeros(n_subjects, n_facet_levels, device=self._device)) # Reference-level + anchor masks self.register_buffer("gamma_mask", torch.ones(n_facet_levels, device=self._device)) self.register_buffer("tau_mask", torch.ones(n_items, n_facet_levels, device=self._device)) self.register_buffer("anchor_mask", torch.zeros(n_items, n_facet_levels, device=self._device)) @property def discrimination(self) -> torch.Tensor: """Per-item discrimination, constrained positive via ``exp``.""" return torch.exp(self._discrimination_raw)
[docs] def set_reference_level(self, level_idx: int) -> None: """Anchor a facet level to zero (e.g., English baseline). Forces ``gamma[level_idx] = 0`` and ``tau[:, level_idx] = 0`` at both fit and predict time. Also zeros ``delta[:, level_idx]`` (the subject intercept under the reference facet is absorbed by ``ability``). """ self.gamma_mask[level_idx] = 0.0 self.tau_mask[:, level_idx] = 0.0
[docs] def set_anchor_items(self, item_indices: Sequence[int] | torch.Tensor) -> None: """Mark items whose tau should be near zero across all facet levels. Anchor items get a tight (sd=0.01) Student-t prior on tau, encoding the assumption that their difficulty is invariant across the facet. """ idx = torch.as_tensor(item_indices, dtype=torch.long, device=self._device) self.anchor_mask[idx, :] = 1.0
[docs] def predict(self, facet_indices: torch.Tensor | None = None) -> torch.Tensor: """Compute response probabilities for one facet level. Parameters ---------- facet_indices : torch.Tensor | None Single facet level index. If None, uses level 0. Returns ------- torch.Tensor Probability matrix of shape ``(n_subjects, n_items)``. """ if facet_indices is None: facet_indices = torch.zeros(1, dtype=torch.long, device=self._device) gamma = self.gamma * self.gamma_mask tau = self.tau * self.tau_mask delta = self.delta * self.gamma_mask.unsqueeze(0) if facet_indices.numel() == 1: fl = int(facet_indices.item()) difficulty_l = self.difficulty + gamma[fl] + tau[:, fl] ability_l = self.ability + delta[:, fl] logit = self.discrimination.unsqueeze(0) * (ability_l.unsqueeze(1) - difficulty_l.unsqueeze(0)) return torch.sigmoid(logit) raise NotImplementedError("Batch facet indices not yet supported. Pass a single facet level.")
[docs] def fit( self, subject_idx: torch.Tensor, item_idx: torch.Tensor, facet_idx: torch.Tensor, response: torch.Tensor, max_epochs: int = 4000, lr: float = 0.01, clip_norm: float = 10.0, verbose: bool = True, num_posterior_samples: int = 500, ) -> dict: """Fit via Bayesian SVI (Pyro). Long-form quadruple input: each row is one observation ``(subject_idx[k], item_idx[k], facet_idx[k]) -> response[k]``. Priors: - ``ability ~ Normal(0, 1)`` - ``difficulty ~ Normal(0, 1)`` - ``discrimination ~ LogNormal(0.5, 0.5)`` (positive) - ``gamma_raw ~ Normal(0, 1)``, then ``gamma = gamma_raw * gamma_mask`` - ``tau_scale ~ HalfCauchy(1)``; ``tau_raw ~ StudentT(1, 0, scale)`` with scale=0.01 at anchor cells, ``tau_scale`` elsewhere; then ``tau = tau_raw * tau_mask`` - ``delta_raw ~ Normal(0, 0.5)``, then ``delta = delta_raw * gamma_mask`` Parameters ---------- subject_idx, item_idx, facet_idx : torch.LongTensor Long-form indices, each shape ``(n_obs,)``. response : torch.Tensor Binary observations, shape ``(n_obs,)``. max_epochs : int Number of SVI steps. lr : float Learning rate for ClippedAdam. clip_norm : float Gradient clipping norm. verbose : bool Show tqdm progress bar if available. num_posterior_samples : int Posterior samples for parameter extraction. Returns ------- dict ``{"losses": list[float], "posterior": {param_name: Tensor}}`` where ``posterior`` holds the posterior means used to populate the model's parameter slots. """ try: import pyro import pyro.distributions as dist import pyro.poutine from pyro.infer import SVI, Predictive, Trace_ELBO from pyro.infer.autoguide import AutoNormal from pyro.optim import ClippedAdam except ImportError as err: raise ImportError( "Bayesian SVI fitting requires pyro-ppl. Install with: pip install torch_measure[bayesian]" ) from err device = self._device n_subjects = self.n_subjects n_items = self.n_items n_facets = self.n_facet_levels subject_idx = subject_idx.to(device=device, dtype=torch.long) item_idx = item_idx.to(device=device, dtype=torch.long) facet_idx = facet_idx.to(device=device, dtype=torch.long) response = response.to(device=device, dtype=torch.float32) gamma_mask = self.gamma_mask tau_mask = self.tau_mask anchor_mask = self.anchor_mask def pyro_model(s_idx, i_idx, f_idx, obs): theta = pyro.sample( "ability", dist.Normal(torch.zeros(n_subjects, device=device), 1.0).to_event(1), ) beta = pyro.sample( "difficulty", dist.Normal(torch.zeros(n_items, device=device), 1.0).to_event(1), ) alpha = pyro.sample( "discrimination", dist.LogNormal( torch.full((n_items,), 0.5, device=device), torch.full((n_items,), 0.5, device=device), ).to_event(1), ) gamma_raw = pyro.sample( "gamma_raw", dist.Normal(torch.zeros(n_facets, device=device), 1.0).to_event(1), ) gamma = pyro.deterministic("gamma", gamma_raw * gamma_mask) tau_scale = pyro.sample( "tau_scale", dist.HalfCauchy(torch.ones(1, device=device)).to_event(1), ) tau_scale_per = torch.where( anchor_mask > 0.5, torch.full((n_items, n_facets), 0.01, device=device), tau_scale.expand(n_items, n_facets), ) tau_raw = pyro.sample( "tau_raw", dist.StudentT( 1.0, torch.zeros(n_items, n_facets, device=device), tau_scale_per, ).to_event(2), ) tau = pyro.deterministic("tau", tau_raw * tau_mask) delta_raw = pyro.sample( "delta_raw", dist.Normal(torch.zeros(n_subjects, n_facets, device=device), 0.5).to_event(2), ) delta_mask = gamma_mask.unsqueeze(0).expand(n_subjects, -1) delta = pyro.deterministic("delta", delta_raw * delta_mask) with pyro.plate("data", s_idx.shape[0]): ability_eff = theta[s_idx] + delta[s_idx, f_idx] difficulty_eff = beta[i_idx] + gamma[f_idx] + tau[i_idx, f_idx] logit = alpha[i_idx] * (ability_eff - difficulty_eff) pyro.sample("response", dist.Bernoulli(logits=logit), obs=obs) pyro.clear_param_store() guide = AutoNormal(pyro.poutine.block(pyro_model, hide=["response", "tau", "gamma", "delta"])) optimizer = ClippedAdam({"lr": lr, "clip_norm": clip_norm}) svi = SVI(pyro_model, guide, optimizer, loss=Trace_ELBO()) history: dict = {"losses": []} iterator = range(max_epochs) if verbose: try: from tqdm import tqdm iterator = tqdm(iterator, desc="SVI fitting (MultiFacet2PL)") except ImportError: pass for _ in iterator: loss = svi.step(subject_idx, item_idx, facet_idx, response) history["losses"].append(loss) if verbose and hasattr(iterator, "set_postfix"): iterator.set_postfix({"ELBO": f"{loss:.2f}"}) predictive = Predictive( pyro_model, guide=guide, num_samples=num_posterior_samples, return_sites=["ability", "difficulty", "discrimination", "gamma", "tau", "delta"], ) samples = predictive(subject_idx, item_idx, facet_idx, None) with torch.no_grad(): ability_mean = samples["ability"].mean(dim=0).reshape(n_subjects) difficulty_mean = samples["difficulty"].mean(dim=0).reshape(n_items) disc_mean = samples["discrimination"].mean(dim=0).reshape(n_items) gamma_mean = samples["gamma"].mean(dim=0).reshape(n_facets) tau_mean = samples["tau"].mean(dim=0).reshape(n_items, n_facets) delta_mean = samples["delta"].mean(dim=0).reshape(n_subjects, n_facets) self.ability.copy_(ability_mean) self.difficulty.copy_(difficulty_mean) self._discrimination_raw.copy_(torch.log(disc_mean.clamp_min(1e-8))) self.gamma.copy_(gamma_mean) self.tau.copy_(tau_mean) self.delta.copy_(delta_mean) history["posterior"] = { "ability": ability_mean.detach().clone(), "difficulty": difficulty_mean.detach().clone(), "discrimination": disc_mean.detach().clone(), "gamma": gamma_mean.detach().clone(), "tau": tau_mean.detach().clone(), "delta": delta_mean.detach().clone(), } return history