Skip to content

Ali-Meh619/SAMBA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

51 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

SAMBA: A Graph-Mamba Approach for Stock Price Prediction
Paper | Dataset

Title of the Picture

๐Ÿ“– About

This repository contains the modular implementation of SAMBA, a novel architecture that combines State-space Mamba models with Graph Neural Networks for stock price prediction. This work is based on the paper "Mamba Meets Financial Markets: A Graph-Mamba Approach for Stock Price Prediction" accepted for publication in IEEE ICASSP 2025.

๐Ÿ”— Original Paper Repository: https://github.com/Ali-Meh619/SAMBA

๐ŸŽฏ Overview

SAMBA (State-space Mamba with Graph Neural Networks) is designed for stock price prediction using real-world financial market data. The model leverages:

  • ๐Ÿง  Mamba blocks for efficient sequence modeling with selective state spaces
  • ๐Ÿ•ธ๏ธ Graph Neural Networks with Chebyshev polynomials for spatial relationships
  • ๐Ÿ“Š Gaussian kernel-based adjacency matrices for adaptive graph learning
  • โšก Bidirectional processing for enhanced temporal understanding

๐Ÿ—๏ธ Architecture

The model consists of several key components:

  1. ๐Ÿง  Mamba Backbone: Processes temporal sequences using selective state space models
  2. ๐Ÿ•ธ๏ธ Graph Convolution Layers: Capture spatial dependencies using Chebyshev polynomials
  3. ๐Ÿ“Š Adaptive Adjacency Matrix: Learns graph structure using Gaussian kernels
  4. ๐Ÿ”— Residual Connections: Enable deep network training with skip connections

๐Ÿš€ Installation

  1. Clone the repository:
git clone <repository-url>
cd samba-stock-prediction
  1. Install dependencies:
pip install -r requirements.txt
  1. Ensure you have a CUDA-compatible GPU for optimal performance. ๐ŸŽฎ

๐Ÿ’ป Usage

โšก Quick Start

  1. ๐Ÿ“ฆ Install dependencies:

    pip install -r requirements.txt
  2. ๐Ÿ“ Create Dataset folder and add your CSV files:

    mkdir Dataset
    # Copy your CSV files to the Dataset folder:
    # - combined_dataframe_IXIC.csv
    # - combined_dataframe_NYSE.csv  
    # - combined_dataframe_DJI.csv
  3. ๐Ÿš€ Run the model:

    python main.py

๐ŸŽฏ Basic Training

Run the model:

from main import main

# Run training with paper configuration
main()

โš™๏ธ Custom Configuration

You can modify the training configuration in main.py or paper_config.py:

# In main.py, the configuration is loaded from paper_config.py
model_args, config = get_paper_config()

# You can modify the config before training
config.epochs = 500  # Reduce epochs for faster training
config.batch_size = 64  # Increase batch size

๐Ÿ“Š Data Format

โš ๏ธ Important: Create a Dataset folder and place your CSV files in it.

The model expects CSV data with the following format:

  • ๐Ÿ“… Date column as index
  • ๐Ÿท๏ธ Name column (will be removed during preprocessing)
  • ๐Ÿ’ฐ Price column for target values
  • ๐Ÿ“ˆ Additional feature columns (technical indicators, market data, etc.)

Example:

Date,Name,Price,Volume,RSI,MACD,...
2023-01-01,IXIC,100.0,1000000,0.5,0.3,...
2023-01-02,IXIC,101.0,1200000,0.6,0.4,...

๐Ÿ’ก Note: The num_nodes parameter is automatically determined from the input data shape (number of features), so you don't need to specify it manually.

๐Ÿ“ˆ Available Datasets

This repository is configured to work with three real-world datasets from the US stock market with 82 daily stock features:

๐Ÿ“ Folder Structure:

Dataset/
โ”œโ”€โ”€ combined_dataframe_IXIC.csv    # ๐Ÿ“Š NASDAQ Composite Index
โ”œโ”€โ”€ combined_dataframe_NYSE.csv    # ๐Ÿ›๏ธ New York Stock Exchange
โ””โ”€โ”€ combined_dataframe_DJI.csv     # ๐Ÿ“ˆ Dow Jones Industrial Average

๐Ÿ“… Dataset Period: January 2010 to November 2023
๐Ÿ”ข Features: 82 daily stock features including technical indicators, market data, and financial metrics

Each dataset contains comprehensive historical price data with multiple technical indicators as features, providing rich information for the Graph-Mamba model to learn complex market patterns.

๐Ÿงฉ Model Components

๐Ÿ”ง Core Modules

  • ๐Ÿ“‹ config/: Configuration classes for model and training parameters
  • ๐Ÿง  models/: Model implementations (SAMBA, Mamba, Graph layers)
  • ๐Ÿ› ๏ธ utils/: Utility functions (data loading, metrics, logging)
  • ๐Ÿƒ trainer/: Training loop and evaluation

๐ŸŽฏ Key Classes

  • ๐Ÿš€ SAMBA: Main model combining Mamba and GNN
  • ๐Ÿง  Mamba: State-space sequence model
  • ๐Ÿ”— MambaBlock: Individual Mamba block with selective scanning
  • ๐Ÿ•ธ๏ธ gconv: Graph convolution with Chebyshev polynomials
  • ๐Ÿƒ Trainer: Training and evaluation pipeline

๐Ÿ“Š Metrics

The model evaluates performance using:

  • ๐Ÿ“ MAE: Mean Absolute Error
  • ๐Ÿ“ RMSE: Root Mean Squared Error
  • ๐Ÿ”— IC: Information Coefficient (Pearson correlation)
  • ๐Ÿ“ˆ RIC: Rank Information Coefficient (Spearman correlation)

โš™๏ธ Configuration

๐Ÿง  Model Parameters

  • ๐Ÿ”ข d_model: Model dimension
  • ๐Ÿ“š n_layer: Number of Mamba layers
  • ๐ŸŽฏ vocab_size: Number of features (automatically determined from input data)
  • ๐Ÿ“ฅ seq_in: Input sequence length
  • ๐Ÿ“ค seq_out: Output sequence length
  • ๐Ÿ  d_state: State dimension
  • ๐Ÿ“ˆ expand: Expansion factor
  • ๐Ÿงฎ cheb_k: Chebyshev polynomial order

๐Ÿƒ Training Parameters

  • ๐Ÿ”„ epochs: Number of training epochs
  • ๐Ÿ“š lr_init: Initial learning rate
  • ๐Ÿ“ฆ batch_size: Training batch size
  • โน๏ธ early_stop: Enable early stopping
  • โฐ early_stop_patience: Early stopping patience

๐Ÿ“ˆ Results

The model outputs results to:

  • ๐Ÿ“„ samba_results.txt: Performance metrics
  • ๐Ÿ’พ ./best_model.pth: Best model checkpoint
  • ๐Ÿ“บ Console logs: Training progress and final metrics

๐Ÿ“ File Structure

โ”œโ”€โ”€ ๐Ÿ“‚ Dataset/                # Put your CSV files here
โ”‚   โ”œโ”€โ”€ ๐Ÿ“Š combined_dataframe_IXIC.csv
โ”‚   โ”œโ”€โ”€ ๐Ÿ›๏ธ combined_dataframe_NYSE.csv
โ”‚   โ””โ”€โ”€ ๐Ÿ“ˆ combined_dataframe_DJI.csv
โ”œโ”€โ”€ ๐Ÿ“‚ config/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ””โ”€โ”€ model_config.py
โ”œโ”€โ”€ ๐Ÿ“‚ models/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”œโ”€โ”€ ๐Ÿš€ samba.py
โ”‚   โ”œโ”€โ”€ ๐Ÿง  mamba.py
โ”‚   โ”œโ”€โ”€ ๐Ÿ”— mamba_block.py
โ”‚   โ”œโ”€โ”€ ๐Ÿ•ธ๏ธ graph_layers.py
โ”‚   โ””โ”€โ”€ ๐Ÿ“ normalization.py
โ”œโ”€โ”€ ๐Ÿ“‚ utils/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ”œโ”€โ”€ ๐Ÿ“Š data_utils.py
โ”‚   โ”œโ”€โ”€ ๐Ÿ“ˆ metrics.py
โ”‚   โ”œโ”€โ”€ ๐Ÿ“ logger.py
โ”‚   โ””โ”€โ”€ ๐Ÿ› ๏ธ model_utils.py
โ”œโ”€โ”€ ๐Ÿ“‚ trainer/
โ”‚   โ”œโ”€โ”€ __init__.py
โ”‚   โ””โ”€โ”€ ๐Ÿƒ trainer.py
โ”œโ”€โ”€ ๐Ÿš€ main.py                 # Main execution file
โ”œโ”€โ”€ ๐Ÿ“‹ paper_config.py         # Paper-specific configuration
โ”œโ”€โ”€ ๐Ÿงช test_system.py          # System test
โ”œโ”€โ”€ ๐Ÿ“ฆ requirements.txt
โ””โ”€โ”€ ๐Ÿ“– README.md

๐Ÿ“ฆ Dependencies

  • ๐Ÿ”ฅ PyTorch >= 1.9.0
  • ๐Ÿ”ข NumPy >= 1.21.0
  • ๐Ÿผ Pandas >= 1.3.0
  • ๐Ÿ“Š Matplotlib >= 3.4.0
  • ๐Ÿงฎ einops >= 0.4.0
  • ๐Ÿ’พ h5py >= 3.1.0

๐Ÿ“š Citation

If you find our paper and code useful, please kindly cite our paper as follows:

@inproceedings{SAMBA,
  title={Mamba Meets Financial Markets: {A} {G}raph-{M}amba Approach for Stock Price Prediction},
  author={Mehrabian, Ali and Hoseinzade, Ehsan and Mazloum, Mahdi and Chen, Xiaohong},
  booktitle={Proc. IEEE Int. Conf. Acoust., Speech, Signal Process. (ICASSP)},
address={Hyderabad, India},
  month={Apr.},
  year={2025}
}

๐Ÿ“„ Paper: "Mamba Meets Financial Markets: A Graph-Mamba Approach for Stock Price Prediction"
๐Ÿ›๏ธ Conference: IEEE ICASSP 2025
๐Ÿ‘ฅ Authors: Ali Mehrabian, Ehsan Hoseinzade, Mahdi Mazloum, Xiaohong Chen

๐Ÿ“ž Contact

Please feel free to contact us if you have any questions:

๐Ÿค Contributing

  1. ๐Ÿด Fork the repository
  2. ๐ŸŒฟ Create a feature branch
  3. โœ๏ธ Make your changes
  4. ๐Ÿงช Add tests if applicable
  5. ๐Ÿ“ค Submit a pull request

๐Ÿ› Issues

If you encounter any issues, please:

  1. ๐Ÿ” Check the existing issues
  2. ๐Ÿ“ Create a new issue with detailed description
  3. ๐Ÿ’ป Include system information and error logs

โญ If you found this project helpful, please give it a star! โญ

Made with โค๏ธ for the financial AI community

About

This repository hosts the code for the SAMBA model, proposed in our IEEE ICASSP paper "Mamba Meets Financial Markets: A Graph-Mamba Approach for Stock Price Prediction".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages