Source code for torch_measure.models.beta_rasch
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Beta-Rasch (1PL) IRT model for continuous (0, 1) responses."""
from __future__ import annotations
from functools import partial
import torch
from torch_measure.models.rasch import Rasch
[docs]
class BetaRasch(Rasch):
"""Beta-Rasch (1PL) IRT model.
Identical to Rasch in prediction: ``mu = sigmoid(theta - b)``.
Uses Beta NLL loss instead of Bernoulli NLL for fitting, allowing
continuous responses in (0, 1) such as empirical probabilities.
Parameters
----------
n_subjects : int
Number of subjects (test-takers / models).
n_items : int
Number of items (test questions / benchmark tasks).
phi : float
Beta distribution precision parameter. Higher values mean
tighter concentration around the predicted mean. Default 10.0.
device : str
Device to place parameters on.
References
----------
.. [1] Item Response Scaling Laws (ICML 2026).
"""
def __init__(self, n_subjects: int, n_items: int, phi: float = 10.0, device: str = "cpu") -> None:
super().__init__(n_subjects, n_items, device)
self.phi = phi
[docs]
def fit(
self,
response_matrix: torch.Tensor,
mask: torch.Tensor | None = None,
method: str = "mle",
max_epochs: int = 1000,
lr: float = 0.01,
verbose: bool = True,
**kwargs,
) -> dict:
"""Fit the Beta-Rasch model using Beta NLL loss.
Parameters
----------
response_matrix : torch.Tensor
Continuous response matrix with values in (0, 1),
shape (n_subjects, n_items). Values must be strictly
between 0 and 1 (exclusive).
mask : torch.Tensor | None
Boolean mask of entries to use. If None, uses all non-NaN entries.
method : str
Fitting method: "mle", "em", or "jml".
max_epochs : int
Maximum optimization epochs.
lr : float
Learning rate.
verbose : bool
Show progress bar.
Returns
-------
dict
Training history.
"""
from torch_measure.fitting._losses import beta_nll
loss_fn = partial(beta_nll, phi=self.phi)
return super().fit(
response_matrix,
mask=mask,
method=method,
max_epochs=max_epochs,
lr=lr,
verbose=verbose,
loss_fn=loss_fn,
**kwargs,
)