Source code for torch_measure.models._base

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Base class for all IRT models."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch

from torch_measure.models._predictor import Predictor

if TYPE_CHECKING:
    from torch_measure.datasets._long_form import LongFormData


[docs] class IRTModel(Predictor): """Abstract base for factor-based Item Response Theory models. Specialises :class:`Predictor` for models with explicit ``ability`` and ``difficulty`` parameters that compose into a per-cell probability via a logistic link. Subclasses implement :meth:`predict` (inherited from :class:`Predictor`) by gathering parameters at the query indices and applying the IRT formula — see :meth:`_irt_probability`. For non-factor predictors (TabPFN-style, neural baselines), inherit :class:`Predictor` directly instead. """
[docs] def fit( self, data: LongFormData | torch.Tensor, mask: torch.Tensor | None = None, method: str = "mle", max_epochs: int = 1000, lr: float = 0.01, verbose: bool = True, **kwargs, ) -> dict: """Fit the model. Parameters ---------- data : LongFormData | torch.Tensor Either a :class:`~torch_measure.datasets.LongFormData` (canonical long-form input — every observation is one row) or a wide-form response tensor of shape ``(n_subjects, n_items)``. For wide-form, missing entries may be encoded as ``NaN`` or ``-1``. mask : torch.Tensor | None Only used when ``data`` is a wide-form tensor — boolean mask of entries to use for fitting. Inferred from NaN/-1 when ``None``. Ignored for long-form input (absent rows are absent observations). method : str Fitting method: ``"mle"``, ``"em"``, ``"jml"``, or ``"svi"``. max_epochs : int Maximum number of optimization epochs. lr : float Learning rate. verbose : bool Whether to show a progress bar. Returns ------- dict Training history with loss values. """ subject_idx, item_idx, response = self._normalize_fit_inputs(data, mask) if method == "mle": from torch_measure.fitting.mle import mle_fit return mle_fit( self, subject_idx, item_idx, response, max_epochs=max_epochs, lr=lr, verbose=verbose, **kwargs ) elif method == "em": from torch_measure.fitting.em import em_fit return em_fit( self, subject_idx, item_idx, response, max_epochs=max_epochs, lr=lr, verbose=verbose, **kwargs ) elif method == "jml": from torch_measure.fitting.jml import jml_fit return jml_fit( self, subject_idx, item_idx, response, max_epochs=max_epochs, lr=lr, verbose=verbose, **kwargs ) elif method == "svi": from torch_measure.fitting.svi import svi_fit return svi_fit( self, subject_idx, item_idx, response, max_epochs=max_epochs, lr=lr, verbose=verbose, **kwargs ) else: raise ValueError(f"Unknown fitting method: {method!r}. Use 'mle', 'em', 'jml', or 'svi'.")
def _normalize_fit_inputs( self, data, mask: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Coerce ``data`` (LongFormData or wide-form tensor) to the long-form triple. Returns ``(subject_idx, item_idx, response)`` on ``self._device``. """ from torch_measure.datasets._long_form import LongFormData if isinstance(data, LongFormData): fit_inputs = data.to_fit_tensors(device=str(self._device)) return ( fit_inputs["subject_idx"], fit_inputs["item_idx"], fit_inputs["response"], ) if not isinstance(data, torch.Tensor): raise TypeError(f"fit() expected LongFormData or torch.Tensor, got {type(data).__name__}") response_matrix = data.to(self._device) if mask is None: mask = ~torch.isnan(response_matrix) & (response_matrix != -1) mask = mask.to(self._device) obs_indices = mask.nonzero(as_tuple=False) subject_idx = obs_indices[:, 0].to(self._device) item_idx = obs_indices[:, 1].to(self._device) response = response_matrix[mask].float().to(self._device) return subject_idx, item_idx, response @staticmethod def _irt_probability( ability: torch.Tensor, difficulty: torch.Tensor, discrimination: torch.Tensor | None = None, guessing: torch.Tensor | None = None, ) -> torch.Tensor: """Element-wise IRT probability ``P(correct | params)``. All inputs must be 1-D tensors of equal length ``N`` (already gathered at the query indices). Implements:: P = c + (1 - c) * sigmoid(a * (theta - b)) where ``theta=ability``, ``b=difficulty``, ``a=discrimination``, ``c=guessing``. Parameters ---------- ability : torch.Tensor Subject abilities at query rows, shape ``(N,)``. difficulty : torch.Tensor Item difficulties at query rows, shape ``(N,)``. discrimination : torch.Tensor | None Item discriminations at query rows, shape ``(N,)``. Defaults to 1. guessing : torch.Tensor | None Item guessing parameters at query rows, shape ``(N,)``. Defaults to 0. Returns ------- torch.Tensor Probabilities, shape ``(N,)``. """ logit = ability - difficulty if discrimination is not None: logit = discrimination * logit prob = torch.sigmoid(logit) if guessing is not None: prob = guessing + (1 - guessing) * prob return prob