Source code for torch_measure.models.bradley_terry

# Copyright (c) 2026 AIMS Foundations. MIT License.

"""Bradley-Terry model for pairwise comparison data."""

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
from torch import nn

from torch_measure.fitting._losses import bernoulli_nll
from torch_measure.models._predictor import Predictor

if TYPE_CHECKING:
    from torch_measure.data.pairwise import PairwiseComparisons


[docs] class BradleyTerry(Predictor): """Bradley-Terry model for pairwise comparison data. Models the probability that subject *a* beats subject *b* as: .. math:: P(a > b) = \\sigma(\\theta_a - \\theta_b) Mathematically equivalent to Rasch, but the "item" axis is itself a subject — so ``predict(query)`` consumes ``subject_idx`` (the A-side) and ``item_idx`` (the B-side). Parameters ---------- n_subjects : int Number of subjects (e.g., LLMs). device : str Device to place parameters on. Examples -------- >>> from torch_measure.models import BradleyTerry >>> from torch_measure.models._predictor import predict_dense >>> model = BradleyTerry(n_subjects=3) >>> predict_dense(model) # (3, 3) win probability matrix """ def __init__(self, n_subjects: int, device: str = "cpu") -> None: # Both axes of the prediction are subjects; pass n_subjects twice. super().__init__(n_subjects, n_subjects, device) self.ability = nn.Parameter(torch.zeros(n_subjects, device=self._device))
[docs] def predict(self, query: dict[str, torch.Tensor]) -> torch.Tensor: """Compute P(a beats b) at query rows. ``query["subject_idx"]`` is the A-side; ``query["item_idx"]`` is the B-side (also a subject index). """ a = query["subject_idx"] b = query["item_idx"] return torch.sigmoid(self.ability[a] - self.ability[b])
[docs] def predict_pairwise(self, subject_a: torch.Tensor, subject_b: torch.Tensor) -> torch.Tensor: """Domain-named convenience: ``P(a beats b)`` for explicit pair tensors. Equivalent to ``self.predict({"subject_idx": subject_a, "item_idx": subject_b})``. """ return self.predict({"subject_idx": subject_a, "item_idx": subject_b})
[docs] def fit( self, comparisons: PairwiseComparisons, method: str = "mle", max_epochs: int = 1000, lr: float = 0.01, regularization: float = 0.01, convergence_tol: float = 1e-6, verbose: bool = True, ) -> dict: """Fit the model to pairwise comparison data. Parameters ---------- comparisons : PairwiseComparisons Pairwise comparison data with ``subject_a``, ``subject_b``, and ``outcome`` tensors. method : str Fitting method: ``"mle"`` (Adam optimizer) or ``"jml"`` (LBFGS with L2 regularization). max_epochs : int Maximum number of optimization epochs. lr : float Learning rate. regularization : float L2 regularization weight (only used for ``method="jml"``). convergence_tol : float Stop if loss change is below this threshold. verbose : bool Show progress bar. Returns ------- dict Training history with ``'losses'`` key. """ subject_a = comparisons.subject_a.to(self._device) subject_b = comparisons.subject_b.to(self._device) outcome = comparisons.outcome.to(self._device) if method == "jml": optimizer = torch.optim.LBFGS(self.parameters(), lr=lr, max_iter=20) elif method == "mle": optimizer = torch.optim.Adam(self.parameters(), lr=lr) else: raise ValueError(f"Unknown method: {method!r}. Use 'mle' or 'jml'.") history: dict[str, list] = {"losses": []} iterator = range(max_epochs) if verbose: try: from tqdm import tqdm iterator = tqdm(iterator, desc=f"BT {method.upper()} fitting") except ImportError: pass prev_loss = float("inf") for _epoch in iterator: if method == "jml": def closure(): optimizer.zero_grad() probs = self.predict_pairwise(subject_a, subject_b).clamp(1e-7, 1 - 1e-7) loss = bernoulli_nll(probs, outcome) loss = loss + regularization * self.ability.pow(2).mean() loss.backward() return loss loss = optimizer.step(closure) loss_val = loss.item() else: optimizer.zero_grad() probs = self.predict_pairwise(subject_a, subject_b).clamp(1e-7, 1 - 1e-7) loss = bernoulli_nll(probs, outcome) loss.backward() optimizer.step() loss_val = loss.item() history["losses"].append(loss_val) if verbose and hasattr(iterator, "set_postfix"): iterator.set_postfix({"loss": f"{loss_val:.6f}"}) if abs(prev_loss - loss_val) < convergence_tol: break prev_loss = loss_val return history
def __repr__(self) -> str: return f"BradleyTerry(n_subjects={self._n_subjects})"