Source code for torch_measure.metrics.correlation

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Correlation metrics for binary response data.

Consolidated from factor-model/calibration/util.py tetrachoric_matrix_torch.
"""

from __future__ import annotations

import torch


[docs] def tetrachoric_correlation(data: torch.Tensor, min_pairs: int = 5) -> torch.Tensor: """Compute the tetrachoric correlation matrix for binary data. Uses the cosine-pi approximation: r = cos(pi / (1 + sqrt(AD / BC))) where A, B, C, D are the counts in the 2x2 contingency table for each pair of items. Parameters ---------- data : torch.Tensor Binary response matrix (n_subjects, n_items) with values 0, 1, or NaN. min_pairs : int Minimum number of valid pairs required. Pairs with fewer observations get correlation 0. Returns ------- torch.Tensor Tetrachoric correlation matrix of shape (n_items, n_items). """ valid = ~torch.isnan(data) data_clean = data.clone() data_clean[~valid] = 0.0 # Compute 2x2 contingency tables for all pairs # A = both 1, D = both 0, B = (i=1, j=0), C = (i=0, j=1) both_valid = valid.float().T @ valid.float() # (M, M) a = data_clean.T @ data_clean # both correct d = (1 - data_clean).T @ (1 - data_clean) * (valid.float().T @ valid.float()).clamp(min=1).reciprocal() d = d * both_valid # both incorrect # Recompute properly x = data_clean * valid.float() x_not = (1 - data_clean) * valid.float() a = x.T @ x # (1,1) cell d = x_not.T @ x_not # (0,0) cell b = x.T @ x_not # (1,0) cell c = x_not.T @ x # (0,1) cell # Tetrachoric approximation: r = cos(pi / (1 + sqrt(AD/BC))) ad = a * d bc = b * c bc = bc.clamp(min=1e-10) # avoid division by zero ratio = (ad / bc).clamp(min=0) sqrt_ratio = torch.sqrt(ratio) r = torch.cos(torch.pi / (1 + sqrt_ratio)) # Handle edge cases r[both_valid < min_pairs] = 0.0 r.fill_diagonal_(1.0) # Symmetrize r = (r + r.T) / 2 return r
[docs] def point_biserial_correlation( continuous: torch.Tensor, binary: torch.Tensor, ) -> torch.Tensor: """Compute point-biserial correlation between continuous and binary variables. Parameters ---------- continuous : torch.Tensor Continuous variable (e.g., total score) of shape (N,). binary : torch.Tensor Binary variable (e.g., item response) of shape (N,) or (N, M). Returns ------- torch.Tensor Correlation(s). Scalar if binary is 1D, shape (M,) if 2D. """ if binary.ndim == 1: binary = binary.unsqueeze(1) squeeze = True else: squeeze = False mask = ~torch.isnan(binary) & ~torch.isnan(continuous.unsqueeze(1)) results = [] for j in range(binary.shape[1]): m = mask[:, j] x = continuous[m] y = binary[m, j] group1 = x[y == 1] group0 = x[y == 0] if len(group1) < 2 or len(group0) < 2: results.append(torch.tensor(0.0)) continue m1, m0 = group1.mean(), group0.mean() n1, n0 = len(group1), len(group0) n = n1 + n0 s = x.std() if s < 1e-10: results.append(torch.tensor(0.0)) continue rpb = (m1 - m0) / s * torch.sqrt(torch.tensor(n1 * n0 / n**2, dtype=torch.float32)) results.append(rpb) result = torch.stack(results) return result.squeeze() if squeeze else result