Source code for torch_measure.cat.strategies

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Item selection strategies for Computerized Adaptive Testing.

Consolidated from irsl/utils.py _select_next_1pl, _select_next_2pl.
"""

from __future__ import annotations

from abc import ABC, abstractmethod

import torch

from torch_measure.cat.fisher import fisher_information


class SelectionStrategy(ABC):
    """Abstract base for item selection strategies."""

    @abstractmethod
    def select(
        self,
        ability_estimate: torch.Tensor,
        difficulty: torch.Tensor,
        discrimination: torch.Tensor | None,
        administered: torch.Tensor,
    ) -> int:
        """Select the next item to administer.

        Parameters
        ----------
        ability_estimate : torch.Tensor
            Current ability estimate (scalar).
        difficulty : torch.Tensor
            Item difficulties (M,).
        discrimination : torch.Tensor | None
            Item discriminations (M,).
        administered : torch.Tensor
            Boolean mask of already-administered items (M,).

        Returns
        -------
        int
            Index of the selected item.
        """
        ...


[docs] class MaxInfoStrategy(SelectionStrategy): """Select the item with maximum Fisher information at current ability estimate."""
[docs] def select(self, ability_estimate, difficulty, discrimination, administered): info = fisher_information(ability_estimate, difficulty, discrimination) if info.ndim > 1: info = info.squeeze(0) # Mask out administered items info[administered] = -float("inf") return info.argmax().item()
[docs] class SpanningStrategy(SelectionStrategy): """Select items spanning the difficulty range before switching to max-info. First selects items evenly across the difficulty range to get a rough ability estimate, then switches to maximum information selection. Parameters ---------- n_spanning : int Number of items to select in the spanning phase. """ def __init__(self, n_spanning: int = 10) -> None: self.n_spanning = n_spanning self._spanning_count = 0 self._max_info = MaxInfoStrategy()
[docs] def select(self, ability_estimate, difficulty, discrimination, administered): if self._spanning_count < self.n_spanning: self._spanning_count += 1 return self._select_spanning(difficulty, administered) return self._max_info.select(ability_estimate, difficulty, discrimination, administered)
def _select_spanning(self, difficulty, administered): available = ~administered if not available.any(): return 0 available_diffs = difficulty[available] available_indices = available.nonzero(as_tuple=True)[0] # Select items at evenly-spaced quantiles of difficulty n_available = len(available_diffs) target_quantile = self._spanning_count / max(self.n_spanning, 1) sorted_diffs, sorted_idx = available_diffs.sort() target_idx = min(int(target_quantile * n_available), n_available - 1) return available_indices[sorted_idx[target_idx]].item()
[docs] def reset(self): """Reset spanning count for a new test session.""" self._spanning_count = 0
[docs] class RandomStrategy(SelectionStrategy): """Select items randomly from the unadministered pool."""
[docs] def select(self, ability_estimate, difficulty, discrimination, administered): available = (~administered).nonzero(as_tuple=True)[0] if len(available) == 0: return 0 idx = torch.randint(len(available), (1,)).item() return available[idx].item()