Source code for torch_measure.models.logistic_fm

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Logistic Factor Model for measurement.

Consolidated from factor-model/calibration/model.py LogisticFM.
"""

from __future__ import annotations

from typing import ClassVar

import torch
from torch import nn

from torch_measure.models._base import IRTModel


[docs] class LogisticFM(IRTModel): """K-factor Logistic Factor Model. P(correct) = sigmoid(U @ V^T + Z^T) where: - U: (n_subjects, K) latent ability factors - V: (n_items, K) item loadings on factors - Z: (n_items,) item intercepts (easiness) When K=1, this is equivalent to the Rasch model. Parameters ---------- n_subjects : int Number of subjects. n_items : int Number of items. n_factors : int Number of latent factors (K). device : str Device to place parameters on. """ _ability_param_name: ClassVar[str] = "U" def __init__(self, n_subjects: int, n_items: int, n_factors: int = 2, device: str = "cpu") -> None: super().__init__(n_subjects, n_items, device) self.n_factors = n_factors self.U = nn.Parameter(torch.randn(n_subjects, n_factors, device=self._device)) self.V = nn.Parameter(torch.randn(n_items, n_factors, device=self._device)) self.Z = nn.Parameter(torch.randn(n_items, device=self._device)) @property def ability(self) -> torch.Tensor: """Subject ability factors (n_subjects, K).""" return self.U.detach() @property def difficulty(self) -> torch.Tensor: """Item intercepts (n_items,). Negative Z = harder items.""" return -self.Z.detach() @property def loadings(self) -> torch.Tensor: """Item factor loadings (n_items, K).""" return self.V.detach()
[docs] def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor: """Compute P(correct) = sigmoid(U_s · V_i + Z_i) at query rows.""" s = query["subject_idx"] i = query["item_idx"] # (N, K) · (N, K) → (N,) logit = (self.U[s] * self.V[i]).sum(dim=-1) + self.Z[i] return torch.sigmoid(logit)