Source code for torch_measure.models.bifactor
# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Bifactor model with general and group-specific factors."""
from __future__ import annotations
import torch
from torch import nn
from torch_measure.models._base import IRTModel
[docs]
class Bifactor(IRTModel):
"""Bifactor Model.
A constrained factor model with one general factor and multiple
group-specific factors. The general factor loads on all items,
while group factors load only on items in their cluster.
P(correct) = sigmoid(g_n * lambda_g_j + sum_k(s_nk * lambda_sk_j) + z_j)
Parameters
----------
n_subjects : int
Number of subjects.
n_items : int
Number of items.
n_groups : int
Number of group-specific factors.
item_groups : torch.Tensor
Group assignment for each item, shape (n_items,). Values in [0, n_groups).
device : str
Device.
"""
def __init__(
self,
n_subjects: int,
n_items: int,
n_groups: int,
item_groups: torch.Tensor,
device: str = "cpu",
) -> None:
super().__init__(n_subjects, n_items, device)
self.n_groups = n_groups
self.register_buffer("item_groups", item_groups.to(self._device))
# General factor
self.general_ability = nn.Parameter(torch.randn(n_subjects, device=self._device))
self.general_loading = nn.Parameter(torch.randn(n_items, device=self._device))
# Group-specific factors
self.group_ability = nn.Parameter(torch.randn(n_subjects, n_groups, device=self._device))
self.group_loading = nn.Parameter(torch.randn(n_items, device=self._device))
# Intercept
self.Z = nn.Parameter(torch.randn(n_items, device=self._device))
@property
def ability(self) -> torch.Tensor:
return self.general_ability.detach()
@property
def difficulty(self) -> torch.Tensor:
return -self.Z.detach()
[docs]
def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor:
"""Compute P(correct) at query rows using general + group factors."""
s = query["subject_idx"]
i = query["item_idx"]
g_per_item = self.item_groups[i] # (N,) — group index of each query item
general = self.general_ability[s] * self.general_loading[i]
group = self.group_ability[s, g_per_item] * self.group_loading[i]
logit = general + group + self.Z[i]
return torch.sigmoid(logit)