Source code for torch_measure.models._network_base

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Abstract base class for network psychometric models."""

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING

import torch
from torch import nn

if TYPE_CHECKING:
    pass


[docs] class NetworkModel(nn.Module): """Abstract base class for network psychometric models. Network models characterize the conditional dependence structure among items rather than estimating per-subject latent traits. They expose: - `.fit(response_matrix, ...)` to estimate network parameters - `.adjacency` to access the estimated edge weight matrix - `.centrality(measure)` for common node centrality metrics Unlike :class:`~torch_measure.models.IRTModel`, there is no notion of subjects or per-subject ability — the model is defined entirely over items. """ def __init__(self, n_items: int, device: str | torch.device = "cpu") -> None: super().__init__() self._n_items = n_items self._device = torch.device(device) @property def n_items(self) -> int: return self._n_items @property @abstractmethod def adjacency(self) -> torch.Tensor: """Edge weight matrix of shape (n_items, n_items). Symmetric with zero diagonal. Positive values indicate positive conditional dependence; negative values indicate negative dependence. Returns ------- torch.Tensor Detached weight matrix, shape (n_items, n_items). """ ...
[docs] @abstractmethod def fit( self, data, mask: torch.Tensor | None = None, max_epochs: int = 1000, lr: float = 0.01, verbose: bool = True, **kwargs, ) -> dict: """Estimate network parameters. Parameters ---------- data : LongFormData | torch.Tensor Long-form dataset (preferred) or wide-form response tensor of shape ``(n_subjects, n_items)``. For wide-form, missing entries may be encoded as NaN or -1. mask : torch.Tensor | None Only used with wide-form input — boolean mask of entries to use. Inferred from NaNs if None. max_epochs : int Maximum optimisation epochs. lr : float Learning rate. verbose : bool Show progress bar. Returns ------- dict Training history with ``"losses"`` key. """ ...
def _normalize_fit_inputs_to_matrix( self, data, mask: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: """Coerce ``data`` to a dense ``(n_subjects, n_items)`` matrix + mask. Network models genuinely want the matrix (for item-item covariance, conditional probability via matrix multiply, etc.), so long-form inputs are pivoted internally. Repeated ``(s, i)`` pairs are averaged. """ from torch_measure.datasets._long_form import LongFormData if isinstance(data, LongFormData): fit_inputs = data.to_fit_tensors(device=str(self._device)) subject_idx = fit_inputs["subject_idx"] item_idx = fit_inputs["item_idx"] response = fit_inputs["response"] n_subjects = len(fit_inputs["subject_ids"]) n_items = len(fit_inputs["item_ids"]) matrix = torch.zeros((n_subjects, n_items), dtype=torch.float32, device=self._device) counts = torch.zeros((n_subjects, n_items), dtype=torch.float32, device=self._device) matrix.index_put_((subject_idx, item_idx), response.float(), accumulate=True) counts.index_put_( (subject_idx, item_idx), torch.ones_like(response, dtype=torch.float32), accumulate=True, ) built_mask = counts > 0 matrix[built_mask] = matrix[built_mask] / counts[built_mask] return matrix, built_mask if not isinstance(data, torch.Tensor): raise TypeError(f"fit() expected LongFormData or torch.Tensor, got {type(data).__name__}") response_matrix = data.to(self._device) if mask is None: mask = ~torch.isnan(response_matrix) & (response_matrix != -1) mask = mask.to(self._device) return response_matrix, mask
[docs] def centrality(self, measure: str = "strength") -> torch.Tensor: """Compute node centrality from the estimated adjacency matrix. Parameters ---------- measure : str One of ``"strength"``, ``"expected_influence"``, ``"closeness"``, or ``"betweenness"``. Returns ------- torch.Tensor Centrality scores per item, shape (n_items,). """ from torch_measure.metrics.network import ( betweenness_centrality, closeness_centrality, expected_influence, strength_centrality, ) A = self.adjacency if measure == "strength": return strength_centrality(A) elif measure == "expected_influence": return expected_influence(A) elif measure == "closeness": return closeness_centrality(A) elif measure == "betweenness": return betweenness_centrality(A) else: raise ValueError( f"Unknown centrality measure: {measure!r}. " "Choose from 'strength', 'expected_influence', 'closeness', 'betweenness'." )