Graph Neural Networks for Learning Equivariant Representations of Neural Networks
Chen Kai BOSS

Neural network parameters live in a space with strong permutation symmetries: you can reorder hidden units without changing the function, yet the raw weight tensors look completely different. If a representation ignores this, it ends up learning spurious differences and struggles to generalize across architectures or widths. This paper proposes representing a neural network as a neural graph (nodes as neurons/bias features, edges as weights) and then using a GNN to produce equivariant representations that respect these symmetries. This enables tasks like predicting generalization, classifying networks by behavior, retrieving similar architectures, and meta-learning over model populations.

Why Equivariance Matters for "Learning Over Networks"

The fundamental problem: parameter symmetries

Many tasks treat an entire neural network as a data point:

  • Predict generalization from weights (without running validation)
  • Classify networks by behavior (e.g., task type, dataset)
  • Retrieve similar networks (e.g., find models solving the same task)
  • Meta-learning / architecture analysis (extract patterns across model populations)
  • Model merging (combine weights from different models)

But neural network weights have a key nuisance symmetry: in an MLP, you can permute hidden units (and permute the corresponding incoming/outgoing weights) without changing the function:

whereis a permutation matrix.

Problem: If you flatten weights into a vector, the representation is not equivariant: equivalent networks (same function, different permutation) map to completely different vectors.

Consequence: A representation that ignores symmetries:

  • Treats equivalent networks as different (wastes capacity learning spurious patterns)
  • Fails to generalize across widths/architectures
  • Cannot transfer knowledge between equivalent models

Why naive approaches fail

1. Flattening weights into a vector

Method: Concatenate all weight matrices into a single vector.

Problem:

  • Not permutation-equivariant (permuting hidden units completely changes)
  • Dimension depends on architecture (cannot compare networks of different widths)
  • Loses structural information (which weights connect to which neurons)

2. Statistical summaries (mean, variance, histograms)

Method: Compute statistics over weight distributions.

Problem:

  • Loses all relational information (which neurons connect to which)
  • Cannot distinguish functionally different networks with similar weight distributions
  • Too coarse for fine-grained tasks (e.g., predicting generalization gap)

3. Training a CNN/MLP on weight matrices

Method: Treat weight matrices as "images" and apply convolution.

Problem:

  • Still not equivariant (CNNs are translation-equivariant, not permutation-equivariant)
  • Dimension mismatch across architectures
  • Ignores graph structure (neurons are not arranged in a grid)

Neural Graphs: Turning Weights Into a Graph

Core idea

Represent a neural network as a directed graph:

  • Nodes: neurons (or computational units)
    • Input layer nodes
    • Hidden layer nodes
    • Output layer nodes
    • Optional: bias features as special nodes
  • Edges: connections with weights as edge features
    • Edgeexists if neuronconnects to neuron - Edge feature: weight(or a tuple of weight + bias if combined)

Example: For an MLP with architecture input(2) → hidden(3) → output(1):

1
2
3
4
5
6
Nodes: [input_1, input_2, hidden_1, hidden_2, hidden_3, output_1]
Edges:
(input_1 → hidden_1, weight=W1[1,1])
(input_1 → hidden_2, weight=W1[2,1])
...
(hidden_3 → output_1, weight=W2[1,3])

Why this representation is powerful

Key insight: Graph neural networks (GNNs) are inherently permutation-equivariant. If you permute nodes in the graph, the GNN output transforms in a corresponding way.

Mathematically, for a GNN:whereis a permutation of nodes.

Benefit: The neural-graph representation + GNN automatically respects the symmetries of the underlying neural network.


What "Equivariant" Means Here

Equivariance vs invariance

  • Invariance: Output is the same regardless of permutation

Use case: Graph-level prediction (e.g., predict generalization gap)

  • Equivariance: Output transforms correspondingly with permutation

Use case: Node-level prediction (e.g., neuron importance scores)

Why equivariance is stronger: You may want node-level embeddings to permute correspondingly (e.g., to align neurons across networks), and then optionally produce a graph-level invariant embedding via pooling (sum, mean, max).

Example: Neuron alignment

Suppose you have two MLPs with the same architecture but trained from different initializations. Their hidden units may encode similar features in different orders.

With naive vector representation: No way to identify corresponding neurons.

With equivariant GNN: Node embeddings can be matched (e.g., via optimal transport) to find neuron correspondence.


Model Architecture

High-level pipeline

  1. Build neural graph from network parameters

    • Nodes: neurons + biases (optional)
    • Edges: weights
    • Edge features: weight values (possibly concatenated with layer index, neuron type, etc.)
  2. Run message passing (GNN layers)

    • Each node aggregates information from neighbors
    • Node features updated via learnable functions
    • Multiple layers capture multi-hop interactions
  3. Pool node representations (optional, for graph-level tasks)

    • Sum pooling: - Mean pooling: - Attention pooling:
  4. Train for downstream task

    • Regression: predict generalization gap, training time, etc.
    • Classification: classify network by task type, dataset, architecture family
    • Retrieval: embed networks into a metric space for similarity search

GNN message passing (simplified)

At each layer: where:

-: hidden state of nodeat layer -: neighbors of node -: edge feature (weight) fromto - AGGREGATE: sum, mean, max, or attention-weighted aggregation - UPDATE: MLP, GRU, or other learnable function

Common choices:

  • GCN-style:
  • GAT-style: Attention over neighbors
  • MPNN-style: Edge networks for edge-conditioned messages

Downstream Tasks and Applications

Task 1: Predicting generalization gap

Setup: Given a trained network, predict train_acc - test_acc without running validation.

Why equivariance helps: Equivalent networks (same function, different permutation) should predict the same generalization gap.

Method: 1. Build neural graph from trained weights 2. Run GNN to get graph embedding${} = f{}(_{})$

Task 2: Network classification

Setup: Classify networks by task type (e.g., CIFAR-10 vs ImageNet), architecture family (ResNet vs VGG), or training method (SGD vs Adam).

Why equivariance helps: Permuting hidden units shouldn't change the classification (e.g., a ResNet is still a ResNet).

Method: Train a classifier on top of graph embeddings.

Task 3: Network retrieval

Setup: Given a query network, find similar networks in a database.

Why equivariance helps: Similarity should be measured in function space, not parameter space.

Method: 1. Embed all networks into a metric space via GNN 2. Use cosine similarity or Euclidean distance for retrieval 3. Optional: Train with contrastive loss (e.g., triplet loss) to encourage functionally similar networks to have nearby embeddings

Task 4: Meta-learning / model merging

Setup: Learn patterns across a population of models (e.g., predict which architectures generalize well, or merge multiple trained models).

Why equivariance helps: Models solving the same task with different permutations should be alignable.

Method: Use equivariant node embeddings to find neuron correspondences, then merge weights via averaging or optimal transport.


Comparison to Naive Baselines

Method Equivariant? Scalable? Captures Structure?
Flatten weights ❌ No ❌ No (dim grows with width) ❌ No
Weight statistics ✅ Yes (invariant) ✅ Yes ❌ No (loses relational info)
CNN on weight matrices ❌ No (translation ≠ permutation) ⚠️ Medium ⚠️ Partial
Neural graphs + GNN ✅ Yes ✅ Yes (graph size ~ network size) ✅ Yes

Practical Considerations

1. Graph construction details

Node features: What to initialize node embeddings with?

  • Random: Initialize
  • Layer index: Encode which layer the neuron belongs to
  • Neuron type: Encode whether it's input/hidden/output
  • Bias values: Include bias as a node feature or separate node

Edge features: What information to encode?

  • Weight value:(most important)
  • Layer index: Which layer the connection belongs to
  • Connection type: Conv, linear, attention, etc.

2. Handling different architectures

Problem: Networks with different architectures have different graph sizes/topologies.

Solutions:

  • Invariant pooling: Use sum/mean pooling to get fixed-size graph embedding
  • Attention pooling: Learn to weight important neurons
  • Hierarchical pooling: Coarsen the graph in multiple stages

3. Computational cost

Concern: GNNs can be slow on large graphs.

Mitigations:

  • Sampling: Sample subgraphs for mini-batch training
  • Sparse GNNs: Use efficient sparse operations (most neural graphs are sparse)
  • Layer-wise processing: Process one layer of the neural network at a time

4. Training stability

Issue: GNNs can suffer from over-smoothing (node representations become too similar after many layers).

Solutions:

  • Residual connections:
  • Layer normalization: Normalize node features at each layer
  • Shallow GNNs: Use 2-4 layers instead of very deep models

Takeaways

  1. Flattening weights loses structure and makes symmetry handling hard. Naive approaches treat equivalent networks as different, wasting capacity.

  2. Graph-based representations enable equivariance. By representing a network as a neural graph, GNNs naturally respect permutation symmetries.

  3. Equivariance enables cross-architecture transfer. The approach works across widths and architectures where naive parameter alignment is not meaningful.

  4. Applications are diverse: Predicting generalization, classifying networks, retrieving similar models, meta-learning, and model merging all benefit from equivariant representations.

  5. Open challenges: Scaling to very large models (e.g., LLMs with billions of parameters), handling diverse architectures (e.g., Transformers, graph networks), and designing task-specific graph constructions.


Further Reading


Summary: GNN for Neural Network Representations in 5 Steps

  1. Represent network as a graph: Nodes = neurons, edges = weights
  2. Run GNN message passing: Aggregate neighbor information, update node embeddings
  3. Pool to graph embedding: Sum/mean/attention pooling for graph-level tasks
  4. Train for downstream task: Regression (generalization), classification (task type), retrieval (similarity)
  5. Leverage equivariance: Equivalent networks (same function, different permutation) → consistent representations

Key insight: GNNs naturally respect permutation symmetries, enabling better generalization across architectures and widths.

  • Post title:Graph Neural Networks for Learning Equivariant Representations of Neural Networks
  • Post author:Chen Kai
  • Create time:2023-09-02 00:00:00
  • Post link:https://www.chenk.top/en/gnn-equivariant-representations/
  • Copyright Notice:All articles in this blog are licensed under BY-NC-SA unless stating additionally.
 Comments