Neural network parameters live in a space with strong permutation symmetries: you can reorder hidden units without changing the function, yet the raw weight tensors look completely different. If a representation ignores this, it ends up learning spurious differences and struggles to generalize across architectures or widths. This paper proposes representing a neural network as a neural graph (nodes as neurons/bias features, edges as weights) and then using a GNN to produce equivariant representations that respect these symmetries. This enables tasks like predicting generalization, classifying networks by behavior, retrieving similar architectures, and meta-learning over model populations.
Why Equivariance Matters for "Learning Over Networks"
The fundamental problem: parameter symmetries
Many tasks treat an entire neural network as a data point:
- Predict generalization from weights (without running validation)
- Classify networks by behavior (e.g., task type, dataset)
- Retrieve similar networks (e.g., find models solving the same task)
- Meta-learning / architecture analysis (extract patterns across model populations)
- Model merging (combine weights from different models)
But neural network weights have a key nuisance symmetry: in an MLP, you can permute hidden units (and permute the corresponding incoming/outgoing weights) without changing the function:
Problem: If you flatten weights into a vector
Consequence: A representation that ignores symmetries:
- Treats equivalent networks as different (wastes capacity learning spurious patterns)
- Fails to generalize across widths/architectures
- Cannot transfer knowledge between equivalent models
Why naive approaches fail
1. Flattening weights into a vector
Method: Concatenate all weight matrices into a
single vector
Problem:
- Not permutation-equivariant (permuting hidden units completely
changes
) - Dimension depends on architecture (cannot compare networks of different widths)
- Loses structural information (which weights connect to which neurons)
2. Statistical summaries (mean, variance, histograms)
Method: Compute statistics over weight distributions.
Problem:
- Loses all relational information (which neurons connect to which)
- Cannot distinguish functionally different networks with similar weight distributions
- Too coarse for fine-grained tasks (e.g., predicting generalization gap)
3. Training a CNN/MLP on weight matrices
Method: Treat weight matrices as "images" and apply convolution.
Problem:
- Still not equivariant (CNNs are translation-equivariant, not permutation-equivariant)
- Dimension mismatch across architectures
- Ignores graph structure (neurons are not arranged in a grid)
Neural Graphs: Turning Weights Into a Graph
Core idea
Represent a neural network as a directed graph
- Nodes
: neurons (or computational units) - Input layer nodes
- Hidden layer nodes
- Output layer nodes
- Optional: bias features as special nodes
- Edges
: connections with weights as edge features - Edge
exists if neuron connects to neuron - Edge feature: weight (or a tuple of weight + bias if combined)
- Edge
Example: For an MLP with architecture
input(2) → hidden(3) → output(1):
1 | Nodes: [input_1, input_2, hidden_1, hidden_2, hidden_3, output_1] |
Why this representation is powerful
Key insight: Graph neural networks (GNNs) are inherently permutation-equivariant. If you permute nodes in the graph, the GNN output transforms in a corresponding way.
Mathematically, for a GNN
Benefit: The neural-graph representation + GNN automatically respects the symmetries of the underlying neural network.
What "Equivariant" Means Here
Equivariance vs invariance
- Invariance: Output is the same regardless of
permutation
Use case: Graph-level prediction (e.g., predict generalization gap)
- Equivariance: Output transforms correspondingly
with permutation
Use case: Node-level prediction (e.g., neuron importance scores)
Why equivariance is stronger: You may want node-level embeddings to permute correspondingly (e.g., to align neurons across networks), and then optionally produce a graph-level invariant embedding via pooling (sum, mean, max).
Example: Neuron alignment
Suppose you have two MLPs with the same architecture but trained from different initializations. Their hidden units may encode similar features in different orders.
With naive vector representation: No way to identify corresponding neurons.
With equivariant GNN: Node embeddings can be matched (e.g., via optimal transport) to find neuron correspondence.
Model Architecture
High-level pipeline
Build neural graph from network parameters
- Nodes: neurons + biases (optional)
- Edges: weights
- Edge features: weight values (possibly concatenated with layer index, neuron type, etc.)
Run message passing (GNN layers)
- Each node aggregates information from neighbors
- Node features updated via learnable functions
- Multiple layers capture multi-hop interactions
Pool node representations (optional, for graph-level tasks)
- Sum pooling:
- Mean pooling: - Attention pooling:
- Sum pooling:
Train for downstream task
- Regression: predict generalization gap, training time, etc.
- Classification: classify network by task type, dataset, architecture family
- Retrieval: embed networks into a metric space for similarity search
GNN message passing (simplified)
At each layer
-
Common choices:
- GCN-style:
- GAT-style: Attention over neighbors
- MPNN-style: Edge networks for edge-conditioned messages
Downstream Tasks and Applications
Task 1: Predicting generalization gap
Setup: Given a trained network, predict
train_acc - test_acc without running validation.
Why equivariance helps: Equivalent networks (same function, different permutation) should predict the same generalization gap.
Method: 1. Build neural graph from trained weights
2. Run GNN to get graph embedding${}
Task 2: Network classification
Setup: Classify networks by task type (e.g., CIFAR-10 vs ImageNet), architecture family (ResNet vs VGG), or training method (SGD vs Adam).
Why equivariance helps: Permuting hidden units shouldn't change the classification (e.g., a ResNet is still a ResNet).
Method: Train a classifier on top of graph embeddings.
Task 3: Network retrieval
Setup: Given a query network, find similar networks in a database.
Why equivariance helps: Similarity should be measured in function space, not parameter space.
Method: 1. Embed all networks into a metric space via GNN 2. Use cosine similarity or Euclidean distance for retrieval 3. Optional: Train with contrastive loss (e.g., triplet loss) to encourage functionally similar networks to have nearby embeddings
Task 4: Meta-learning / model merging
Setup: Learn patterns across a population of models (e.g., predict which architectures generalize well, or merge multiple trained models).
Why equivariance helps: Models solving the same task with different permutations should be alignable.
Method: Use equivariant node embeddings to find neuron correspondences, then merge weights via averaging or optimal transport.
Comparison to Naive Baselines
| Method | Equivariant? | Scalable? | Captures Structure? |
|---|---|---|---|
| Flatten weights | ❌ No | ❌ No (dim grows with width) | ❌ No |
| Weight statistics | ✅ Yes (invariant) | ✅ Yes | ❌ No (loses relational info) |
| CNN on weight matrices | ❌ No (translation ≠ permutation) | ⚠️ Medium | ⚠️ Partial |
| Neural graphs + GNN | ✅ Yes | ✅ Yes (graph size ~ network size) | ✅ Yes |
Practical Considerations
1. Graph construction details
Node features: What to initialize node embeddings with?
- Random: Initialize
- Layer index: Encode which layer the neuron belongs to
- Neuron type: Encode whether it's input/hidden/output
- Bias values: Include bias as a node feature or separate node
Edge features: What information to encode?
- Weight value:
(most important) - Layer index: Which layer the connection belongs to
- Connection type: Conv, linear, attention, etc.
2. Handling different architectures
Problem: Networks with different architectures have different graph sizes/topologies.
Solutions:
- Invariant pooling: Use sum/mean pooling to get fixed-size graph embedding
- Attention pooling: Learn to weight important neurons
- Hierarchical pooling: Coarsen the graph in multiple stages
3. Computational cost
Concern: GNNs can be slow on large graphs.
Mitigations:
- Sampling: Sample subgraphs for mini-batch training
- Sparse GNNs: Use efficient sparse operations (most neural graphs are sparse)
- Layer-wise processing: Process one layer of the neural network at a time
4. Training stability
Issue: GNNs can suffer from over-smoothing (node representations become too similar after many layers).
Solutions:
- Residual connections:
- Layer normalization: Normalize node features at each layer
- Shallow GNNs: Use 2-4 layers instead of very deep models
Takeaways
Flattening weights loses structure and makes symmetry handling hard. Naive approaches treat equivalent networks as different, wasting capacity.
Graph-based representations enable equivariance. By representing a network as a neural graph, GNNs naturally respect permutation symmetries.
Equivariance enables cross-architecture transfer. The approach works across widths and architectures where naive parameter alignment is not meaningful.
Applications are diverse: Predicting generalization, classifying networks, retrieving similar models, meta-learning, and model merging all benefit from equivariant representations.
Open challenges: Scaling to very large models (e.g., LLMs with billions of parameters), handling diverse architectures (e.g., Transformers, graph networks), and designing task-specific graph constructions.
Further Reading
- Original paper: Graph Neural Networks for Learning Equivariant Representations of Neural Networks
- Related work:
- Code: Check the original paper for implementation details and benchmarks.
Summary: GNN for Neural Network Representations in 5 Steps
- Represent network as a graph: Nodes = neurons, edges = weights
- Run GNN message passing: Aggregate neighbor information, update node embeddings
- Pool to graph embedding: Sum/mean/attention pooling for graph-level tasks
- Train for downstream task: Regression (generalization), classification (task type), retrieval (similarity)
- Leverage equivariance: Equivalent networks (same function, different permutation) → consistent representations
Key insight: GNNs naturally respect permutation symmetries, enabling better generalization across architectures and widths.
- Post title:Graph Neural Networks for Learning Equivariant Representations of Neural Networks
- Post author:Chen Kai
- Create time:2023-09-02 00:00:00
- Post link:https://www.chenk.top/en/gnn-equivariant-representations/
- Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.