Source code for torch_measure.models._predictor

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Abstract base for any model that predicts P(correct) over long-form observations."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import ClassVar

import torch
from torch import nn


[docs] class Predictor(nn.Module, ABC): """Base class for any model producing P(correct) over (subject, item) cells. Subclasses implement :meth:`predict`, which accepts a long-form query (a dict of 1-D index tensors) and returns one probability per row. :meth:`forward` is a thin wrapper that delegates to :meth:`predict`, so ``model(query)`` works via :meth:`nn.Module.__call__`. Each subclass declares the keys it consumes via :attr:`expected_keys`. The default ``("subject_idx", "item_idx")`` covers every IRT-style model; condition-aware or trial-aware models extend it. """ expected_keys: ClassVar[tuple[str, ...]] = ("subject_idx", "item_idx") def __init__(self, n_subjects: int, n_items: int, device: str | torch.device = "cpu") -> None: super().__init__() self._n_subjects = n_subjects self._n_items = n_items self._device = torch.device(device) @property def n_subjects(self) -> int: return self._n_subjects @property def n_items(self) -> int: return self._n_items @property def device(self) -> torch.device: return self._device
[docs] @abstractmethod def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor: """Predict P(correct) for each row of ``query``. Parameters ---------- query : dict[str, torch.Tensor] Must contain a 1-D tensor for each name in :attr:`expected_keys`, all of equal length ``N``. Extra keys are ignored. Returns ------- torch.Tensor Predicted probabilities, shape ``(N,)`` on the model's device. """ ...
[docs] def forward(self, query: dict[str, torch.Tensor]) -> torch.Tensor: return self.predict(query)
[docs] def cartesian_query( n_subjects: int, n_items: int, device: str | torch.device | None = None, ) -> dict[str, torch.Tensor]: """Build the (subject, item) Cartesian-product query of size ``n_subjects * n_items``. Useful when a caller wants the dense matrix of predictions; see :func:`predict_dense` for the common shortcut. Parameters ---------- n_subjects, n_items : int Universe sizes. device : str or torch.device or None Device for the returned tensors. ``None`` uses the torch default. Returns ------- dict[str, torch.Tensor] ``{"subject_idx": LongTensor (n_subjects*n_items,), "item_idx": LongTensor (n_subjects*n_items,)}``. Row order is subject-major: all of subject 0's items first, then subject 1's items, etc. """ s = torch.arange(n_subjects, device=device).repeat_interleave(n_items) i = torch.arange(n_items, device=device).repeat(n_subjects) return {"subject_idx": s, "item_idx": i}
[docs] def predict_dense(model: Predictor, **extra_keys: torch.Tensor) -> torch.Tensor: """Predict over the full ``(n_subjects, n_items)`` Cartesian grid. Convenience wrapper around :func:`cartesian_query` + ``model.predict``, reshaped back to a ``(n_subjects, n_items)`` matrix. Use this for visualization, EM quadrature, and other callers that genuinely want the dense view. Parameters ---------- model : Predictor Any predictor with a ``(n_subjects, n_items)`` universe. **extra_keys : torch.Tensor Additional query columns required by the model's :attr:`expected_keys` beyond ``subject_idx`` / ``item_idx``. Each must be 1-D of length ``n_subjects * n_items``. Returns ------- torch.Tensor Probability matrix of shape ``(n_subjects, n_items)``. """ query = cartesian_query(model.n_subjects, model.n_items, device=model.device) query.update(extra_keys) return model.predict(query).view(model.n_subjects, model.n_items)