This repository contains the modular implementation of SAMBA, a novel architecture that combines State-space Mamba models with Graph Neural Networks for stock price prediction. This work is based on the paper "Mamba Meets Financial Markets: A Graph-Mamba Approach for Stock Price Prediction" accepted for publication in IEEE ICASSP 2025.
๐ Original Paper Repository: https://github.com/Ali-Meh619/SAMBA
SAMBA (State-space Mamba with Graph Neural Networks) is designed for stock price prediction using real-world financial market data. The model leverages:
- ๐ง Mamba blocks for efficient sequence modeling with selective state spaces
- ๐ธ๏ธ Graph Neural Networks with Chebyshev polynomials for spatial relationships
- ๐ Gaussian kernel-based adjacency matrices for adaptive graph learning
- โก Bidirectional processing for enhanced temporal understanding
The model consists of several key components:
- ๐ง Mamba Backbone: Processes temporal sequences using selective state space models
- ๐ธ๏ธ Graph Convolution Layers: Capture spatial dependencies using Chebyshev polynomials
- ๐ Adaptive Adjacency Matrix: Learns graph structure using Gaussian kernels
- ๐ Residual Connections: Enable deep network training with skip connections
- Clone the repository:
git clone <repository-url>
cd samba-stock-prediction- Install dependencies:
pip install -r requirements.txt- Ensure you have a CUDA-compatible GPU for optimal performance. ๐ฎ
-
๐ฆ Install dependencies:
pip install -r requirements.txt
-
๐ Create Dataset folder and add your CSV files:
mkdir Dataset # Copy your CSV files to the Dataset folder: # - combined_dataframe_IXIC.csv # - combined_dataframe_NYSE.csv # - combined_dataframe_DJI.csv
-
๐ Run the model:
python main.py
Run the model:
from main import main
# Run training with paper configuration
main()You can modify the training configuration in main.py or paper_config.py:
# In main.py, the configuration is loaded from paper_config.py
model_args, config = get_paper_config()
# You can modify the config before training
config.epochs = 500 # Reduce epochs for faster training
config.batch_size = 64 # Increase batch size
โ ๏ธ Important: Create aDatasetfolder and place your CSV files in it.
The model expects CSV data with the following format:
- ๐ Date column as index
- ๐ท๏ธ Name column (will be removed during preprocessing)
- ๐ฐ Price column for target values
- ๐ Additional feature columns (technical indicators, market data, etc.)
Example:
Date,Name,Price,Volume,RSI,MACD,...
2023-01-01,IXIC,100.0,1000000,0.5,0.3,...
2023-01-02,IXIC,101.0,1200000,0.6,0.4,...๐ก Note: The
num_nodesparameter is automatically determined from the input data shape (number of features), so you don't need to specify it manually.
This repository is configured to work with three real-world datasets from the US stock market with 82 daily stock features:
๐ Folder Structure:
Dataset/
โโโ combined_dataframe_IXIC.csv # ๐ NASDAQ Composite Index
โโโ combined_dataframe_NYSE.csv # ๐๏ธ New York Stock Exchange
โโโ combined_dataframe_DJI.csv # ๐ Dow Jones Industrial Average
๐
Dataset Period: January 2010 to November 2023
๐ข Features: 82 daily stock features including technical indicators, market data, and financial metrics
Each dataset contains comprehensive historical price data with multiple technical indicators as features, providing rich information for the Graph-Mamba model to learn complex market patterns.
- ๐
config/: Configuration classes for model and training parameters - ๐ง
models/: Model implementations (SAMBA, Mamba, Graph layers) - ๐ ๏ธ
utils/: Utility functions (data loading, metrics, logging) - ๐
trainer/: Training loop and evaluation
- ๐
SAMBA: Main model combining Mamba and GNN - ๐ง
Mamba: State-space sequence model - ๐
MambaBlock: Individual Mamba block with selective scanning - ๐ธ๏ธ
gconv: Graph convolution with Chebyshev polynomials - ๐
Trainer: Training and evaluation pipeline
The model evaluates performance using:
- ๐ MAE: Mean Absolute Error
- ๐ RMSE: Root Mean Squared Error
- ๐ IC: Information Coefficient (Pearson correlation)
- ๐ RIC: Rank Information Coefficient (Spearman correlation)
- ๐ข
d_model: Model dimension - ๐
n_layer: Number of Mamba layers - ๐ฏ
vocab_size: Number of features (automatically determined from input data) - ๐ฅ
seq_in: Input sequence length - ๐ค
seq_out: Output sequence length - ๐
d_state: State dimension - ๐
expand: Expansion factor - ๐งฎ
cheb_k: Chebyshev polynomial order
- ๐
epochs: Number of training epochs - ๐
lr_init: Initial learning rate - ๐ฆ
batch_size: Training batch size - โน๏ธ
early_stop: Enable early stopping - โฐ
early_stop_patience: Early stopping patience
The model outputs results to:
- ๐
samba_results.txt: Performance metrics - ๐พ
./best_model.pth: Best model checkpoint - ๐บ Console logs: Training progress and final metrics
โโโ ๐ Dataset/ # Put your CSV files here
โ โโโ ๐ combined_dataframe_IXIC.csv
โ โโโ ๐๏ธ combined_dataframe_NYSE.csv
โ โโโ ๐ combined_dataframe_DJI.csv
โโโ ๐ config/
โ โโโ __init__.py
โ โโโ model_config.py
โโโ ๐ models/
โ โโโ __init__.py
โ โโโ ๐ samba.py
โ โโโ ๐ง mamba.py
โ โโโ ๐ mamba_block.py
โ โโโ ๐ธ๏ธ graph_layers.py
โ โโโ ๐ normalization.py
โโโ ๐ utils/
โ โโโ __init__.py
โ โโโ ๐ data_utils.py
โ โโโ ๐ metrics.py
โ โโโ ๐ logger.py
โ โโโ ๐ ๏ธ model_utils.py
โโโ ๐ trainer/
โ โโโ __init__.py
โ โโโ ๐ trainer.py
โโโ ๐ main.py # Main execution file
โโโ ๐ paper_config.py # Paper-specific configuration
โโโ ๐งช test_system.py # System test
โโโ ๐ฆ requirements.txt
โโโ ๐ README.md
- ๐ฅ PyTorch >= 1.9.0
- ๐ข NumPy >= 1.21.0
- ๐ผ Pandas >= 1.3.0
- ๐ Matplotlib >= 3.4.0
- ๐งฎ einops >= 0.4.0
- ๐พ h5py >= 3.1.0
If you find our paper and code useful, please kindly cite our paper as follows:
@inproceedings{SAMBA,
title={Mamba Meets Financial Markets: {A} {G}raph-{M}amba Approach for Stock Price Prediction},
author={Mehrabian, Ali and Hoseinzade, Ehsan and Mazloum, Mahdi and Chen, Xiaohong},
booktitle={Proc. IEEE Int. Conf. Acoust., Speech, Signal Process. (ICASSP)},
address={Hyderabad, India},
month={Apr.},
year={2025}
}๐ Paper: "Mamba Meets Financial Markets: A Graph-Mamba Approach for Stock Price Prediction"
๐๏ธ Conference: IEEE ICASSP 2025
๐ฅ Authors: Ali Mehrabian, Ehsan Hoseinzade, Mahdi Mazloum, Xiaohong Chen
Please feel free to contact us if you have any questions:
- ๐จโ๐ป Ali Mehrabian: alimehrabian619@yahoo.com, ali.mehrabian@vectorinstitute.ai
- ๐ Original Repository: https://github.com/Ali-Meh619/SAMBA
- ๐ด Fork the repository
- ๐ฟ Create a feature branch
- โ๏ธ Make your changes
- ๐งช Add tests if applicable
- ๐ค Submit a pull request
If you encounter any issues, please:
- ๐ Check the existing issues
- ๐ Create a new issue with detailed description
- ๐ป Include system information and error logs
โญ If you found this project helpful, please give it a star! โญ
Made with โค๏ธ for the financial AI community