mech-interp
npx skills add https://github.com/pranav-karra-3301/skills --skill mech-interp
Agent 安装分布
Skill 文档
Mechanistic Interpretability
Overview
Mechanistic interpretability (mech interp) is the science of reverse-engineering neural networks to understand the algorithms they learn. The core question: “What computation is this model performing, and how?”
Key concepts:
- Residual stream: The main highway through the model; each layer reads from and writes to it
- Features: Directions in activation space representing interpretable concepts
- Circuits: Subgraphs implementing specific behaviors
- Superposition: Models represent more features than dimensions using non-orthogonal directions
Why it matters:
- Understand model capabilities and limitations
- Debug unexpected behaviors
- Verify safety properties
- Build interpretable AI systems
Core Workflow
Phase 1: Environment Setup
-
Detect compute resources
import torch device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Device: {device}") if device == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB") -
Install libraries based on model size
Model Size Primary Tool Install â¤2B params TransformerLens pip install transformer-lens2B-13B params nnsight pip install nnsightSAE training SAELens pip install sae-lens[train]Both ecosystems nnterp pip install nnterp -
Load model
from transformer_lens import HookedTransformer model = HookedTransformer.from_pretrained( "gpt2-small", device=device, )
See references/tools.md for detailed tool setup.
Phase 2: Experiment Design
Before writing code, clarify:
-
Research question: What specifically are you trying to understand?
- “What does attention head L5H3 do?”
- “How does the model represent ‘is_capital_of’ relationships?”
- “Which components contribute to this prediction?”
-
Hypothesis: What do you expect to find?
- Phrase as testable predictions
- Include what would falsify the hypothesis
-
Technique selection (see table below)
-
Validation plan: How will you verify findings?
- Causal interventions
- Held-out examples
- Alternative explanations to rule out
Phase 3: Implementation
Select technique based on your goal:
| Goal | Technique | Tool/Method | Reference |
|---|---|---|---|
| What tokens does model predict at each layer? | Logit Lens | resid @ W_U |
techniques.md |
| Which component affects this output? | Activation Patching | run_with_hooks |
techniques.md |
| How much does each head contribute to logit? | Direct Logit Attribution | Decompose residual | techniques.md |
| What information does this head move? | OV Circuit Analysis | W_V @ W_O |
techniques.md |
| What attends to what? | QK Circuit Analysis | W_Q @ W_K.T |
techniques.md |
| Is information X represented here? | Probing | Train classifier | techniques.md |
| Find interpretable features | SAE | Train/load SAE | sae-guide.md |
| Which feature represents concept Y? | Feature Search | Max activating examples | sae-guide.md |
Phase 4: Analysis
-
Run experiments
- Cache activations:
logits, cache = model.run_with_cache(tokens) - Always use
torch.no_grad()for inference - Save intermediate results
- Cache activations:
-
Visualize results
- Attention heatmaps
- Patching effect matrices
- Feature activation distributions
-
Iterate
- Refine hypothesis based on findings
- Test edge cases
- Look for counterexamples
Phase 5: Validation
Before claiming findings, verify:
- Causal evidence: Ablating/patching changes behavior as predicted
- Held-out data: Results replicate on unseen examples
- Multiple seeds: Not an artifact of specific randomness
- Alternative explanations: Ruled out simpler stories
- Effect size: Practically meaningful, not just statistically significant
See references/pitfalls.md for common mistakes.
Technique Quick Reference
Logit Lens
Project intermediate representations through unembedding to see evolving predictions.
for layer in range(model.cfg.n_layers):
resid = cache["resid_post", layer]
resid_normed = model.ln_final(resid)
logits = resid_normed @ model.W_U
top_token = logits[0, -1].argmax()
print(f"Layer {layer}: {model.to_str_tokens(top_token)}")
Activation Patching
Measure causal effect by swapping activations between runs.
def patch_hook(activation, hook):
activation[:, pos, :] = clean_cache[hook.name][:, pos, :]
return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(hook_point, patch_hook)]
)
Direct Logit Attribution
Decompose final logits into per-component contributions.
target_dir = model.W_U[:, target_token_idx]
for layer in range(model.cfg.n_layers):
attn_contribution = cache["attn_out", layer][0, -1] @ target_dir
mlp_contribution = cache["mlp_out", layer][0, -1] @ target_dir
print(f"L{layer} attn: {attn_contribution:.3f}, mlp: {mlp_contribution:.3f}")
SAE Feature Analysis
Find interpretable features in activations.
from sae_lens import SAE
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
feature_acts = sae.encode(cache["resid_pre", 8])
top_features = feature_acts[0, -1].topk(10)
Model Size Guidance
| Model | Library | Memory (FP16) | Notes |
|---|---|---|---|
| GPT-2-small | TransformerLens | ~0.25GB | Best for learning |
| GPT-2-medium/large | TransformerLens | ~0.7-1.5GB | Good balance |
| GPT-2-xl | TransformerLens | ~3GB | Needs decent GPU |
| Pythia-70M to 410M | TransformerLens | ~0.15-0.8GB | Checkpoints available |
| Pythia-1B to 2.8B | TransformerLens | ~2-5.5GB | Pushes memory |
| Pythia-6.9B+ | nnsight | ~14GB+ | Use nnsight for efficiency |
| Llama-2-7B, Mistral-7B | nnsight | ~14GB | Needs 24GB+ GPU |
| Llama-2-13B+ | nnsight | ~26GB+ | Need A100/multi-GPU |
See references/compute-awareness.md for memory estimation.
When to Ask the User
Ask before proceeding when:
-
Research question unclear
“What specific behavior or component are you trying to understand?”
-
Compute constraints unknown
“What GPU do you have available? This model needs ~XGB VRAM.”
-
Multiple valid approaches
“We could use activation patching (causal) or probing (correlational). Which do you prefer?”
-
Unexpected results
“The results don’t match expectations. Should we investigate further or try a different approach?”
-
Scaling decisions
“Initial results look promising on GPT-2-small. Want to scale up to a larger model?”
Common Tasks
“Set up a mech interp project”
- Create project structure (see repo-maintenance.md)
- Install dependencies based on target model
- Set up CLAUDE.md with project-specific instructions
- Configure experiment tracking (wandb or simple JSON logging)
“What does this attention head do?”
- Visualize attention patterns across diverse inputs
- Analyze QK circuit (what attends to what)
- Analyze OV circuit (what information moves)
- Test with activation patching (is it necessary?)
- Check for known patterns (induction, copying, etc.)
“Find the circuit for behavior X”
- Design clean/corrupted input pairs
- Patch residual stream: layer à position heatmap
- Narrow to specific layers
- Patch individual heads
- Validate with ablation
- Analyze winning components
“Train an SAE”
- Choose layer and hook point
- Estimate memory requirements
- Set hyperparameters (expansion factor, L1 coefficient)
- Run training with monitoring (L0, reconstruction loss, dead features)
- Evaluate quality before analysis
See references/sae-guide.md for detailed guidance.
“Interpret SAE features”
- Load pretrained SAE or train your own
- Find max activating examples for features of interest
- Look for patterns in activating contexts
- Test hypothesis with feature steering/ablation
- Validate causal role
Quality Checklist
Before concluding analysis:
- Research question clearly stated
- Appropriate technique selected
- Code runs without errors
- Results visualized
- Causal validation performed
- Edge cases tested
- Alternative explanations considered
- Results documented with reproducibility info
Reference Files
| File | Contents |
|---|---|
| tools.md | TransformerLens, nnsight, SAELens setup |
| techniques.md | Patching, logit lens, circuits, probing |
| sae-guide.md | SAE training and analysis |
| visualization.md | Plotting patterns and dashboards |
| pitfalls.md | Common mistakes and validation |
| repo-maintenance.md | Project structure templates |
| vocabulary.md | Glossary of terms |
| compute-awareness.md | GPU/memory guidance |