Source code for torch_measure.data.masking

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Masking strategies for train/test splitting response matrices.

Consolidated from factor-model/calibration/util.py and predictive-eval/utils/util.py.
"""

from __future__ import annotations

import torch


[docs] def random_mask(observed: torch.Tensor, train_frac: float = 0.8) -> tuple[torch.Tensor, torch.Tensor]: """Randomly split observed entries into train/test masks. Parameters ---------- observed : torch.Tensor Boolean mask of observed entries (n_subjects x n_items). train_frac : float Fraction of observed entries to assign to training. Returns ------- train_mask, test_mask : tuple[torch.Tensor, torch.Tensor] Boolean masks for training and testing. """ coin = torch.rand_like(observed.float()) train_mask = observed & (coin < train_frac) test_mask = observed & ~train_mask return train_mask, test_mask
[docs] def l_mask(observed: torch.Tensor, row_frac: float = 0.8, col_frac: float = 0.8) -> tuple[torch.Tensor, torch.Tensor]: """L-shaped masking: fully observe a subset of rows AND columns for training. The test set consists of the intersection of held-out rows and held-out columns. This tests transductive generalization (new subjects on new items). Parameters ---------- observed : torch.Tensor Boolean mask of observed entries. row_frac : float Fraction of rows to fully observe in training. col_frac : float Fraction of columns to fully observe in training. Returns ------- train_mask, test_mask : tuple[torch.Tensor, torch.Tensor] """ n_rows, n_cols = observed.shape n_train_rows = int(n_rows * row_frac) n_train_cols = int(n_cols * col_frac) row_perm = torch.randperm(n_rows) col_perm = torch.randperm(n_cols) train_rows = torch.zeros(n_rows, dtype=torch.bool) train_rows[row_perm[:n_train_rows]] = True train_cols = torch.zeros(n_cols, dtype=torch.bool) train_cols[col_perm[:n_train_cols]] = True # L-shape: train on (train_rows x all_cols) union (all_rows x train_cols) train_mask = observed & (train_rows[:, None] | train_cols[None, :]) test_mask = observed & ~train_mask return train_mask, test_mask
[docs] def row_mask( observed: torch.Tensor, train_frac: float = 0.8, exposure_rate: float = 0.3, ) -> tuple[torch.Tensor, torch.Tensor]: """Row-based masking: fully observe some rows, partially observe the rest. Parameters ---------- observed : torch.Tensor Boolean mask of observed entries. train_frac : float Fraction of rows to fully observe. exposure_rate : float Fraction of entries to observe in held-out rows. Returns ------- train_mask, test_mask : tuple[torch.Tensor, torch.Tensor] """ n_rows = observed.shape[0] n_train = int(n_rows * train_frac) perm = torch.randperm(n_rows) train_mask = observed.clone() # For held-out rows, only expose a fraction for i in perm[n_train:]: row_observed = observed[i] coin = torch.rand_like(row_observed.float()) train_mask[i] = row_observed & (coin < exposure_rate) test_mask = observed & ~train_mask return train_mask, test_mask
[docs] def col_mask( observed: torch.Tensor, train_frac: float = 0.8, exposure_rate: float = 0.3, ) -> tuple[torch.Tensor, torch.Tensor]: """Column-based masking: fully observe some columns, partially observe the rest. Parameters ---------- observed : torch.Tensor Boolean mask of observed entries. train_frac : float Fraction of columns to fully observe. exposure_rate : float Fraction of entries to observe in held-out columns. Returns ------- train_mask, test_mask : tuple[torch.Tensor, torch.Tensor] """ # Transpose, apply row_mask, transpose back train_t, test_t = row_mask(observed.T, train_frac, exposure_rate) return train_t.T, test_t.T
[docs] def model_mask( observed: torch.Tensor, train_frac: float = 0.8, exposure_rate: float = 0.3, ) -> tuple[torch.Tensor, torch.Tensor]: """Model-based masking (alias for row_mask). Fully observe train_frac of models, partially observe the rest. """ return row_mask(observed, train_frac, exposure_rate)
[docs] def item_mask( observed: torch.Tensor, train_frac: float = 0.8, exposure_rate: float = 0.3, ) -> tuple[torch.Tensor, torch.Tensor]: """Item-based masking (alias for col_mask). Fully observe train_frac of items, partially observe the rest. """ return col_mask(observed, train_frac, exposure_rate)