Neural Collapse in Hierarchical Loss Functions
Investigating neural collapse phenomena and hierarchical loss functions in deep classification networks
Course Project: EE5179 - Deep Learning for Imaging | December 2023
Team: Siddharth Betala, Ruban Vishnu Pandian, Nikhil Anand
Institution: IIT Madras
Problem Statement
Neural Collapse (NC) is a recently discovered phenomenon occurring in the terminal phase of training (TPT) where deep neural networks exhibit four key characteristics:
- Collapse of variability: Training samples converge to their class means
- Simplex ETF formation: Class means form an equiangular tight frame
- Self-dual alignment: Last layer weights align with class means
- Nearest class-mean decoding: Classification reduces to finding the nearest class mean
While NC correlates with good generalization, it may not be optimal for hierarchically structured data where mistake severity varies (e.g., confusing a dog with a cat is less severe than confusing a dog with a car).
Research Objectives
- Validation: Replicate neural collapse findings across different architectures and loss functions
- Extension: Investigate whether hierarchical loss functions modify NC characteristics
- Evaluation: Assess if hierarchical approaches can maintain accuracy while encoding semantic class relationships
Theoretical Framework
Neural Collapse Metrics
We employ four quantitative metrics to measure NC emergence:
NC1 (Variability Collapse): Measures ratio of within-class to between-class covariance
\[NC1 = \frac{1}{K} tr(\Sigma_w \Sigma_b^\dagger)\]NC2 (Simplex ETF): Quantifies deviation from ideal simplex geometry
\[NC2 = \left\|\frac{MM^T}{\|MM^T\|_F} - \frac{1}{\sqrt{K-1}}\left(I_K - \frac{1}{K}1_K1_K^T\right)\right\|_F\]NC3 (Self-Dual Alignment): Measures alignment between weights and class means
\[NC3 = \left\|\frac{AM^T}{\|AM^T\|_F} - \frac{1}{\sqrt{K-1}}\left(I_K - \frac{1}{K}1_K1_K^T\right)\right\|_F\]NC4 (Nearest Class-Mean): Fraction of samples not classified by nearest mean rule
\[NC4 = \frac{1}{Kn} \sum_{i=1}^{Kn} \mathbb{I}(c_i \neq \arg\min_k \|f_i - \mu_k\|_2)\]Hierarchical Loss Formulations
We explored two main approaches to incorporate class hierarchy:
1. Hierarchical Cross-Entropy (HXE)
Computes conditional probabilities along the hierarchy path, weighting them based on tree height:
\[\mathcal{L}_{HXE}(p, C) = -\sum_{l=0}^{h-1} \lambda(C^{(l)}) \log p(C^{(l)}|C^{(l+1)})\]where \(\lambda(C) = \exp(-\alpha h(C))\)
2. Soft Labels
Replaces one-hot encodings with soft labels weighted by hierarchical distance:
\[\mathcal{L}_{soft} = -\sum_{A \in \mathcal{C}} y_A^{soft}(C) \log p(A)\]where
\[y_A^{soft}(C) = \frac{\exp(-\beta d(A,C))}{\sum_{B \in \mathcal{C}} \exp(-\beta d(B,C))}\] \[d(A,C) = \frac{height(LCA)}{height(tree)}\]Experimental Setup
Baseline Validation:
- Datasets: MNIST, CIFAR-10
- Architectures: MLP, ResNet18
- Loss Functions: MSE Loss, Cross-Entropy Loss
- Goal: Replicate NC emergence patterns from literature
Hierarchical Extension:
- Constructed semantic hierarchy for CIFAR-10 classes
- Implemented both HXE and soft label approaches
- Compared NC metrics against standard training
Key Concepts
Hierarchical Neural Networks
The goal is to not only minimize classification errors but also minimize their severity. Key definitions:
- Top-k prediction: The predictions with the k highest probabilities
- Hierarchical distance: Mean height of the Least Common Ancestor (LCA) between ground truth and predicted class
- Hierarchical Average Top-k Error: Mean LCA height between ground truth and k most likely classes
Three Approaches to Incorporate Hierarchy
- Label embeddings \(y^H\): Embeddings where cosine similarity is proportional to hierarchical distance
- Hierarchical loss function \(\mathcal{L}^H\): Modified loss with penalties based on mistake severity
- Hierarchical architecture \(\phi^H\): Generic distinctions in early layers, fine-grained distinctions in later layers
Current Status & Findings
Baseline Replication ✓
Successfully replicated NC emergence patterns consistent with literature:
- All four NC metrics converge toward zero during extended training
- Behavior consistent across MLP and ResNet18 architectures
- Both MSE and Cross-Entropy losses exhibit NC characteristics
Hierarchical Loss Investigation (In Progress)
The hierarchical loss experiments are under active development. Initial implementation complete for:
- Class hierarchy construction for CIFAR-10
- HXE and soft label loss functions
- Integrated training pipeline
Pending Work:
- Complete training runs for hierarchical models
- Compute NC metrics for hierarchical loss functions
- Statistical comparison of NC emergence patterns
- Analysis of hierarchical distance of mistakes
Expected Contributions
- Empirical Validation: Comprehensive replication of NC across multiple settings
- Novel Investigation: First systematic study of NC under hierarchical loss functions
- Practical Insights: Understanding trade-offs between NC properties and semantic structure
Technical Implementation
Framework: PyTorch
Key Components:
- Custom NC metric computation (NC1-NC4)
- Hierarchical loss implementations (HXE, Soft Labels)
- Class hierarchy construction utilities
- Visualization tools for feature space geometry
Future Directions
- Complete hierarchical loss experimental runs
- Extend to ImageNet with richer hierarchical structure
- Investigate NC in modern architectures (Vision Transformers)
- Theoretical analysis of NC under hierarchical constraints
- Study interaction with few-shot and continual learning
References
[1] Kothapalli, V., Rasromani, E., & Awatramani, V. (2023). Neural Collapse: A Review on Modelling Principles and Generalization. Trans. Mach. Learn. Res.
[2] Bertinetto, L., et al. (2020). Making Better Mistakes: Leveraging Class Hierarchies With Deep Networks. CVPR, 12503-12512.
Code & Documentation
📂 GitHub: Neural Collapse and Class Hierarchy
 📄 Report: Detailed mathematical formulations and theoretical background