Model Card for Model ID
Model Details
Model Description
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
- Developed by: [Zhi Huang]
- Shared by [optional]: [Zhi Huang]
- Model type: [Chemical reaction prediction]
- License: [MIT]
Model Sources [optional]
- Repository: [More Information Needed]
- Paper [optional]: [More Information Needed]
- Demo [optional]: [More Information Needed]
Uses
import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForSeq2SeqLM
from collections import Counter
import re
import time
# Optional: RDKit for advanced chemical validation
try:
from rdkit import Chem
from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
from rdkit.DataStructs import TanimotoSimilarity
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*') # Disable RDKit warnings
RDKIT_AVAILABLE = True
print("RDKit found. Advanced chemical validation will be used.")
except ImportError:
RDKIT_AVAILABLE = False
print("RDKit not found. Using basic validation methods.")
class ChemicalReactionPredictor:
def __init__(self, model_name):
"""Initialize the chemical reaction predictor with enhanced validation"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoPeftModelForSeq2SeqLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto"
)
# Configure model tokens
self.model.config.bos_token_id = self.tokenizer.bos_token_id
self.model.config.eos_token_id = self.tokenizer.eos_token_id
self.model.config.pad_token_id = self.tokenizer.pad_token_id
self.model.config.decoder_start_token_id = self.tokenizer.pad_token_id
self.model.eval()
# Scoring weights for comprehensive evaluation
self.MODEL_SCORE_WEIGHT = 0.3
self.VALIDITY_SCORE_WEIGHT = 0.2
self.ATOM_BALANCE_SCORE_WEIGHT = 0.25
self.CHARGE_BALANCE_SCORE_WEIGHT = 0.25
def canonicalize_smiles(self, smiles):
"""Canonicalize SMILES using RDKit if available"""
if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str):
return str(smiles).strip() if isinstance(smiles, str) else ""
try:
mol = Chem.MolFromSmiles(smiles, sanitize=True)
if mol:
return Chem.MolToSmiles(mol, canonical=True)
# Try without sanitization
mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False)
if mol_no_sanitize:
return Chem.MolToSmiles(mol_no_sanitize, canonical=True)
return smiles.strip()
except Exception:
return smiles.strip()
def score_chemical_validity(self, smiles):
"""Score chemical validity of SMILES (0.0 to 1.0)"""
if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str):
return 0.1 # Basic validity for non-empty strings
try:
mol = Chem.MolFromSmiles(smiles, sanitize=True)
if mol:
return 1.0 # Fully valid
mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False)
if mol_no_sanitize:
return 0.5 # Partially valid
return 0.1 # Invalid but parseable
except Exception:
return 0.0 # Completely invalid
def get_atom_counts_from_smiles(self, smiles):
"""Get atom counts from SMILES string"""
if not smiles or not isinstance(smiles, str):
return Counter()
smiles = smiles.strip()
if not smiles:
return Counter()
if RDKIT_AVAILABLE:
try:
mol = Chem.MolFromSmiles(smiles, sanitize=True)
if mol:
atom_counts = Counter()
for atom in mol.GetAtoms():
symbol = atom.GetSymbol()
atom_counts[symbol] += 1
return atom_counts
except Exception:
pass
try:
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol:
atom_counts = Counter()
for atom in mol.GetAtoms():
symbol = atom.GetSymbol()
atom_counts[symbol] += 1
return atom_counts
except Exception:
pass
# Fallback regex method
return self._regex_atom_count_fallback(smiles)
def _regex_atom_count_fallback(self, smiles):
"""Fallback atom counting using regex"""
atom_counts = Counter()
atom_pattern = r'([A−Z][a−z]?)(?:[+−]?\d∗|\.)∗([A-Z][a-z]?)(?:[+-]?\d*|\.)*|([A-Z][a-z]?)'
matches = re.findall(atom_pattern, smiles)
for match in matches:
atom = match[0] if match[0] else match[1]
if atom:
atom_counts[atom] += 1
return atom_counts
def get_charge_from_smiles(self, smiles):
"""Extract total charge from SMILES"""
if not smiles or not isinstance(smiles, str):
return 0
total_charge = 0
if RDKIT_AVAILABLE:
try:
mol = Chem.MolFromSmiles(smiles)
if mol:
total_charge = Chem.rdmolops.GetFormalCharge(mol)
return total_charge
except:
pass
# Regex fallback
charge_pattern = r'[A−Z][a−z]?[]]∗?([+−]\d∗)[A-Z][a-z]?[^]]*?([+-]\d*)'
matches = re.findall(charge_pattern, smiles)
for charge_str in matches:
if charge_str == '+':
total_charge += 1
elif charge_str == '-':
total_charge -= 1
else:
try:
total_charge += int(charge_str)
except ValueError:
pass
return total_charge
def parse_multi_component_smiles(self, multi_smiles):
"""Parse multi-component SMILES separated by dots"""
if not multi_smiles or not isinstance(multi_smiles, str):
return []
components = [comp.strip() for comp in multi_smiles.split('.') if comp.strip()]
return components
def check_balance(self, reactant_smiles, product_smiles):
"""Check atom balance and charge balance"""
reactant_components = self.parse_multi_component_smiles(reactant_smiles)
product_components = self.parse_multi_component_smiles(product_smiles)
# Atom balance check
reactant_atoms = Counter()
for component in reactant_components:
component_atoms = self.get_atom_counts_from_smiles(component)
reactant_atoms.update(component_atoms)
product_atoms = Counter()
for component in product_components:
component_atoms = self.get_atom_counts_from_smiles(component)
product_atoms.update(component_atoms)
missing_in_products = reactant_atoms - product_atoms
extra_in_products = product_atoms - reactant_atoms
missing_in_products = +missing_in_products
extra_in_products = +extra_in_products
is_atom_balanced = len(missing_in_products) == 0 and len(extra_in_products) == 0
total_reactant_atoms = sum(reactant_atoms.values())
total_imbalance = sum(missing_in_products.values()) + sum(extra_in_products.values())
if total_reactant_atoms == 0:
atom_balance_score = 0.0
else:
atom_balance_score = max(0.0, 1.0 - (total_imbalance / total_reactant_atoms))
# Charge balance check
reactant_charge = sum(self.get_charge_from_smiles(comp) for comp in reactant_components)
product_charge = sum(self.get_charge_from_smiles(comp) for comp in product_components)
charge_difference = abs(reactant_charge - product_charge)
is_charge_balanced = (charge_difference == 0)
# Charge balance score calculation
if charge_difference == 0:
charge_balance_score = 1.0
elif charge_difference == 1:
charge_balance_score = 0.7
elif charge_difference == 2:
charge_balance_score = 0.4
else:
charge_balance_score = max(0.0, 1.0 - charge_difference * 0.3)
return {
'is_atom_balanced': is_atom_balanced,
'reactant_atoms': reactant_atoms,
'product_atoms': product_atoms,
'missing_in_products': missing_in_products,
'extra_in_products': extra_in_products,
'atom_balance_score': atom_balance_score,
'is_charge_balanced': is_charge_balanced,
'reactant_charge': reactant_charge,
'product_charge': product_charge,
'charge_difference': charge_difference,
'charge_balance_score': charge_balance_score
}
def predict_with_chemical_enhancement(self, reactant_smiles, num_return_sequences=20,
max_length=256, num_beams=None,
temperature=1.0, do_sample=False,
top_k=50, top_p=0.95):
"""
Predict reaction products with chemical enhancement and comprehensive scoring
"""
if num_beams is None:
num_beams = num_return_sequences + 5
print(f"Input reactants: {reactant_smiles}")
print("-" * 80)
# Tokenize input
inputs = self.tokenizer(reactant_smiles, return_tensors="pt",
max_length=max_length, truncation=True)
# Move to GPU if available
device = next(self.model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
start_time = time.time()
# Generate predictions
with torch.no_grad():
outputs = self.model.generate(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
max_length=max_length,
num_return_sequences=num_return_sequences,
num_beams=num_beams,
do_sample=do_sample,
temperature=temperature if do_sample else 1.0,
top_k=top_k if do_sample else None,
top_p=top_p if do_sample else None,
early_stopping=True,
eos_token_id=self.model.config.eos_token_id,
pad_token_id=self.model.config.pad_token_id,
decoder_start_token_id=self.model.config.decoder_start_token_id,
return_dict_in_generate=True,
output_scores=True
)
generation_time = time.time() - start_time
print(f"Generation time: {generation_time:.3f} seconds")
# Decode predictions
generated_sequences = outputs.sequences
raw_predictions = self.tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)
# Process and score predictions
predictions_data = []
seen_canonical = set()
for i, pred_raw in enumerate(raw_predictions):
# Model score based on generation order
model_score = 1.0 - (i / len(raw_predictions))
# Canonicalize SMILES
canonical_pred = self.canonicalize_smiles(pred_raw)
# Skip duplicates
if canonical_pred in seen_canonical:
continue
seen_canonical.add(canonical_pred)
# Chemical validity score
validity_score = self.score_chemical_validity(canonical_pred)
# Balance scores
atom_balance_score = 0.0
charge_balance_score = 0.0
balance_info = {}
if canonical_pred:
balance_result = self.check_balance(reactant_smiles, canonical_pred)
atom_balance_score = balance_result['atom_balance_score']
charge_balance_score = balance_result['charge_balance_score']
balance_info = balance_result
# Comprehensive score
comprehensive_score = (
model_score * self.MODEL_SCORE_WEIGHT +
validity_score * self.VALIDITY_SCORE_WEIGHT +
atom_balance_score * self.ATOM_BALANCE_SCORE_WEIGHT +
charge_balance_score * self.CHARGE_BALANCE_SCORE_WEIGHT
)
predictions_data.append({
'rank': len(predictions_data) + 1,
'raw_smiles': pred_raw,
'canonical_smiles': canonical_pred,
'model_score': model_score,
'validity_score': validity_score,
'atom_balance_score': atom_balance_score,
'charge_balance_score': charge_balance_score,
'comprehensive_score': comprehensive_score,
'balance_info': balance_info
})
# Sort by comprehensive score (chemical enhancement)
predictions_data.sort(key=lambda x: x['comprehensive_score'], reverse=True)
# Display results
print("\n" + "="*80)
print("ORIGINAL MODEL RANKING vs CHEMICALLY ENHANCED RANKING")
print("="*80)
print("\nORIGINAL MODEL RANKING (Top 5):")
print("-" * 50)
for i, pred in enumerate(predictions_data[:5]):
original_rank = pred['rank']
print(f"{i+1:2d}. {pred['canonical_smiles']}")
print(f" Original Rank: #{original_rank}, Model Score: {pred['model_score']:.3f}")
print(f" Validity: {pred['validity_score']:.3f}, "
f"Atom Balance: {pred['atom_balance_score']:.3f}, "
f"Charge Balance: {pred['charge_balance_score']:.3f}")
# Chemical analysis
if pred['balance_info']:
balance = pred['balance_info']
print(f" Chemical Analysis:")
print(f" Reactant atoms: {dict(balance['reactant_atoms'])}")
print(f" Product atoms: {dict(balance['product_atoms'])}")
if balance['missing_in_products']:
print(f" Missing atoms: {dict(balance['missing_in_products'])}")
if balance['extra_in_products']:
print(f" Extra atoms: {dict(balance['extra_in_products'])}")
print(f" Charge: {balance['reactant_charge']} → {balance['product_charge']} "
f"(Δ{balance['charge_difference']})")
print()
print("\nCHEMICALLY ENHANCED RANKING (Top 5):")
print("-" * 50)
enhanced_top5 = sorted(predictions_data, key=lambda x: x['comprehensive_score'], reverse=True)[:5]
for i, pred in enumerate(enhanced_top5):
original_rank = pred['rank']
improvement = f"↑{original_rank - (i+1)}" if original_rank > (i+1) else f"↓{(i+1) - original_rank}" if original_rank < (i+1) else "="
print(f"{i+1:2d}. {pred['canonical_smiles']}")
print(f" Rank Change: #{original_rank} → #{i+1} ({improvement})")
print(f" Comprehensive Score: {pred['comprehensive_score']:.3f}")
print(f" Component Scores - Model: {pred['model_score']:.3f}, "
f"Validity: {pred['validity_score']:.3f}, "
f"Atom: {pred['atom_balance_score']:.3f}, "
f"Charge: {pred['charge_balance_score']:.3f}")
# Quality indicators
quality_indicators = []
if pred['validity_score'] >= 0.9:
quality_indicators.append("✓Valid")
if pred['atom_balance_score'] >= 0.9:
quality_indicators.append("✓Balanced")
if pred['charge_balance_score'] >= 0.9:
quality_indicators.append("✓Charge OK")
if quality_indicators:
print(f" Quality: {' '.join(quality_indicators)}")
print()
# Summary statistics
print("="*80)
print("SUMMARY STATISTICS")
print("="*80)
print(f"Total unique predictions: {len(predictions_data)}")
print(f"Average validity score: {sum(p['validity_score'] for p in predictions_data) / len(predictions_data):.3f}")
print(f"Average atom balance score: {sum(p['atom_balance_score'] for p in predictions_data) / len(predictions_data):.3f}")
print(f"Average charge balance score: {sum(p['charge_balance_score'] for p in predictions_data) / len(predictions_data):.3f}")
valid_predictions = sum(1 for p in predictions_data if p['validity_score'] >= 0.9)
balanced_predictions = sum(1 for p in predictions_data if p['atom_balance_score'] >= 0.9)
charge_balanced_predictions = sum(1 for p in predictions_data if p['charge_balance_score'] >= 0.9)
print(f"High validity predictions (≥0.9): {valid_predictions}/{len(predictions_data)} ({valid_predictions/len(predictions_data)*100:.1f}%)")
print(f"Well-balanced predictions (≥0.9): {balanced_predictions}/{len(predictions_data)} ({balanced_predictions/len(predictions_data)*100:.1f}%)")
print(f"Charge-balanced predictions (≥0.9): {charge_balanced_predictions}/{len(predictions_data)} ({charge_balanced_predictions/len(predictions_data)*100:.1f}%)")
return enhanced_top5, predictions_data
# Initialize the predictor
print("Initializing Chemical Reaction Predictor...")
model_name = "ML4chemistry/chemical-reaction-t5-v3"
predictor = ChemicalReactionPredictor(model_name)
# Test with your example
reactant_smiles = "CC(C)(C)N=N"
# Run prediction with chemical enhancement
top_predictions, all_predictions = predictor.predict_with_chemical_enhancement(
reactant_smiles,
num_return_sequences=20,
num_beams=25,
max_length=256
)
print("\n" + "="*80)
print("FINAL TOP 3 CHEMICALLY ENHANCED PREDICTIONS")
print("="*80)
for i, pred in enumerate(top_predictions[:3], 1):
print(f"{i}. {pred['canonical_smiles']}")
print(f" Score: {pred['comprehensive_score']:.3f}")
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support