paper
active
2023
9
paper:arxiv-2303-02536

Finding Alignments Between Interpretable Causal Variables and Distributed Neural Representations

TL;DR

Distributed alignment search (DAS) resolves two blocking limitations of prior causal abstraction work—brute-force alignment search and the localist assumption that high-level variables map to disjoint neuron sets—by using gradient descent over orthogonal rotation matrices to find alignments in non-standard bases of neural representations. On a hierarchical equality task, a three-layer feed-forward network with hidden size 16 achieves 100% interchange intervention accuracy (IIA) under DAS at layer 1 with an 8-dimensional intervention subspace, whereas the best brute-force localist search reaches only 0.60 IIA and the closest localist alignment only 0.73 IIA. On the Monotonicity NLI benchmark, BERT-base fine-tuned on MoNLI achieves 100% IIA at layer 9 when 256 non-standard basis dimensions of the [CLS] token encode lexical entailment and 256 others encode negation, while no localist alignment exceeds 0.51 IIA on the same task. A subsequent subspace decomposition reveals a structural asymmetry: the hierarchical equality representations of w=x and y=z cannot be decomposed into representations of individual input identities (subspace DAS IIA ≈ 0.50–0.51), whereas the apparent lexical-entailment representation in BERT decomposes almost perfectly (IIA ≈ 0.97–0.98) into two word-identity representations. DAS implies that previous negative or weak causal abstraction findings may have been artifacts of the localist assumption, and that neural networks can genuinely implement tree-structured symbolic algorithms—but that apparent relational representations may sometimes be data structures over entity identities rather than true relational encodings.

What to take away

  1. 1. DAS (distributed alignment search) finds alignments between high-level causal variables and distributed neural representations by optimizing an orthogonal rotation matrix with stochastic gradient descent rather than brute-force search over localist neuron subsets.
  2. 2. On the hierarchical equality task, DAS achieves 100% IIA at layer 1 of a 16-hidden-unit feed-forward network using an 8-dimensional intervention subspace, compared to 0.60 IIA for brute-force localist search and 0.73 IIA for the closest localist alignment.
  3. 3. For BERT-base fine-tuned on the MoNLI benchmark, DAS finds 100% IIA at layer 9 with a 256-dimensional intervention subspace for the joint negation-and-lexical-entailment high-level model, while all localist alignments remain at or below 0.51 IIA.
  4. 4. The learned rotation matrices are non-trivial: eigenvector rotation analyses show the majority of basis vectors are substantially rotated, indicating that high-level causal structure is genuinely distributed and not recoverable by standard neuron-aligned probes.
  5. 5. Subspace DAS applied to the hierarchical equality task finds that representations of w=x and y=z cannot be decomposed into representations of individual input identities (IIA ≈ 0.50–0.51), establishing that the network encodes abstract relational structure independent of the participating entities.
  6. 6. Subspace DAS applied to the MoNLI BERT model finds that the apparent lexical-entailment representation decomposes nearly perfectly into two word-identity representations (IIA ≈ 0.97–0.98 at layer 9), revealing it is a data structure over word identities rather than a true relational encoding.
  7. 7. DAS runtime for the MoNLI task is approximately 1,105 seconds, versus a tractable brute-force runtime of 198 seconds over a limited hypothesis set, but the brute-force worst-case combinatorial space is estimated at C(768,32) ≈ 2e58 hypotheses, making exhaustive search computationally infeasible.
  8. 8. To replicate DAS, one implements a differentiable orthogonal matrix parameterization (e.g., PyTorch's torch.nn.utils.parametrizations.orthogonal), freezes both low-level and high-level models, and minimizes cross-entropy between the high-level output distribution and the push-forward of the low-level output distribution under distributed interchange interventions.
  9. 9. Applying DAS to randomly initialized, chance-accuracy (50%) networks shows that IIA increases only when the hidden dimension is orders of magnitude larger than the input dimension (e.g., reaching 0.64 IIA only at hidden size 4096 for a 16-dimensional input), confirming that DAS cannot fabricate causal structure absent from the model.
  10. 10. An open question the paper raises is whether non-linear invertible transformations (e.g., normalizing flows) rather than orthogonal matrices would be required to find alignments when high-level variables are encoded in non-linear sub-manifolds of the representation space, which DAS in its current form cannot handle.

Peer brief — for seminar discussion

Geiger et al. (2024) address a foundational bottleneck in causal abstraction-based interpretability: prior methods require brute-force search over localist alignments—mappings from high-level causal variables to disjoint neuron subsets—making them both computationally intractable and structurally biased against the distributed representations widely hypothesized to characterize neural networks. The paper introduces distributed alignment search (DAS), which parameterizes the alignment as an orthogonal rotation matrix over a subspace of a neural layer's representation, then optimizes it with stochastic gradient descent using interchange intervention training objectives, with both the neural network and the high-level causal model frozen. An alternative approach the paper could have used is iterative nullspace projection (INLP), which also searches for linear subspaces encoding target concepts but does so adversarially rather than causally and would not directly optimize interchange intervention accuracy. The load-bearing empirical finding is a clean double dissociation. On a hierarchical equality task, a three-layer feed-forward network with hidden size 16 achieves 100% IIA under DAS at layer 1 with an 8-dimensional intervention subspace, while brute-force localist search plateaus at 0.60 IIA and the nearest localist re-projection at 0.73 IIA. On the Monotonicity NLI benchmark, BERT-base fine-tuned on MoNLI reaches 100% IIA at layer 9 with a 256-dimensional subspace encoding both negation and lexical entailment jointly, whereas no tested localist alignment exceeds 0.51 IIA. A further subspace decomposition (Subspace DAS) then reveals a structural difference between the two cases: the equality representations in the feed-forward network cannot be decomposed into individual entity-identity representations (decomposition IIA ≈ 0.50), whereas the apparent lexical-entailment representation in BERT decomposes nearly perfectly into two word-identity representations (IIA ≈ 0.97–0.98 at layer 9), indicating it is a data structure over lexeme identities rather than a genuine relational encoding. The paper's interpretive claim is that when 100% IIA is achieved and representations resist decomposition, the neural network literally implements a symbolic, tree-structured algorithm—not merely an approximation of one. This is framed as foundational for understanding the coexistence of symbolic and connectionist computation. The paper also implicitly predicts that many previously reported weak or null causal abstraction findings in the literature will prove to be artifacts of the localist assumption rather than genuine evidence of non-symbolic computation. A critical reader would push back on the scope of the experimental substrate. Both tasks—hierarchical equality on a toy three-layer MLP and MoNLI on a single BERT-base fine-tune—are constructed to have clean, known symbolic solutions with exactly two intermediate variables, and both models are trained to 100% accuracy before analysis begins. The claim that DAS scales to realistic large models and messy tasks remains undemonstrated: the paper itself acknowledges that rotating the full [CLS] representation of BERT-base across a concatenated token sequence would require approximately 15.4B parameters in the rotation matrix, which is intractable, and scaling is deferred to future work. A skeptic could reasonably argue that the 100% IIA results reflect the extreme simplicity and synthetic construction of the tasks rather than a general property of gradient-descent alignment search, and that the decomposability asymmetry—while striking—is observed in only two settings, one of which (BERT on MoNLI) is a fine-tuned model on a purpose-built dataset that may not generalize to naturalistic language understanding.

Methods (4)

  • Distributed Alignment Search
    The core method introduced in this paper: finds alignments between high-level causal variables and distributed neural representations via gradient descent.
  • Distributed Interchange Intervention
    Extends interchange interventions to non-standard bases by rotating representations, intervening in rotated subspaces, then rotating back.
  • Interchange Intervention Accuracy
    Proportion of aligned interchange interventions with equivalent high-level and low-level effects; graded measure of causal abstraction.
  • Subspace DAS
    Extension of DAS that learns a second rotation matrix on top of a fixed first one to decompose representations into sub-representations.

Frameworks (1)

  • Parallel Distributed Processing Framework
    The theoretical framework from Rumelhart, McClelland, and Smolensky (1986) identifying distributed representations in neural networks; theoretical precursor to DAS.

Datasets (3)

  • Hierarchical Equality Training Data
    1.92M randomly generated input-output pairs used to train the feed-forward network on the hierarchical equality task.
  • MoNLI Dataset
    Natural language inference dataset where premise-hypothesis pairs differ by a single word; used to evaluate DAS on BERT.
  • MultiNLI Dataset
    BERT is first fine-tuned on MultiNLI before being fine-tuned on MoNLI in the NLI experiment.

Findings (12)

Claims (9)

Questions (4)

Original abstract (expand)

Causal abstraction is a promising theoretical framework for explainable artificial intelligence that defines when an interpretable high-level causal model is a faithful simplification of a low-level deep learning system. However, existing causal abstraction methods have two major limitations: they require a brute-force search over alignments between the high-level model and the low-level one, and they presuppose that variables in the high-level model will align with disjoint sets of neurons in the low-level one. In this paper, we present distributed alignment search (DAS), which overcomes these limitations. In DAS, we find the alignment between high-level and low-level models using gradient descent rather than conducting a brute-force search, and we allow individual neurons to play multiple distinct roles by analyzing representations in non-standard bases-distributed representations. Our experiments show that DAS can discover internal structure that prior approaches miss. Overall, DAS removes previous obstacles to conducting causal abstraction analyses and allows us to find conceptual structure in trained neural nets.

Related work— refs + corpus + external arXiv

Cited / in-corpus / arXiv badges show which signals surfaced each row. Multi-source rows weighted higher.

+23 more

Similar preprints — Semantic Scholar

Cited by (6)