Model Card for Model ID

Model Details

Model Description

This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.

  • Developed by: [Zhi Huang]
  • Shared by [optional]: [Zhi Huang]
  • Model type: [Chemical reaction prediction]
  • License: [MIT]

Model Sources [optional]

  • Repository: [More Information Needed]
  • Paper [optional]: [More Information Needed]
  • Demo [optional]: [More Information Needed]

Uses

import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForSeq2SeqLM
from collections import Counter
import re
import time

# Optional: RDKit for advanced chemical validation
try:
    from rdkit import Chem
    from rdkit.Chem.rdMolDescriptors import GetMorganFingerprintAsBitVect
    from rdkit.DataStructs import TanimotoSimilarity
    from rdkit import RDLogger
    RDLogger.DisableLog('rdApp.*')  # Disable RDKit warnings
    RDKIT_AVAILABLE = True
    print("RDKit found. Advanced chemical validation will be used.")
except ImportError:
    RDKIT_AVAILABLE = False
    print("RDKit not found. Using basic validation methods.")

class ChemicalReactionPredictor:
    def __init__(self, model_name):
        """Initialize the chemical reaction predictor with enhanced validation"""
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoPeftModelForSeq2SeqLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )

        # Configure model tokens
        self.model.config.bos_token_id = self.tokenizer.bos_token_id
        self.model.config.eos_token_id = self.tokenizer.eos_token_id
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        self.model.config.decoder_start_token_id = self.tokenizer.pad_token_id

        self.model.eval()

        # Scoring weights for comprehensive evaluation
        self.MODEL_SCORE_WEIGHT = 0.3
        self.VALIDITY_SCORE_WEIGHT = 0.2
        self.ATOM_BALANCE_SCORE_WEIGHT = 0.25
        self.CHARGE_BALANCE_SCORE_WEIGHT = 0.25

    def canonicalize_smiles(self, smiles):
        """Canonicalize SMILES using RDKit if available"""
        if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str):
            return str(smiles).strip() if isinstance(smiles, str) else ""

        try:
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                return Chem.MolToSmiles(mol, canonical=True)

            # Try without sanitization
            mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol_no_sanitize:
                return Chem.MolToSmiles(mol_no_sanitize, canonical=True)

            return smiles.strip()
        except Exception:
            return smiles.strip()

    def score_chemical_validity(self, smiles):
        """Score chemical validity of SMILES (0.0 to 1.0)"""
        if not RDKIT_AVAILABLE or not smiles or not isinstance(smiles, str):
            return 0.1  # Basic validity for non-empty strings

        try:
            mol = Chem.MolFromSmiles(smiles, sanitize=True)
            if mol:
                return 1.0  # Fully valid

            mol_no_sanitize = Chem.MolFromSmiles(smiles, sanitize=False)
            if mol_no_sanitize:
                return 0.5  # Partially valid

            return 0.1  # Invalid but parseable
        except Exception:
            return 0.0  # Completely invalid

    def get_atom_counts_from_smiles(self, smiles):
        """Get atom counts from SMILES string"""
        if not smiles or not isinstance(smiles, str):
            return Counter()

        smiles = smiles.strip()
        if not smiles:
            return Counter()

        if RDKIT_AVAILABLE:
            try:
                mol = Chem.MolFromSmiles(smiles, sanitize=True)
                if mol:
                    atom_counts = Counter()
                    for atom in mol.GetAtoms():
                        symbol = atom.GetSymbol()
                        atom_counts[symbol] += 1
                    return atom_counts
            except Exception:
                pass

            try:
                mol = Chem.MolFromSmiles(smiles, sanitize=False)
                if mol:
                    atom_counts = Counter()
                    for atom in mol.GetAtoms():
                        symbol = atom.GetSymbol()
                        atom_counts[symbol] += 1
                    return atom_counts
            except Exception:
                pass

        # Fallback regex method
        return self._regex_atom_count_fallback(smiles)

    def _regex_atom_count_fallback(self, smiles):
        """Fallback atom counting using regex"""
        atom_counts = Counter()
        atom_pattern = r'([A−Z][a−z]?)(?:[+−]?\d∗|\.)∗([A-Z][a-z]?)(?:[+-]?\d*|\.)*|([A-Z][a-z]?)'
        matches = re.findall(atom_pattern, smiles)

        for match in matches:
            atom = match[0] if match[0] else match[1]
            if atom:
                atom_counts[atom] += 1

        return atom_counts

    def get_charge_from_smiles(self, smiles):
        """Extract total charge from SMILES"""
        if not smiles or not isinstance(smiles, str):
            return 0

        total_charge = 0

        if RDKIT_AVAILABLE:
            try:
                mol = Chem.MolFromSmiles(smiles)
                if mol:
                    total_charge = Chem.rdmolops.GetFormalCharge(mol)
                    return total_charge
            except:
                pass

        # Regex fallback
        charge_pattern = r'[A−Z][a−z]?[]]∗?([+−]\d∗)[A-Z][a-z]?[^]]*?([+-]\d*)'
        matches = re.findall(charge_pattern, smiles)

        for charge_str in matches:
            if charge_str == '+':
                total_charge += 1
            elif charge_str == '-':
                total_charge -= 1
            else:
                try:
                    total_charge += int(charge_str)
                except ValueError:
                    pass

        return total_charge

    def parse_multi_component_smiles(self, multi_smiles):
        """Parse multi-component SMILES separated by dots"""
        if not multi_smiles or not isinstance(multi_smiles, str):
            return []
        components = [comp.strip() for comp in multi_smiles.split('.') if comp.strip()]
        return components

    def check_balance(self, reactant_smiles, product_smiles):
        """Check atom balance and charge balance"""
        reactant_components = self.parse_multi_component_smiles(reactant_smiles)
        product_components = self.parse_multi_component_smiles(product_smiles)

        # Atom balance check
        reactant_atoms = Counter()
        for component in reactant_components:
            component_atoms = self.get_atom_counts_from_smiles(component)
            reactant_atoms.update(component_atoms)

        product_atoms = Counter()
        for component in product_components:
            component_atoms = self.get_atom_counts_from_smiles(component)
            product_atoms.update(component_atoms)

        missing_in_products = reactant_atoms - product_atoms
        extra_in_products = product_atoms - reactant_atoms
        missing_in_products = +missing_in_products
        extra_in_products = +extra_in_products

        is_atom_balanced = len(missing_in_products) == 0 and len(extra_in_products) == 0
        total_reactant_atoms = sum(reactant_atoms.values())
        total_imbalance = sum(missing_in_products.values()) + sum(extra_in_products.values())

        if total_reactant_atoms == 0:
            atom_balance_score = 0.0
        else:
            atom_balance_score = max(0.0, 1.0 - (total_imbalance / total_reactant_atoms))

        # Charge balance check
        reactant_charge = sum(self.get_charge_from_smiles(comp) for comp in reactant_components)
        product_charge = sum(self.get_charge_from_smiles(comp) for comp in product_components)

        charge_difference = abs(reactant_charge - product_charge)
        is_charge_balanced = (charge_difference == 0)

        # Charge balance score calculation
        if charge_difference == 0:
            charge_balance_score = 1.0
        elif charge_difference == 1:
            charge_balance_score = 0.7
        elif charge_difference == 2:
            charge_balance_score = 0.4
        else:
            charge_balance_score = max(0.0, 1.0 - charge_difference * 0.3)

        return {
            'is_atom_balanced': is_atom_balanced,
            'reactant_atoms': reactant_atoms,
            'product_atoms': product_atoms,
            'missing_in_products': missing_in_products,
            'extra_in_products': extra_in_products,
            'atom_balance_score': atom_balance_score,
            'is_charge_balanced': is_charge_balanced,
            'reactant_charge': reactant_charge,
            'product_charge': product_charge,
            'charge_difference': charge_difference,
            'charge_balance_score': charge_balance_score
        }

    def predict_with_chemical_enhancement(self, reactant_smiles, num_return_sequences=20,
                                        max_length=256, num_beams=None,
                                        temperature=1.0, do_sample=False,
                                        top_k=50, top_p=0.95):
        """
        Predict reaction products with chemical enhancement and comprehensive scoring
        """
        if num_beams is None:
            num_beams = num_return_sequences + 5

        print(f"Input reactants: {reactant_smiles}")
        print("-" * 80)

        # Tokenize input
        inputs = self.tokenizer(reactant_smiles, return_tensors="pt",
                               max_length=max_length, truncation=True)

        # Move to GPU if available
        device = next(self.model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        start_time = time.time()

        # Generate predictions
        with torch.no_grad():
            outputs = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                num_return_sequences=num_return_sequences,
                num_beams=num_beams,
                do_sample=do_sample,
                temperature=temperature if do_sample else 1.0,
                top_k=top_k if do_sample else None,
                top_p=top_p if do_sample else None,
                early_stopping=True,
                eos_token_id=self.model.config.eos_token_id,
                pad_token_id=self.model.config.pad_token_id,
                decoder_start_token_id=self.model.config.decoder_start_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )

        generation_time = time.time() - start_time
        print(f"Generation time: {generation_time:.3f} seconds")

        # Decode predictions
        generated_sequences = outputs.sequences
        raw_predictions = self.tokenizer.batch_decode(generated_sequences, skip_special_tokens=True)

        # Process and score predictions
        predictions_data = []
        seen_canonical = set()

        for i, pred_raw in enumerate(raw_predictions):
            # Model score based on generation order
            model_score = 1.0 - (i / len(raw_predictions))

            # Canonicalize SMILES
            canonical_pred = self.canonicalize_smiles(pred_raw)

            # Skip duplicates
            if canonical_pred in seen_canonical:
                continue
            seen_canonical.add(canonical_pred)

            # Chemical validity score
            validity_score = self.score_chemical_validity(canonical_pred)

            # Balance scores
            atom_balance_score = 0.0
            charge_balance_score = 0.0
            balance_info = {}

            if canonical_pred:
                balance_result = self.check_balance(reactant_smiles, canonical_pred)
                atom_balance_score = balance_result['atom_balance_score']
                charge_balance_score = balance_result['charge_balance_score']
                balance_info = balance_result

            # Comprehensive score
            comprehensive_score = (
                model_score * self.MODEL_SCORE_WEIGHT +
                validity_score * self.VALIDITY_SCORE_WEIGHT +
                atom_balance_score * self.ATOM_BALANCE_SCORE_WEIGHT +
                charge_balance_score * self.CHARGE_BALANCE_SCORE_WEIGHT
            )

            predictions_data.append({
                'rank': len(predictions_data) + 1,
                'raw_smiles': pred_raw,
                'canonical_smiles': canonical_pred,
                'model_score': model_score,
                'validity_score': validity_score,
                'atom_balance_score': atom_balance_score,
                'charge_balance_score': charge_balance_score,
                'comprehensive_score': comprehensive_score,
                'balance_info': balance_info
            })

        # Sort by comprehensive score (chemical enhancement)
        predictions_data.sort(key=lambda x: x['comprehensive_score'], reverse=True)

        # Display results
        print("\n" + "="*80)
        print("ORIGINAL MODEL RANKING vs CHEMICALLY ENHANCED RANKING")
        print("="*80)

        print("\nORIGINAL MODEL RANKING (Top 5):")
        print("-" * 50)
        for i, pred in enumerate(predictions_data[:5]):
            original_rank = pred['rank']
            print(f"{i+1:2d}. {pred['canonical_smiles']}")
            print(f"    Original Rank: #{original_rank}, Model Score: {pred['model_score']:.3f}")
            print(f"    Validity: {pred['validity_score']:.3f}, "
                  f"Atom Balance: {pred['atom_balance_score']:.3f}, "
                  f"Charge Balance: {pred['charge_balance_score']:.3f}")

            # Chemical analysis
            if pred['balance_info']:
                balance = pred['balance_info']
                print(f"    Chemical Analysis:")
                print(f"      Reactant atoms: {dict(balance['reactant_atoms'])}")
                print(f"      Product atoms:  {dict(balance['product_atoms'])}")
                if balance['missing_in_products']:
                    print(f"      Missing atoms:  {dict(balance['missing_in_products'])}")
                if balance['extra_in_products']:
                    print(f"      Extra atoms:    {dict(balance['extra_in_products'])}")
                print(f"      Charge: {balance['reactant_charge']}{balance['product_charge']} "
                      f"(Δ{balance['charge_difference']})")
            print()

        print("\nCHEMICALLY ENHANCED RANKING (Top 5):")
        print("-" * 50)
        enhanced_top5 = sorted(predictions_data, key=lambda x: x['comprehensive_score'], reverse=True)[:5]

        for i, pred in enumerate(enhanced_top5):
            original_rank = pred['rank']
            improvement = f"↑{original_rank - (i+1)}" if original_rank > (i+1) else f"↓{(i+1) - original_rank}" if original_rank < (i+1) else "="

            print(f"{i+1:2d}. {pred['canonical_smiles']}")
            print(f"    Rank Change: #{original_rank} → #{i+1} ({improvement})")
            print(f"    Comprehensive Score: {pred['comprehensive_score']:.3f}")
            print(f"    Component Scores - Model: {pred['model_score']:.3f}, "
                  f"Validity: {pred['validity_score']:.3f}, "
                  f"Atom: {pred['atom_balance_score']:.3f}, "
                  f"Charge: {pred['charge_balance_score']:.3f}")

            # Quality indicators
            quality_indicators = []
            if pred['validity_score'] >= 0.9:
                quality_indicators.append("✓Valid")
            if pred['atom_balance_score'] >= 0.9:
                quality_indicators.append("✓Balanced")
            if pred['charge_balance_score'] >= 0.9:
                quality_indicators.append("✓Charge OK")

            if quality_indicators:
                print(f"    Quality: {' '.join(quality_indicators)}")
            print()

        # Summary statistics
        print("="*80)
        print("SUMMARY STATISTICS")
        print("="*80)
        print(f"Total unique predictions: {len(predictions_data)}")
        print(f"Average validity score: {sum(p['validity_score'] for p in predictions_data) / len(predictions_data):.3f}")
        print(f"Average atom balance score: {sum(p['atom_balance_score'] for p in predictions_data) / len(predictions_data):.3f}")
        print(f"Average charge balance score: {sum(p['charge_balance_score'] for p in predictions_data) / len(predictions_data):.3f}")

        valid_predictions = sum(1 for p in predictions_data if p['validity_score'] >= 0.9)
        balanced_predictions = sum(1 for p in predictions_data if p['atom_balance_score'] >= 0.9)
        charge_balanced_predictions = sum(1 for p in predictions_data if p['charge_balance_score'] >= 0.9)

        print(f"High validity predictions (≥0.9): {valid_predictions}/{len(predictions_data)} ({valid_predictions/len(predictions_data)*100:.1f}%)")
        print(f"Well-balanced predictions (≥0.9): {balanced_predictions}/{len(predictions_data)} ({balanced_predictions/len(predictions_data)*100:.1f}%)")
        print(f"Charge-balanced predictions (≥0.9): {charge_balanced_predictions}/{len(predictions_data)} ({charge_balanced_predictions/len(predictions_data)*100:.1f}%)")

        return enhanced_top5, predictions_data

# Initialize the predictor
print("Initializing Chemical Reaction Predictor...")
model_name = "ML4chemistry/chemical-reaction-t5-v3"
predictor = ChemicalReactionPredictor(model_name)

# Test with your example
reactant_smiles = "CC(C)(C)N=N"

# Run prediction with chemical enhancement
top_predictions, all_predictions = predictor.predict_with_chemical_enhancement(
    reactant_smiles,
    num_return_sequences=20,
    num_beams=25,
    max_length=256
)

print("\n" + "="*80)
print("FINAL TOP 3 CHEMICALLY ENHANCED PREDICTIONS")
print("="*80)
for i, pred in enumerate(top_predictions[:3], 1):
    print(f"{i}. {pred['canonical_smiles']}")
    print(f"   Score: {pred['comprehensive_score']:.3f}")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support