Source code for torch_measure.cat.fisher
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Fisher information for adaptive item selection.
Consolidated from irsl/utils.py compute_fisher_info_2pl.
"""
from __future__ import annotations
import torch
[docs]
def fisher_information(
ability: torch.Tensor,
difficulty: torch.Tensor,
discrimination: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute Fisher information for each item at given ability levels.
For the 2PL model: I(theta) = a^2 * P(theta) * (1 - P(theta))
For the Rasch model (a=1): I(theta) = P(theta) * (1 - P(theta))
Higher information means the item is more useful for estimating
ability at that level.
Parameters
----------
ability : torch.Tensor
Subject ability values, shape (N,) or scalar.
difficulty : torch.Tensor
Item difficulty values, shape (M,).
discrimination : torch.Tensor | None
Item discrimination values, shape (M,). Defaults to 1 (Rasch).
Returns
-------
torch.Tensor
Fisher information matrix, shape (N, M) or (M,) if ability is scalar.
"""
if ability.ndim == 0:
ability = ability.unsqueeze(0)
squeeze = True
else:
squeeze = False
# P(correct) = sigmoid(a * (theta - b))
logit = ability.unsqueeze(1) - difficulty.unsqueeze(0)
if discrimination is not None:
logit = discrimination.unsqueeze(0) * logit
p = torch.sigmoid(logit)
q = 1 - p
info = p * q
if discrimination is not None:
info = discrimination.unsqueeze(0) ** 2 * info
return info.squeeze(0) if squeeze else info