A deep learning model for synthetic accessibility prediction based on molecular prices.
Clone the repository and create a virtual environment with conda:
# Get the code
git clone https://github.com/fredhastedt/MolPrice.git
cd MolPrice
# Create environment
conda env create -f molprice.yml
conda activate molprice
We provide model checkpoints for MolPrice via Figshare . One can choose from the following models:
- SECFP fingerprint (with or w/o 2D features)
- Morgan Fingerprint (with or w/o 2D features)
Once the model is downloaded, place in ./models directory.
One can run the code per molecule or using batch prediction. In case of batch prediction, please first save all molecules in a .csv file.
# Single molecule prediction
python -m bin.predict --mol "CC(=O)OC1=CC=CC=C1C(=O)O" --cn MP_SECFP_hybrid
# Batch prediction
python -m bin.predict --mol molecules.csv --cn MP_SECFP_hybrid --smiles-col SMILES_COLUMN
The test datasets for SA comparison can be obtained from Figshare via test files. Once the files are downloaded, place within ./testing directory.
The results for each test dataset can be obtained by running:
python -m bin.test main_ood --model Fingerprint --cn MODEL_CHECKPOINT --test_name TEST_FILE1,TEST_FILE2 --combined
For example, if one downloaded the MP_SECFP_hybrid model and saved the test files 3 as follows: TS3_hs.csv and TS3_es.csv, one can run:
python -m bin.test main_ood --model Fingerprint --cn MP_SECFP_hybrid/best.ckpt --test_name TS3_hs.csv,TS3_es.csv --combined
If one has access to a database containing molecules along with their prices, one can run the following script to train their own model (given that prices are in log(USD)/mmol):
python -m bin.train --model MODEL_TYPE --fp FINGERPRINT_TYPE
Within the script, the following arguments can be adjusted:
- model: Choose between [Fingerprint, RoBERTa, Transformer, LSTM_EFG]
- fp: Choose between [atom, rdkit, morgan, mhfp] (mhfp is the SECFP fingerprint encoder)
In one has a pre-trained Fingerprint model, one can train the model on the contrastive loss by calling:
python -m bin.train --model Fingerprint --fp FINGERPRINT_TYPE --combined --cn MODEL_CHECKPOINT