Source code for torch_measure.models.threepl
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Three-Parameter Logistic (3PL) IRT model."""
from __future__ import annotations
import torch
from torch import nn
from torch_measure.models._base import IRTModel
[docs]
class ThreePL(IRTModel):
"""3-Parameter Logistic IRT model.
P(correct) = c + (1 - c) * sigmoid(a * (theta - b)) where:
- theta: subject ability
- b: item difficulty
- a: item discrimination
- c: guessing parameter (lower asymptote)
Parameters
----------
n_subjects : int
Number of subjects.
n_items : int
Number of items.
device : str
Device to place parameters on.
"""
def __init__(self, n_subjects: int, n_items: int, device: str = "cpu") -> None:
super().__init__(n_subjects, n_items, device)
self.ability = nn.Parameter(torch.randn(n_subjects, device=self._device))
self.difficulty = nn.Parameter(torch.randn(n_items, device=self._device))
self._discrimination_raw = nn.Parameter(torch.randn(n_items, device=self._device))
self._guessing_raw = nn.Parameter(torch.full((n_items,), -2.0, device=self._device))
@property
def discrimination(self) -> torch.Tensor:
"""Item discrimination parameters (constrained positive)."""
return torch.exp(self._discrimination_raw)
@property
def guessing(self) -> torch.Tensor:
"""Item guessing parameters (constrained to [0, 1])."""
return torch.sigmoid(self._guessing_raw)
[docs]
def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute P(correct) = c + (1-c) * sigmoid(a * (theta - b)) at query rows."""
s = query["subject_idx"]
i = query["item_idx"]
return self._irt_probability(
self.ability[s],
self.difficulty[i],
discrimination=self.discrimination[i],
guessing=self.guessing[i],
)