Skip to content

stchakwdev/Mamba_KAN

Repository files navigation

Mamba-KAN

A Rigorous Factorial Comparison of Neural Network Architectures

Python 3.9+ PyTorch 2.0+ License: MIT CI Code style: black

Documentation | Quick Start | Results | Citation


Overview

This project implements a comprehensive 2Γ—3 factorial experiment comparing neural network architectures, investigating the interplay between feedforward components and sequence modeling approaches.

Research Question

Do Kolmogorov-Arnold Networks (KAN) outperform MLPs due to their learnable B-spline activation functions, or their unique network topology?

Following Wu et al. (2024), we isolate these effects by including MLP+B-spline baselines alongside Transformer and Mamba sequence models.


Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                    2Γ—3 FACTORIAL DESIGN                         β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚                                                                 β”‚
β”‚   Feedforward Type          Sequence Model                      β”‚
β”‚   ════════════════         ══════════════                       β”‚
β”‚                                                                 β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                      β”‚
β”‚   β”‚  MLP        │──────────│ Transformer β”‚ ─► mlp_transformer   β”‚
β”‚   β”‚  (ReLU/GELU)β”‚          β”‚ (Attention) β”‚                      β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                      β”‚
β”‚         β”‚                        β”‚                              β”‚
β”‚         β”‚                        β”‚                              β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                      β”‚
β”‚   β”‚  MLP +      │──────────│ Transformer β”‚ ─► bspline_transformerβ”‚
β”‚   β”‚  B-spline   β”‚          β”‚ (Attention) β”‚                      β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                      β”‚
β”‚         β”‚                        β”‚                              β”‚
β”‚         β”‚                        β”‚                              β”‚
β”‚   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”          β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                      β”‚
β”‚   β”‚  Full KAN   │──────────│ Transformer β”‚ ─► kan_transformer   β”‚
β”‚   β”‚  (Learnable)β”‚          β”‚ (Attention) β”‚                      β”‚
β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜          β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                      β”‚
β”‚         β”‚                        β”‚                              β”‚
β”‚         β”‚                  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                      β”‚
β”‚         └──────────────────│   Mamba     β”‚ ─► *_mamba variants  β”‚
β”‚                            β”‚   (SSM)     β”‚                      β”‚
β”‚                            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                      β”‚
β”‚                                                                 β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Six Model Variants

Variant Feedforward Sequence Purpose
mlp_transformer MLP (ReLU/GELU) Attention Baseline
bspline_transformer MLP + B-spline Attention Isolate activation effect
kan_transformer Full KAN Attention Full KAN architecture
mlp_mamba MLP (ReLU/GELU) SSM Mamba baseline
bspline_mamba MLP + B-spline SSM Activation + SSM
kan_mamba Full KAN SSM Full KAN + SSM (novel)

Key Results

60 experiments completed: 3 models Γ— 2 tasks Γ— 10 seeds on NVIDIA H100 80GB

Results Visualization

Experiment Results

Accuracy Comparison

Task MLP Transformer KAN Transformer B-spline Transformer
Symbolic Regression 0.0077 Β± 0.0023 0.0080 Β± 0.0028 0.0082 Β± 0.0038
Language Modeling 10.8366 Β± 0.0007 10.8373 Β± 0.0006 10.8363 Β± 0.0017

Training Speed Comparison

Model Speed (steps/s) Time per Experiment Slowdown vs MLP
MLP Transformer 92.4 26s 1.0Γ— (baseline)
KAN Transformer 52.0 50s 1.78Γ— slower
B-spline Transformer 19.6 633s 4.72Γ— slower

Key Findings

Metric MLP KAN B-spline
Accuracy Best ~Equal ~Equal
Speed Fastest 1.78Γ— slower 4.72Γ— slower
Recommendation Use this If interpretability needed Not recommended

Model Comparison

Conclusions

  1. All models perform similarly on accuracy - differences are within statistical noise
  2. MLP wins on speed - fastest training with best or equal accuracy
  3. KAN is practical with efficient-kan - only 1.78Γ— slower (vs 60,000Γ— with pykan)
  4. B-spline provides no benefit - slowest model without accuracy gains

Quick Start

Installation

# Clone the repository
git clone https://github.com/stchakwdev/Mamba_KAN.git
cd Mamba_KAN

# Create environment
conda create -n mamba_kan python=3.10
conda activate mamba_kan

# Install PyTorch (adjust CUDA version as needed)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install package with development dependencies
pip install -e ".[dev]"

Run Quick Validation

# Quick test (1 model, 1 seed, 100 steps)
make train

# Or directly:
python scripts/run_experiment.py \
    --model mlp_transformer \
    --task symbolic \
    --seeds 1 \
    --max-steps 100 \
    --no-wandb

Run Full Comparison

# Full factorial experiment (all models, 10 seeds)
python scripts/run_full_comparison.py \
    --tasks symbolic special_functions timeseries language long_context \
    --seeds 10 \
    --max-steps 10000 \
    --output-dir ./results \
    --generate-report

Benchmark Tasks

1. Symbolic Regression (KAN-favorable)

Tests function approximation with learnable activations:

  • Basic functions: sin, cos, exp, log, sqrt
  • Special functions: Bessel (J0, J1), Legendre (P2, P3, P4)
  • Deep compositions: sin(exp(cos(x))), nested trigonometrics

2. Time Series Forecasting

Tests temporal pattern recognition:

  • Patterns: sine with trend, multi-seasonal, chaotic, AR process
  • Sequence lengths: 128, 256, 512
  • Prediction horizons: 10, 20 steps

3. Language Modeling (Mamba-favorable)

Tests long-range dependency modeling:

  • Standard sequences: 256, 512 tokens
  • Long-context: 2048, 4096 tokens
  • Mamba's O(n) complexity provides significant advantage

Project Structure

mamba_kan/
β”œβ”€β”€ models/                    # Model implementations
β”‚   β”œβ”€β”€ base.py               # Task-aware base class
β”‚   β”œβ”€β”€ mlp_transformer.py    # MLP-Transformer (baseline)
β”‚   β”œβ”€β”€ bspline_transformer.py # B-spline activation baseline
β”‚   β”œβ”€β”€ kan_transformer.py    # Full KAN-Transformer
β”‚   β”œβ”€β”€ *_mamba.py            # Mamba variants
β”‚   └── components/
β”‚       β”œβ”€β”€ bspline_mlp.py    # Learnable B-spline activation
β”‚       β”œβ”€β”€ kan_layers.py     # KAN building blocks
β”‚       β”œβ”€β”€ mamba_layers.py   # Mamba with B-spline support
β”‚       └── transformer_layers.py
β”œβ”€β”€ training/
β”‚   β”œβ”€β”€ trainer.py            # PyTorch Lightning module
β”‚   β”œβ”€β”€ scheduler.py          # Learning rate schedules
β”‚   └── callbacks.py          # Training monitoring
β”œβ”€β”€ analysis/
β”‚   └── statistics.py         # Friedman, Wilcoxon, bootstrap CI
β”œβ”€β”€ visualization/            # Plotting and dashboards
β”‚   β”œβ”€β”€ plots.py              # Training curves, comparisons
β”‚   β”œβ”€β”€ heatmaps.py           # Statistical visualizations
β”‚   β”œβ”€β”€ animations.py         # GIF generation
β”‚   └── dashboard.py          # Interactive HTML reports
β”œβ”€β”€ data/
β”‚   └── datasets.py           # All benchmark datasets
└── configs/
    └── base_config.py        # Configuration system

scripts/
β”œβ”€β”€ run_experiment.py         # Single experiment runner
β”œβ”€β”€ run_full_comparison.py    # Full factorial comparison
β”œβ”€β”€ generate_assets.py        # Generate README visualizations
└── runpod_setup.sh          # Cloud GPU setup

Statistical Analysis

The project implements rigorous statistical testing following DemΕ‘ar (2006):

  • Friedman Test: Non-parametric comparison across multiple classifiers
  • Wilcoxon Signed-Rank: Pairwise post-hoc comparisons
  • Holm-Bonferroni Correction: Multiple comparison adjustment
  • Bootstrap Confidence Intervals: Effect size uncertainty quantification
from mamba_kan.analysis import run_full_analysis, print_analysis_summary

results = {
    'mlp_transformer': [0.52, 0.51, 0.53, ...],  # Loss per seed
    'bspline_transformer': [0.48, 0.47, 0.49, ...],
    'kan_transformer': [0.45, 0.44, 0.46, ...],
    # ... other models
}

analysis = run_full_analysis(results)
print_analysis_summary(analysis)

Hardware Requirements

Configuration Specification
Minimum NVIDIA GPU, 8GB VRAM, 16GB RAM
Recommended RTX 3080+ or A100, 32GB RAM
Full experiments H100, 80GB VRAM

Cloud Deployment

# RunPod H100 setup
chmod +x scripts/runpod_setup.sh
./scripts/runpod_setup.sh

# Run full experiment suite
python scripts/run_full_comparison.py --task all --seeds 10

Development

# Install dev dependencies
pip install -e ".[dev]"

# Run tests
make test

# Run linting
make lint

# Format code
make format

# Generate visualizations from results
make visualize

Documentation


References

Papers

Resources

  • efficient-kan - Fast KAN implementation (used in this project, ~250Γ— faster than pykan)
  • pykan - Official KAN implementation
  • mamba-ssm - Official Mamba implementation
  • awesome-kan - Comprehensive KAN resources

Citation

@misc{mamba_kan_2025,
    title={Mamba-KAN: A Factorial Comparison of Neural Network Architectures},
    author={Samuel T. Chakwera},
    year={2025},
    url={https://github.com/stchakwdev/Mamba_KAN},
    note={Investigating whether KAN advantages stem from B-spline activations or network topology}
}

License

MIT License - see LICENSE for details.


Back to Top

Made with PyTorch Lightning and scientific rigor

About

A rigorous 2x3 factorial comparison of neural network architectures: KAN vs MLP feedforward layers combined with Transformer vs Mamba sequence models. Investigates whether KAN advantages stem from B-spline activations or network topology.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors