# Copyright (c) 2026 AIMS Foundations. MIT License.
"""Factor rotation methods for interpretability.
Consolidated from factor-model/calibration/rotation.py.
"""
from __future__ import annotations
import torch
[docs]
def varimax_rotation(
loadings: torch.Tensor, max_iter: int = 100, tol: float = 1e-6
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply Varimax rotation to factor loadings.
Varimax maximizes the variance of squared loadings within each factor,
producing a simpler structure.
Parameters
----------
loadings : torch.Tensor
Factor loading matrix (n_items, n_factors).
max_iter : int
Maximum iterations.
tol : float
Convergence tolerance.
Returns
-------
rotated_loadings : torch.Tensor
Rotated loading matrix (n_items, n_factors).
rotation_matrix : torch.Tensor
Rotation matrix (n_factors, n_factors).
"""
n_factors = loadings.shape[1]
if n_factors < 2:
# factor_analyzer's Rotator crashes on single-factor input: its _varimax
# returns 1 value when n_factors < 2 but fit_transform unpacks 2.
# Single-factor rotation is the identity, so short-circuit here.
identity = torch.eye(n_factors, dtype=loadings.dtype, device=loadings.device)
return loadings.clone(), identity
try:
from factor_analyzer.rotator import Rotator
rotator = Rotator(method="varimax")
rotated = rotator.fit_transform(loadings.detach().cpu().numpy())
rotated_t = torch.tensor(rotated, dtype=loadings.dtype, device=loadings.device)
# Compute rotation matrix: R = L_orig^+ @ L_rotated
rotation = torch.linalg.lstsq(loadings.detach(), rotated_t).solution
return rotated_t, rotation
except ImportError:
# Simple varimax implementation
return _varimax_torch(loadings, max_iter, tol)
def _varimax_torch(loadings: torch.Tensor, max_iter: int = 100, tol: float = 1e-6) -> tuple[torch.Tensor, torch.Tensor]:
"""Pure PyTorch Varimax rotation."""
p, k = loadings.shape
rotation = torch.eye(k, dtype=loadings.dtype, device=loadings.device)
rotated = loadings.clone()
for _ in range(max_iter):
old = rotated.clone()
for i in range(k):
for j in range(i + 1, k):
# Compute rotation angle
u = rotated[:, i] ** 2 - rotated[:, j] ** 2
v = 2 * rotated[:, i] * rotated[:, j]
a = u.sum()
b = v.sum()
c = (u**2 - v**2).sum()
d = (2 * u * v).sum()
angle = 0.25 * torch.atan2(2 * d - 2 * a * b / p, c - (a**2 - b**2) / p)
cos_a = torch.cos(angle)
sin_a = torch.sin(angle)
# Apply rotation
col_i = rotated[:, i] * cos_a + rotated[:, j] * sin_a
col_j = -rotated[:, i] * sin_a + rotated[:, j] * cos_a
rotated[:, i] = col_i
rotated[:, j] = col_j
# Update rotation matrix
rot_i = rotation[:, i] * cos_a + rotation[:, j] * sin_a
rot_j = -rotation[:, i] * sin_a + rotation[:, j] * cos_a
rotation[:, i] = rot_i
rotation[:, j] = rot_j
if (rotated - old).abs().max() < tol:
break
return rotated, rotation
[docs]
def promax_rotation(loadings: torch.Tensor, power: int = 4, **kwargs) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply Promax (oblique) rotation to factor loadings.
Promax starts with Varimax and then applies a power transformation
to achieve simple structure while allowing correlated factors.
Parameters
----------
loadings : torch.Tensor
Factor loading matrix (n_items, n_factors).
power : int
Power parameter for the Promax transformation.
Returns
-------
rotated_loadings : torch.Tensor
Promax-rotated loadings.
rotation_matrix : torch.Tensor
Rotation matrix.
"""
n_factors = loadings.shape[1]
if n_factors < 2:
# factor_analyzer's _promax inherits the same single-factor unpacking
# bug as _varimax (returns 1 value when n_factors < 2). Short-circuit.
identity = torch.eye(n_factors, dtype=loadings.dtype, device=loadings.device)
return loadings.clone(), identity
try:
from factor_analyzer.rotator import Rotator
rotator = Rotator(method="promax", power=power)
rotated = rotator.fit_transform(loadings.detach().cpu().numpy())
rotated_t = torch.tensor(rotated, dtype=loadings.dtype, device=loadings.device)
rotation = torch.linalg.lstsq(loadings.detach(), rotated_t).solution
return rotated_t, rotation
except ImportError:
# Fall back to varimax
return varimax_rotation(loadings, **kwargs)
[docs]
def bifactor_rotation(
U: torch.Tensor,
V: torch.Tensor,
Z: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Apply bifactor rotation: whiten, then Varimax, then separate general factor.
Parameters
----------
U : torch.Tensor
Subject abilities (n_subjects, K).
V : torch.Tensor
Item loadings (n_items, K).
Z : torch.Tensor
Item intercepts (n_items,).
Returns
-------
U_rot : torch.Tensor
Rotated abilities.
V_rot : torch.Tensor
Rotated loadings.
Z : torch.Tensor
Unchanged intercepts.
"""
# Whiten U
U_centered = U - U.mean(dim=0)
cov = (U_centered.T @ U_centered) / (U.shape[0] - 1)
L = torch.linalg.cholesky(cov + 1e-6 * torch.eye(cov.shape[0], device=cov.device))
L_inv = torch.linalg.inv(L)
U_white = U_centered @ L_inv.T
V_transformed = V @ L.T
# Varimax rotation
V_rot, rotation = varimax_rotation(V_transformed)
U_rot = U_white @ rotation
return U_rot, V_rot, Z