--- library_name: transformers tags: [] --- # Model Card for Model ID ## Model Details ### Model Description This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated. - **Developed by:** [Zhi Huang] - **Shared by [optional]:** [Zhi Huang] - **Model type:** [Chemical reaction prediction] - **License:** [MIT] ### Model Sources [optional] - **Repository:** [More Information Needed] - **Paper [optional]:** [More Information Needed] - **Demo [optional]:** [More Information Needed] ## Uses ```python import torch from transformers import AutoTokenizer from peft import AutoPeftModelForSeq2SeqLM from collections import Counter import re import time # Optional: RDKit for advanced chemical validation try: from rdkit import Chem from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect from rdkit.DataStructs import TanimotoSimilarity from rdkit import RDLogger RDLogger.DisableLog('rdApp.*') # Disable RDKit warnings RDKIT_AVAILABLE = True print("RDKit found. Advanced chemical validation will be used.") except ImportError: RDKIT_AVAILABLE = False print("RDKit not found. Using basic validation methods.") class ChemicalReactionPredictor: def __init__(self, model_name): """Initialize the chemical reaction predictor with enhanced validation""" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoPeftModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) # Configure model tokens self.model.config.bos_token_id = self.tokenizer.bos_token_id self.model.config.eos_token_id = self.tokenizer.eos_token_id self.model.config.pad_token_id = self.tokenizer.pad_token_id self.model.config.decoder_start_token_id = self.tokenizer.pad_token_id self.model.eval() # Scoring weights for comprehensive evaluation self.MODEL_SCORE_WEIGHT = 0.3 self.VALIDITY_SCORE_WEIGHT = 0.2 self.ATOM_BALANCE_SCORE_WEIGHT = 0.25 self.CHARGE_BALANCE_SCORE_WEIGHT = 0.25 def canonicalize_smiles(self, smiles): """Canonicalize SMILES using RDKit if available""" if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str): return str(smiles).strip() if isinstance(smiles, str) else "" try: mol = Chem.MolFromSmiles(smiles, sanitize=True) if mol: return Chem.MolToSmiles(mol, canonical=True) # Try without sanitization mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False) if mol_no_sanitize: return Chem.MolToSmiles(mol_no_sanitize, canonical=True) return smiles.strip() except Exception: return smiles.strip() def score_chemical_validity(self, smiles): """Score chemical validity of SMILES (0.0 to 1.0)""" if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str): return 0.1 # Basic validity for non-empty strings try: mol = Chem.MolFromSmiles(smiles, sanitize=True) if mol: return 1.0 # Fully valid mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False) if mol_no_sanitize: return 0.5 # Partially valid return 0.1 # Invalid but parseable except Exception: return 0.0 # Completely invalid def get_atom_counts_from_smiles(self, smiles): """Get atom counts from SMILES string""" if not smiles or not isinstance(smiles, str): return Counter() smiles = smiles.strip() if not smiles: return Counter() if RDKIT_AVAILABLE: try: mol = Chem.MolFromSmiles(smiles, sanitize=True) if mol: atom_counts = Counter() for atom in mol.GetAtoms(): symbol = atom.GetSymbol() atom_counts[symbol] += 1 return atom_counts except Exception: pass try: mol = Chem.MolFromSmiles(smiles, sanitize=False) if mol: atom_counts = Counter() for atom in mol.GetAtoms(): symbol = atom.GetSymbol() atom_counts[symbol] += 1 return atom_counts except Exception: pass # Fallback regex method return self._regex_atom_count_fallback(smiles) def _regex_atom_count_fallback(self, smiles): """Fallback atom counting using regex""" atom_counts = Counter() atom_pattern = r'([A−Z][a−z]?)(?:[+−]?\d∗|\.)∗([A-Z][a-z]?)(?:[+-]?\d*|\.)*|([A-Z][a-z]?)' matches = re.findall(atom_pattern, smiles) for match in matches: atom = match[0] if match[0] else match[1] if atom: atom_counts[atom] += 1 return atom_counts def get_charge_from_smiles(self, smiles): """Extract total charge from SMILES""" if not smiles or not isinstance(smiles, str): return 0 total_charge = 0 if RDKIT_AVAILABLE: try: mol = Chem.MolFromSmiles(smiles) if mol: total_charge = Chem.rdmolops.GetFormalCharge(mol) return total_charge except: pass # Regex fallback charge_pattern = r'[A−Z][a−z]?[]]∗?([+−]\d∗)[A-Z][a-z]?[^]]*?([+-]\d*)' matches = re.findall(charge_pattern, smiles) for charge_str in matches: if charge_str == '+': total_charge += 1 elif charge_str == '-': total_charge -= 1 else: try: total_charge += int(charge_str) except ValueError: pass return total_charge def parse_multi_component_smiles(self, multi_smiles): """Parse multi-component SMILES separated by dots""" if not multi_smiles or not isinstance(multi_smiles, str): return [] components = [comp.strip() for comp in multi_smiles.split('.') if comp.strip()] return components def check_balance(self, reactant_smiles, product_smiles): """Check atom balance and charge balance""" reactant_components = self.parse_multi_component_smiles(reactant_smiles) product_components = self.parse_multi_component_smiles(product_smiles) # Atom balance check reactant_atoms = Counter() for component in reactant_components: component_atoms = self.get_atom_counts_from_smiles(component) reactant_atoms.update(component_atoms) product_atoms = Counter() for component in product_components: component_atoms = self.get_atom_counts_from_smiles(component) product_atoms.update(component_atoms) missing_in_products = reactant_atoms - product_atoms extra_in_products = product_atoms - reactant_atoms missing_in_products = +missing_in_products extra_in_products = +extra_in_products is_atom_balanced = len(missing_in_products) == 0 and len(extra_in_products) == 0 total_reactant_atoms = sum(reactant_atoms.values()) total_imbalance = sum(missing_in_products.values()) + sum(extra_in_products.values()) if total_reactant_atoms == 0: atom_balance_score = 0.0 else: atom_balance_score = max(0.0, 1.0 - (total_imbalance / total_reactant_atoms)) # Charge balance check reactant_charge = sum(self.get_charge_from_smiles(comp) for comp in reactant_components) product_charge = sum(self.get_charge_from_smiles(comp) for comp in product_components) charge_difference = abs(reactant_charge - product_charge) is_charge_balanced = (charge_difference == 0) # Charge balance score calculation if charge_difference == 0: charge_balance_score = 1.0 elif charge_difference == 1: charge_balance_score = 0.7 elif charge_difference == 2: charge_balance_score = 0.4 else: charge_balance_score = max(0.0, 1.0 - charge_difference * 0.3) return { 'is_atom_balanced': is_atom_balanced, 'reactant_atoms': reactant_atoms, 'product_atoms': product_atoms, 'missing_in_products': missing_in_products, 'extra_in_products': extra_in_products, 'atom_balance_score': atom_balance_score, 'is_charge_balanced': is_charge_balanced, 'reactant_charge': reactant_charge, 'product_charge': product_charge, 'charge_difference': charge_difference, 'charge_balance_score': charge_balance_score } def predict_with_chemical_enhancement(self, reactant_smiles, num_return_sequences=20, max_length=256, num_beams=None, temperature=1.0, do_sample=False, top_k=50, top_p=0.95): """ Predict reaction products with chemical enhancement and comprehensive scoring """ if num_beams is None: num_beams = num_return_sequences + 5 print(f"Input reactants: {reactant_smiles}") print("-" * 80) # Tokenize input inputs = self.tokenizer(reactant_smiles, return_tensors="pt", max_length=max_length, truncation=True) # Move to GPU if available device = next(self.model.parameters()).device inputs = {k: v.to(device) for k, v in inputs.items()} start_time = time.time() # Generate predictions with torch.no_grad(): outputs = self.model.generate( input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], max_length=max_length, num_return_sequences=num_return_sequences, num_beams=num_beams, do_sample=do_sample, temperature=temperature if do_sample else 1.0, top_k=top_k if do_sample else None, top_p=top_p if do_sample else None, early_stopping=True, eos_token_id=self.model.config.eos_token_id, pad_token_id=self.model.config.pad_token_id, decoder_start_token_id=self.model.config.decoder_start_token_id, return_dict_in_generate=True, output_scores=True ) generation_time = time.time() - start_time print(f"Generation time: {generation_time:.3f} seconds") # Decode predictions generated_sequences = outputs.sequences raw_predictions = self.tokenizer.batch_decode(generated_sequences, skip_special_tokens=True) # Process and score predictions predictions_data = [] seen_canonical = set() for i, pred_raw in enumerate(raw_predictions): # Model score based on generation order model_score = 1.0 - (i / len(raw_predictions)) # Canonicalize SMILES canonical_pred = self.canonicalize_smiles(pred_raw) # Skip duplicates if canonical_pred in seen_canonical: continue seen_canonical.add(canonical_pred) # Chemical validity score validity_score = self.score_chemical_validity(canonical_pred) # Balance scores atom_balance_score = 0.0 charge_balance_score = 0.0 balance_info = {} if canonical_pred: balance_result = self.check_balance(reactant_smiles, canonical_pred) atom_balance_score = balance_result['atom_balance_score'] charge_balance_score = balance_result['charge_balance_score'] balance_info = balance_result # Comprehensive score comprehensive_score = ( model_score * self.MODEL_SCORE_WEIGHT + validity_score * self.VALIDITY_SCORE_WEIGHT + atom_balance_score * self.ATOM_BALANCE_SCORE_WEIGHT + charge_balance_score * self.CHARGE_BALANCE_SCORE_WEIGHT ) predictions_data.append({ 'rank': len(predictions_data) + 1, 'raw_smiles': pred_raw, 'canonical_smiles': canonical_pred, 'model_score': model_score, 'validity_score': validity_score, 'atom_balance_score': atom_balance_score, 'charge_balance_score': charge_balance_score, 'comprehensive_score': comprehensive_score, 'balance_info': balance_info }) # Sort by comprehensive score (chemical enhancement) predictions_data.sort(key=lambda x: x['comprehensive_score'], reverse=True) # Display results print("\n" + "="*80) print("ORIGINAL MODEL RANKING vs CHEMICALLY ENHANCED RANKING") print("="*80) print("\nORIGINAL MODEL RANKING (Top 5):") print("-" * 50) for i, pred in enumerate(predictions_data[:5]): original_rank = pred['rank'] print(f"{i+1:2d}. {pred['canonical_smiles']}") print(f" Original Rank: #{original_rank}, Model Score: {pred['model_score']:.3f}") print(f" Validity: {pred['validity_score']:.3f}, " f"Atom Balance: {pred['atom_balance_score']:.3f}, " f"Charge Balance: {pred['charge_balance_score']:.3f}") # Chemical analysis if pred['balance_info']: balance = pred['balance_info'] print(f" Chemical Analysis:") print(f" Reactant atoms: {dict(balance['reactant_atoms'])}") print(f" Product atoms: {dict(balance['product_atoms'])}") if balance['missing_in_products']: print(f" Missing atoms: {dict(balance['missing_in_products'])}") if balance['extra_in_products']: print(f" Extra atoms: {dict(balance['extra_in_products'])}") print(f" Charge: {balance['reactant_charge']} → {balance['product_charge']} " f"(Δ{balance['charge_difference']})") print() print("\nCHEMICALLY ENHANCED RANKING (Top 5):") print("-" * 50) enhanced_top5 = sorted(predictions_data, key=lambda x: x['comprehensive_score'], reverse=True)[:5] for i, pred in enumerate(enhanced_top5): original_rank = pred['rank'] improvement = f"↑{original_rank - (i+1)}" if original_rank > (i+1) else f"↓{(i+1) - original_rank}" if original_rank < (i+1) else "=" print(f"{i+1:2d}. {pred['canonical_smiles']}") print(f" Rank Change: #{original_rank} → #{i+1} ({improvement})") print(f" Comprehensive Score: {pred['comprehensive_score']:.3f}") print(f" Component Scores - Model: {pred['model_score']:.3f}, " f"Validity: {pred['validity_score']:.3f}, " f"Atom: {pred['atom_balance_score']:.3f}, " f"Charge: {pred['charge_balance_score']:.3f}") # Quality indicators quality_indicators = [] if pred['validity_score'] >= 0.9: quality_indicators.append("✓Valid") if pred['atom_balance_score'] >= 0.9: quality_indicators.append("✓Balanced") if pred['charge_balance_score'] >= 0.9: quality_indicators.append("✓Charge OK") if quality_indicators: print(f" Quality: {' '.join(quality_indicators)}") print() # Summary statistics print("="*80) print("SUMMARY STATISTICS") print("="*80) print(f"Total unique predictions: {len(predictions_data)}") print(f"Average validity score: {sum(p['validity_score'] for p in predictions_data) / len(predictions_data):.3f}") print(f"Average atom balance score: {sum(p['atom_balance_score'] for p in predictions_data) / len(predictions_data):.3f}") print(f"Average charge balance score: {sum(p['charge_balance_score'] for p in predictions_data) / len(predictions_data):.3f}") valid_predictions = sum(1 for p in predictions_data if p['validity_score'] >= 0.9) balanced_predictions = sum(1 for p in predictions_data if p['atom_balance_score'] >= 0.9) charge_balanced_predictions = sum(1 for p in predictions_data if p['charge_balance_score'] >= 0.9) print(f"High validity predictions (≥0.9): {valid_predictions}/{len(predictions_data)} ({valid_predictions/len(predictions_data)*100:.1f}%)") print(f"Well-balanced predictions (≥0.9): {balanced_predictions}/{len(predictions_data)} ({balanced_predictions/len(predictions_data)*100:.1f}%)") print(f"Charge-balanced predictions (≥0.9): {charge_balanced_predictions}/{len(predictions_data)} ({charge_balanced_predictions/len(predictions_data)*100:.1f}%)") return enhanced_top5, predictions_data # Initialize the predictor print("Initializing Chemical Reaction Predictor...") model_name = "ML4chemistry/chemical-reaction-t5-v3" predictor = ChemicalReactionPredictor(model_name) # Test with your example reactant_smiles = "CC(C)(C)N=N" # Run prediction with chemical enhancement top_predictions, all_predictions = predictor.predict_with_chemical_enhancement( reactant_smiles, num_return_sequences=20, num_beams=25, max_length=256 ) print("\n" + "="*80) print("FINAL TOP 3 CHEMICALLY ENHANCED PREDICTIONS") print("="*80) for i, pred in enumerate(top_predictions[:3], 1): print(f"{i}. {pred['canonical_smiles']}") print(f" Score: {pred['comprehensive_score']:.3f}") ```