paper
active
2025
paper:doi-10-48550-arxiv-2501-06164

Model Alignment Search

TL;DR

Model Alignment Search (MAS) establishes bidirectional causal similarity between neural networks by learning a per-model orthogonal rotation matrix that isolates behaviorally relevant subspaces and uses interchange interventions — patching those subspaces across frozen model pairs — to measure functional alignment via Interchange Intervention Accuracy (IIA). Comparing GRUs and 2-layer Transformers on numeric tasks reveals that correlative methods like RSA and CKA give misleading estimates: RSA shows anomalously low embedding-layer similarity between same-architecture GRU seeds, and both CKA and RSA suggest potentially high hidden-state similarity between GRU and Transformer hidden states that MAS correctly diagnoses as low because Transformers employ an anti-Markovian solution that recomputes numeric information at every step. MAS compresses behaviorally relevant information to as few as 4 dimensions while achieving IIA comparable to DAS, and it reduces the number of required comparison matrices from O(n²) to O(n), making it more compute-efficient than traditional model stitching for three or more models. A case study on DeepSeek-R1-Distill-Qwen-1.5B models fine-tuned on toxic versus nontoxic text demonstrates that toxic-to-toxic MAS IIA is measurably higher than toxic-to-nontoxic IIA, whereas nontoxic-to-nontoxic comparisons show no significant internal difference — suggesting MAS can serve as a diagnostic for representational misalignment. The Counterfactual Latent MAS (CLMAS) extension, which adds an auxiliary L2 plus cosine loss against prerecorded latent vectors, recovers causal alignment even when one model is causally inaccessible, implying the method may generalize to ANN–biological neural network comparisons where only recordings, not interventions, are available.

What to take away

  1. 1. MAS reduces the number of learned comparison matrices from O(n²) for pairwise model stitching to O(n), one orthogonal rotation matrix per model, when comparing n models.
  2. 2. RSA (using Spearman rank correlation on cosine-distance RDMs over 1000 sampled vectors) produces anomalously low embedding-layer similarity scores even for GRU models of the same architecture trained on identical Multi-Object tasks with different random seeds, whereas MAS IIA correctly shows near-ceiling causal transfer.
  3. 3. CKA and RSA both suggest high hidden-state similarity between Multi-Object GRUs and 2-layer RoPE Transformers, but MAS IIA is low because Transformers use an anti-Markovian solution that recomputes numeric information at every sequence step, a difference that correlative methods cannot detect.
  4. 4. MAS can compress all behaviorally relevant causal information into as few as 4 aligned dimensions (matching DAS performance), while model stitching achieves near-perfect IIA even at rank 2 by exploiting the behavioral null-space of the source model.
  5. 5. Fine-tuned DeepSeek-R1-Distill-Qwen-1.5B toxic models exhibit higher stepwise MAS IIA when compared to other toxic models than when compared to nontoxic models, with no significant IIA difference observed in nontoxic-to-nontoxic comparisons.
  6. 6. GRUs trained on Multi-Object and Rounding tasks show lower cross-task MAS IIA for their numeric subspaces than within-task GRU pairs, and restricting the Arithmetic GRU's Cumu Val range to 1–10 raises MAS IIA toward but not to the level of the Rem Ops alignment, consistent with GRUs encoding arithmetic and counting numbers differently.
  7. 7. CLMAS, which augments the MAS loss with an auxiliary L2 plus cosine loss (weighted by hyperparameter ε tested at 0.5, 0.89, 0.94) against prerecorded counterfactual latent vectors, achieves higher IIA in the causally inaccessible intervention direction than both behavioral stitching and latent stitching baselines while matching standard MAS in the accessible direction.
  8. 8. An open question raised is whether including more than two models simultaneously in a single MAS training would harm alignment quality by creating conflicting gradient signals or would instead improve isolation of causally relevant subspaces across all models.
  9. 9. MAS rotation matrices are trained for 1000 epochs using Adam (lr=0.001, batch size 512), with 10,000 intervention samples and 1,000 held-out validation samples, orthogonalized via PyTorch's exponential-of-skew-symmetric parametrization, selecting the checkpoint with best validation IIA — a fully replicable procedure.
  10. 10. Model stitching can succeed at near-perfect IIA using rank-2 transformations by relying on the source model's behavioral null-space and dormant subspaces, meaning a successful stitch does not imply that the two networks encode the task variable in structurally similar ways.

Peer brief — for seminar discussion

Grant (2025) introduces Model Alignment Search (MAS), a method for measuring functional similarity between pairs of frozen neural networks by learning one orthogonal rotation matrix per model that simultaneously uncovers causally relevant latent subspaces and maps them onto each other, then uses bidirectional interchange interventions — patching those subspaces across models — and measuring the resulting Interchange Intervention Accuracy (IIA) on counterfactual behavior as the similarity score. The method is conceptually a fusion of model stitching (Bansal et al., 2021) and Distributed Alignment Search (Geiger et al., 2021; 2023), and it was validated on GRUs, LSTMs, 2-layer RoPE Transformers trained on numeric equivalence tasks, and DeepSeek-R1-Distill-Qwen-1.5B models fine-tuned for toxic or nontoxic text generation. An alternative it could have used — but argues against — is standard model stitching, which the paper shows achieves near-perfect IIA even at rank 2 by exploiting the source model's behavioral null-space rather than isolating genuinely shared causal structure. The load-bearing finding is that correlative methods, specifically RSA via Spearman rank correlation on 1000-vector cosine-distance RDMs and CKA via cosine kernels, systematically misrepresent functional similarity in ways that MAS corrects. RSA gives anomalously low embedding-layer scores for same-architecture GRU seeds trained on the identical Multi-Object task, while both CKA and RSA suggest high hidden-state similarity between GRUs and Transformers on the same task — a similarity that MAS IIA correctly identifies as low because Transformers employ an anti-Markovian solution that recomputes numeric state at each token, rendering their hidden states causally non-equivalent to GRU hidden states. MAS also reveals that GRUs trained on counting tasks versus arithmetic tasks encode number differently: cross-task numeric subspace IIA is lower than within-task IIA, and restricting the arithmetic model's Cumu Val range to 1–10 raises but does not close the gap to the Rem Ops alignment. In the toxicity case study, toxic DeepSeek-R1-Distill-Qwen-1.5B models show higher IIA with other toxic models than with nontoxic models, while nontoxic-to-nontoxic comparisons show no significant difference. The paper also introduces CLMAS, which adds a weighted auxiliary L2 plus cosine loss (ε ∈ {0.5, 0.89, 0.94}) against prerecorded counterfactual latent vectors to recover causal alignment when one model is causally inaccessible — the scenario relevant for ANN–biological neural network comparisons. The key implication is that causal intervention-based similarity measures should supplement or replace correlative measures whenever the research goal is to determine whether two networks perform a task through the same mechanism, not just whether their representations are linearly related. The paper predicts that CLMAS-like methods could eventually reduce the need for neural stimulation in biological comparisons by pre-computing alignment candidates from recordings alone. A critical reader would push back on the limited scope of the toxicity case study: fine-tuning DeepSeek-R1-Distill-Qwen-1.5B on concatenated toxicity datasets (Jigsaw 2018; ToxicChat; RLHF preference data) with only 3 seeds per condition and reporting token-level rather than trial-level IIA makes it difficult to distinguish a genuine representational signature of toxicity from a superficial output-distribution shift induced by fine-tuning on stylistically distinct corpora. The result that toxic models align better with each other than with nontoxic models is consistent with the toxicity hypothesis but equally consistent with the simpler explanation that models fine-tuned on the same data distribution share low-level statistical regularities that MAS picks up regardless of whether those regularities reflect anything semantically meaningful about misalignment.

Methods (5)

  • Alignment Function (AF)
    Learnable invertible transformation in DAS/MAS that rotates latent vectors into aligned subspaces; narrowed to orthogonal matrices Q.
  • Counterfactual Latent (CL) Auxiliary Loss
    Auxiliary objective combining L2 and cosine losses against pre-recorded CL vectors to improve causal relevance when one model is causally inaccessible.
  • Latent Stitch
    Baseline method using a single orthogonal matrix trained to map source latents to target latents via CL auxiliary loss without behavioral objective.
  • Optogenetics
    Light-gated ion channels used to control bioelectric states and dissect cellular computation.
  • Stepwise MAS
    MAS variant applying interchange interventions at multiple contiguous token positions from the start of a sequence to a sampled time step t.

Frameworks (6)

  • Counterfactual Latent MAS (CLMAS)
    MAS variant with an auxiliary CL loss objective for cases where one model is causally inaccessible, enabling ANN-BNN comparisons.
  • Gated Recurrent Unit (GRU)
    Recurrent neural network architecture used as the primary model type in numeric task experiments.
  • Linear Representation Hypothesis
    The hypothesis that models internalize concepts as approximately linear directions in representation space; used to interpret MDS injection behavior
  • Long Short-Term Memory (LSTM)
    Recurrent neural network architecture used alongside GRUs in numeric task experiments; MAS applied to concatenated h and c vectors.
  • Model Alignment Search (MAS)
    The primary contribution of the paper: a bidirectional causal method that learns rotation matrices for each model to uncover and compare causally relevant latent subspaces across neural networks.
  • Shallow Transformer (RoPE-based)
    Two-layer transformer with rotary positional encodings used in numeric task experiments.

Datasets (7)

  • Arithmetic Task Dataset
    More complex numeric task involving addition/subtraction operations with cumulative values; used in Appendix B.7 to explore MAS across differing domains.
  • DeepSeek-R1-Distill-Qwen-1.5B
    Small model used in attention head attribution analysis in appendix
  • Modulo Task Dataset
    Numeric task where the number of response tokens equals the object quantity mod 4.
  • Multi-Object Task Dataset
    Primary numeric task where models count demonstration tokens and produce matching response tokens; used for most MAS analyses.
  • Rounding Task Dataset
    Numeric task where the number of response tokens equals the object quantity rounded to the nearest multiple of 3.
  • Same-Object Task Dataset
    Variant of Multi-Object task using a single token type C instead of multiple demo/response types.
  • Toxicity Finetuning Dataset
    Concatenation of three toxicity-related datasets used to finetune DeepSeek models for the misalignment case study.

Findings (12)

Claims (7)

Hypotheses (3)

Questions (5)

Original abstract (expand)

When can we say that two neural systems perform a task in the same way? What nuances do we miss when we fail to causally probe the representations of the systems, and how do we establish bidirectional causal relationships? In this work, we introduce a method that bidirectionally transfers neural activity between artificial neural networks and uses their resulting behavior as a measure of functional similarity. We first show that the method can be used to transfer the behavior from one frozen Neural Network (NN) to another in a manner similar to model stitching, and we show how the method can differ from correlative similarity measures like Representational Similarity Analysis. Next, we empirically and theoretically show how the method can be equivalent to model stitching when desired, or it can take a form that has a more restrictive focus to shared causal information; in both forms, it reduces the number of required matrices for a comparison of n models to be linear in n. We then present a case study on number-related tasks showing that the method can be used to examine specific subtypes of causal information demonstrating that numbers can be encoded differently in recurrent models depending on the task, and we present another case study showing that MAS can reveal misalignment in fine-tuned DeepSeek-r1-Qwen-1.5B models. Lastly, we augment the loss function with a counterfactual latent (CL) auxiliary objective to improve causal relevance when one of the two networks is causally inaccessible (as is often the case in comparisons with biological networks). We use our results to encourage the use of causal methods in neural similarity analyses and to suggest future explorations of network similarity methodology for model misalignment.

Related work— refs + corpus + external arXiv

Cited / in-corpus / arXiv badges show which signals surfaced each row. Multi-source rows weighted higher.

+18 more

Similar preprints — Semantic Scholar

Cited by (1)