# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Amortized IRT model that predicts item parameters from embeddings.
Consolidated from agent-eval/model/amortized_irt.py and predictive-eval/train/amortized_irt/irt.py.
"""
from __future__ import annotations
import torch
from torch import nn
from torch_measure.models._base import IRTModel
from torch_measure.models._network import MLP
[docs]
class AmortizedIRT(IRTModel):
"""Amortized IRT model.
Instead of learning independent parameters for each item, this model
learns a mapping from item embeddings to item parameters (difficulty,
discrimination, guessing). This enables zero-shot prediction on new
items given their embeddings.
P(correct) = c + (1-c) * sigmoid(a * (theta - b))
where b, a, c = f(embedding) are predicted by a neural network.
Parameters
----------
n_subjects : int
Number of subjects.
n_items : int
Number of items.
embedding_dim : int
Dimension of item embeddings.
hidden_dim : int
Hidden dimension for the embedding projection network.
n_layers : int
Number of layers in the projection network.
pl : int
Number of IRT parameters: 1 (Rasch), 2 (+discrimination), 3 (+guessing).
dropout : float
Dropout rate in the projection network.
device : str
Device to place parameters on.
"""
def __init__(
self,
n_subjects: int,
n_items: int,
embedding_dim: int,
hidden_dim: int = 256,
n_layers: int = 3,
pl: int = 2,
dropout: float = 0.1,
device: str = "cpu",
) -> None:
super().__init__(n_subjects, n_items, device)
self.pl = pl
self.embedding_dim = embedding_dim
# Subject ability parameters (learned directly)
self.ability = nn.Parameter(torch.randn(n_subjects, device=self._device))
# Item parameter projection network
# Output: difficulty + (discrimination if pl>=2) + (guessing if pl==3)
output_dim = 1 + (1 if pl >= 2 else 0) + (1 if pl == 3 else 0)
self.item_net = MLP(
input_dim=embedding_dim,
hidden_dim=hidden_dim,
output_dim=output_dim,
n_layers=n_layers,
dropout=dropout,
).to(self._device)
self._embeddings: torch.Tensor | None = None
[docs]
def set_embeddings(self, embeddings: torch.Tensor) -> None:
"""Set item embeddings for parameter prediction.
Parameters
----------
embeddings : torch.Tensor
Item embeddings of shape (n_items, embedding_dim).
"""
if embeddings.shape[0] != self.n_items:
raise ValueError(f"Expected {self.n_items} embeddings, got {embeddings.shape[0]}")
self._embeddings = embeddings.to(self._device)
def _compute_item_params(self) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Compute item parameters from embeddings via the projection network."""
if self._embeddings is None:
raise RuntimeError("Call set_embeddings() before predict()")
params = self.item_net(self._embeddings)
difficulty = params[:, 0]
discrimination = torch.exp(params[:, 1]) if self.pl >= 2 else None
guessing = torch.sigmoid(params[:, 2]) if self.pl == 3 else None
return difficulty, discrimination, guessing
@property
def difficulty(self) -> torch.Tensor:
"""Predicted item difficulties from embeddings."""
d, _, _ = self._compute_item_params()
return d.detach()
@property
def discrimination(self) -> torch.Tensor | None:
"""Predicted item discriminations from embeddings (2PL/3PL only)."""
_, a, _ = self._compute_item_params()
return a.detach() if a is not None else None
@property
def guessing(self) -> torch.Tensor | None:
"""Predicted item guessing parameters from embeddings (3PL only)."""
_, _, c = self._compute_item_params()
return c.detach() if c is not None else None
[docs]
def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute P(correct) at query rows using amortized item parameters."""
s = query["subject_idx"]
i = query["item_idx"]
difficulty, discrimination, guessing = self._compute_item_params()
return self._irt_probability(
self.ability[s],
difficulty[i],
discrimination=discrimination[i] if discrimination is not None else None,
guessing=guessing[i] if guessing is not None else None,
)
[docs]
def fit(
self,
data,
embeddings: torch.Tensor,
mask: torch.Tensor | None = None,
max_epochs: int = 1000,
lr: float = 1e-3,
weight_decay: float = 1e-4,
verbose: bool = True,
**kwargs,
) -> dict:
"""Fit the amortized IRT model.
Parameters
----------
data : LongFormData | torch.Tensor
Long-form dataset (preferred) or wide-form response tensor.
embeddings : torch.Tensor
Item embeddings ``(n_items, embedding_dim)``.
mask : torch.Tensor | None
Boolean mask for observed entries (only used with wide-form input).
max_epochs : int
Maximum training epochs.
lr : float
Learning rate.
weight_decay : float
Weight decay for Adam optimizer.
verbose : bool
Show progress bar.
Returns
-------
dict
Training history.
"""
self.set_embeddings(embeddings)
from torch_measure.fitting.mle import mle_fit
subject_idx, item_idx, response = self._normalize_fit_inputs(data, mask)
return mle_fit(
self,
subject_idx,
item_idx,
response,
max_epochs=max_epochs,
lr=lr,
weight_decay=weight_decay,
verbose=verbose,
**kwargs,
)