Finetuning T5 problems

Hello everyone,

I want to use a finetuning script for a pretrained T5 model to map one sequence of tokens to another. While I get reasonable improvements for a smaller subsets (e.g. 80 train, 20 val), it completely breaks when I test it on larger amounts of data (e.g. 400 train, 100 val). I already experimented with batch sizes, gradient accumulation and weight decay. For learning rates I tested to start with 3e-4 as well as with 5e-5. I also attached the loss curve for the larger case where the training breaks

Does someone have any hints or clues what might be the problem in my setup?

Thank you for your time and help

1 Like

It’s difficult to pinpoint the issue without information beyond LR and graph, but common pitfalls seem to be like this


Hey thank you very much for your answer. I am trying to go through the most important points. I also tried to print some examples during training to check the padding. My actual usecase is that I am translating tokenized protein structures to sequences, so I am not using text or sentences (The model that I am using is pretrained on this protein data as well). E.g. here I print out the Tokens that the model gets, the models prediction, the label/gold and the padding and attention masks. A problem might also be that some sequences are very uninformative and have lots of repeating tokens, so I am now experimenting with filtering the data.Here are some examples:

--- save-pred example batch0-idx0 ---
TOKENS : ['d', 'v', 'a', 'v', 'q', 'a', 'v', 'v', 'v', 'v', 'y', 'v', 'y', 'y', 'v', 'v', 'v', 'v', 'v', 'q', 'v', 'v', 'q', 'c', 'v', 'l', 'l', 'l', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'c', 'y']
PRED   : ['4', '12', '2', '2', '4', '4', '4', '4', '12', '2', '5', '4', '4', '13', '17', '4', '17', '4', '12', '7', 'LABEL_0', '12', '5', '7', '5', '5', '5', '7', 'LABEL_0', '12', '12', '12', '5', '3', '5', '5', '5', '12', '5', '12']
GOLD   : ['10', '8', '2', '18', '13', '9', '12', '15', '-100', '15', '5', '12', '19', '3', '9', '14', '7', '3', '17', '13', '12', '8', '15', '-100', '-100', '14', 'LABEL_0', '14', '19', '3', '16', '3', '5', '15', '14', '5', 'LABEL_0', '17', '8', 'LABEL_0']
PAD POS: [178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
ATT=0  : [178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
SEQ LEN: 211, VALID: 178

--- save-pred example batch0-idx1 ---
TOKENS : ['d', 'v', 'a', 'v', 'q', 'a', 'v', 'v', 'v', 'v', 'y', 'v', 'y', 'y', 'v', 'v', 'v', 'v', 'v', 'q', 'v', 'v', 'q', 'c', 'v', 'l', 'l', 'l', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'c', 'y']
PRED   : ['8', '17', '9', '17', '17', '5', '5', '17', '17', '17', 'LABEL_1', '17', '17', '2', '2', '5', '5', '17', '17', '17', '17', '5', '16', '17', '19', '13', '9', '7', '8', '8', '12', '8', '5', '17', '17', '5', '3', '4', '5', '17']
GOLD   : ['10', '19', '5', '11', '18', '5', '14', '4', '7', '14', '17', '11', '9', '15', '16', '5', '2', '7', '8', '17', '3', '3', '19', '2', '3', '3', '9', 'LABEL_0', '8', '8', '18', '9', '5', '15', '14', '5', '9', 'LABEL_0', '7', '19']
PAD POS: None
ATT=0  : None
SEQ LEN: 211, VALID: 211

--- save-pred example batch0-idx2 ---
TOKENS : ['d', 'v', 'a', 'v', 'q', 'a', 'v', 'v', 'v', 'v', 'y', 'v', 'y', 'y', 'v', 'v', 'v', 'v', 'v', 'q', 'v', 'v', 'q', 'c', 'v', 'l', 'l', 'l', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'c', 'y']
PRED   : ['15', '17', '4', '5', '13', '12', 'LABEL_0', '5', '13', '17', '7', '9', '5', '2', '15', '11', '15', '5', '2', '2', '9', '9', '17', '17', '5', '11', 'LABEL_0', '17', '3', '3', 'LABEL_0', '5', '2', '2', '17', '13', '2', '16', '5', '13']
GOLD   : ['14', '4', '3', '16', '3', '3', '16', '14', '8', '9', '3', '7', '15', '5', '10', '12', '9', '5', '2', '9', '4', '-100', '7', '14', '4', 'LABEL_0', 'LABEL_0', '14', '15', '12', '3', '4', '8', '8', '15', '12', 'LABEL_0', '17', '14', '8']
PAD POS: [169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
ATT=0  : [169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
SEQ LEN: 211, VALID: 169

--- save-pred example batch0-idx3 ---
TOKENS : ['d', 'v', 'a', 'v', 'q', 'a', 'v', 'v', 'v', 'v', 'y', 'v', 'y', 'y', 'v', 'v', 'v', 'v', 'v', 'q', 'v', 'v', 'q', 'c', 'v', 'l', 'l', 'l', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'v', 'c', 'y']
PRED   : ['12', '8', '5', '13', '18', '7', '7', '13', '12', '5', '12', '3', '3', '9', '14', '7', '7', '8', '14', '17', '9', '14', '19', '19', '14', '19', '9', '5', '8', '2', '17', '2', '8', '2', '8', '8', '8', '15', '5', '8']
GOLD   : ['14', 'LABEL_0', '8', '17', 'LABEL_0', '10', '15', '-100', '4', '3', '12', '-100', '3', '19', '7', '14', '19', '2', '9', '9', '3', '8', '11', '7', '2', '7', '17', '14', '8', '14', '9', '11', '14', '12', '9', '16', '9', '15', '3', '8']
PAD POS: [201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
ATT=0  : [201, 202, 203, 204, 205, 206, 207, 208, 209, 210]
SEQ LEN: 211, VALID: 201
1 Like

In that case, either DataCollator or input IDs might be incorrect. Here’s some safe code.

# pip install -U transformers accelerate datasets huggingface_hub[hf_xet] trackio
# Minimal, safe baseline for token→token seq2seq (e.g., protein tokens).
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    EarlyStoppingCallback,
)
from transformers.integrations import TrackioCallback
import torch
import numpy as np
import random

def set_seed(s=13):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(s)
set_seed()

# --- toy tokenized protein-like pairs (replace with your data) ---
def join(chars): return " ".join(list(chars))
pairs = [
    dict(src=join("dvavqavvvvyvyyvvvvqvvqcvllllvvvvvvvvvcy"),
         tgt="10 8 2 18 13 9 12 15 15 5 12 19 3 9 14 7 3 17 13 12"),
]*40
raw = Dataset.from_list(pairs).train_test_split(test_size=0.1, seed=0)

# --- model + tokenizer ---
model_name = "t5-small"  # swap with your pretrained protein T5
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# --- preprocess: build labels via text_target (no -100 here) ---
src_max, tgt_max = 128, 64
def preprocess(ex):
    enc = tok(ex["src"], truncation=True, max_length=src_max)
    lab = tok(text_target=ex["tgt"], truncation=True, max_length=tgt_max)
    enc["labels"] = lab["input_ids"]  # collator will mask pad to -100
    return enc

train = raw["train"].map(preprocess, remove_columns=raw["train"].column_names)
val   = raw["test"].map(preprocess,  remove_columns=raw["test"].column_names)

# --- collator: masks label padding to -100 automatically ---
collator = DataCollatorForSeq2Seq(tokenizer=tok, model=model, pad_to_multiple_of=8)
# sanity: check that label pads become -100
batch = collator([train[i] for i in range(min(2, len(train)))])
assert (batch["labels"] == -100).any().item(), "Label pad masking failed"

# --- training args: early stopping tracks eval_loss ---
args = Seq2SeqTrainingArguments(
    output_dir="out",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    num_train_epochs=1,
    learning_rate=1e-4,
    lr_scheduler_type="linear",
    warmup_ratio=0.05,
    eval_strategy="steps",
    eval_steps=5,
    save_strategy="steps",
    save_steps=5,                      # keep equal to eval_steps
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    predict_with_generate=True,         # ensure .predict() returns token ids
    generation_max_length=tgt_max,
    group_by_length=True,
    fp16=False,                         # safer for T5; use bf16 if available
    logging_strategy="steps",
    logging_steps=1,
    logging_first_step=True,
    report_to="none",
    #report_to="trackio", # https://huggingface.co/docs/trackio/index
)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    data_collator=collator,
    processing_class=tok,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3), TrackioCallback(),],
)

# --- train + quick decode demo ---
metrics = trainer.train()
print(metrics)
print(trainer.evaluate())

pred = trainer.predict(val)
pred_ids = pred.predictions[0] if isinstance(pred.predictions, tuple) else pred.predictions
decoded = tok.batch_decode(pred_ids[:3], skip_special_tokens=True)
print("DECODED SAMPLES:")
for s in decoded: print(s)
"""
Step	Training Loss	Validation Loss
5	4.073900	4.016737
10	3.765400	3.520519
15	3.668000	3.347217
20	3.344500	3.184257
25	3.069900	3.072643
30	3.276600	3.006171
35	3.094000	2.975766

{'eval_loss': 2.97576642036438, 'eval_runtime': 0.5958, 'eval_samples_per_second': 6.714, 'eval_steps_per_second': 6.714, 'epoch': 1.0}
DECODED SAMPLES:
d v a v q a v y v y y v q v q c v l l l l l l l l l l l l l l 
d v a v q a v y v y y v q v q c v l l l l l l l l l l l l l l 
d v a v q a v y v y y v q v q c v l l l l l l l l l l l l l l 
"""

Hey thank you very much again. I tried to rebuild this with my current setup but I think my case is a bit different because I am using a classification scenario. So I am using the PT5_classification_model and the DataCollatorForTokenClassification. One problem might be that I have a mapping from the aa_tokens to the classes:

def create_token_to_aa_mapping(tokenizer):
    """
    Create mapping from tokenizer token IDs to amino acid class indices (0-19).
    """
    # Standard 20 amino acids
    AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
    AA_TO_CLASS = {aa: i for i, aa in enumerate(AMINO_ACIDS)}
    
    # Create mapping from token ID to class ID
    token_to_class = {}
    
    # Map single amino acid tokens
    for aa in AMINO_ACIDS:
        token_ids = tokenizer.encode(aa, add_special_tokens=False)
        if len(token_ids) == 1:
            token_to_class[token_ids[0]] = AA_TO_CLASS[aa]
    
    # Handle special tokens - map them to -100 (ignore in loss)
    special_token_ids = [
        tokenizer.pad_token_id,
        tokenizer.eos_token_id,
        tokenizer.unk_token_id,
    ]
    
    for token_id in special_token_ids:
        if token_id is not None:
            token_to_class[token_id] = -100
    
    return token_to_class

So my main function currently looks like this:

def main():

    args = get_input()

    set_seed()

    df = pd.read_csv("data.csv")

    if 'data_split' in df.columns:
        train_df = df[df['data_split'] == 'train']  
        val_df = df[df['data_split'] == 'valid']      

        train_df = train_df.iloc[:args.num_samples]
        val_df = val_df.iloc[:int(args.num_samples*0.1)]

        print(f'training samples:{len(train_df)}')
        print(f'test samples:{len(val_df)}')
    
        train_pairs = create_pairs(train_df)
        val_pairs = create_pairs(val_df)
        
        # Create datasets
        train = Dataset.from_list(train_pairs)
        val = Dataset.from_list(val_pairs)
        
        print(f"Train samples: {len(train_pairs)}")
        print(f"Val samples: {len(val_pairs)}")

    # Examples
    print("\nFirst few training examples:")
    for i in range(min(3, len(train))):
        example = train[i]
        print(f"Example {i+1}:")
        print(f"  Source (3Di): {example['src'][:50]}...")
        print(f"  Target (AA):  {example['tgt'][:50]}...")
        if 'key' in example:
            print(f"  Key: {example['key']}")
        print()

    # Check sequence lengths
    src_lengths = [len(example['src'].split()) for example in train]
    tgt_lengths = [len(example['tgt'].split()) for example in train]

    print(f"Source (3Di) length stats: min={min(src_lengths)}, max={max(src_lengths)}, avg={sum(src_lengths)/len(src_lengths):.1f}")
    print(f"Target (AA) length stats: min={min(tgt_lengths)}, max={max(tgt_lengths)}, avg={sum(tgt_lengths)/len(tgt_lengths):.1f}")

    ####################### model + tokenizer #######################

    model, tokenizer = PT5_classification_model(
            num_labels=20, model_dir="model_snapshot"
        )
    
    # Create token to amino acid class mapping
    token_to_class = create_token_to_aa_mapping(tokenizer)

    print(f"\n=== Token Mapping ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")
    print(f"Token to class mappings created: {len(token_to_class)}")
    print(f"Sample mappings: {dict(list(token_to_class.items())[:10])}")

    # Set sequence lengths
    src_max = min(max(src_lengths) + 10, 512)
    tgt_max = min(max(tgt_lengths) + 10, 512)
    print(f"Using src_max={src_max}, tgt_max={tgt_max}")

    print(f"Using src_max={src_max}, tgt_max={tgt_max}")

    # Data collator
    data_collator = DataCollatorForTokenClassification(tokenizer)

    def preprocess(ex):
        # Tokenize source 
        enc = tokenizer(ex["src"], truncation=True, max_length=src_max)
        
        # Tokenize target (AA sequence) 
        tgt_tokens = tokenizer(ex["tgt"], truncation=True, max_length=tgt_max, 
                               add_special_tokens=False)
        
        # Convert token IDs to amino acid class labels (0-19)
        class_labels = convert_tokens_to_classes(tgt_tokens["input_ids"], token_to_class)
        
        # Pad/truncate labels to match input length
        # For seq2seq, labels should align with decoder inputs
        enc["labels"] = class_labels
        
        return enc
    
    # Process datasets
    train_processed = train.map(preprocess, remove_columns=train.column_names)
    val_processed = val.map(preprocess, remove_columns=val.column_names)

    # Verification
    print("\n=== DEBUGGING TOKENIZATION ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")

    sample = train_processed[0]
    print(f"Sample input_ids: {sample['input_ids'][:10]}...")
    print(f"Sample labels (should be 0-19 or -100): {sample['labels'][:10]}...")

    # Check label values
    all_labels = []
    for item in train_processed:
        all_labels.extend([x for x in item['labels'] if x != -100])

    if all_labels:
        unique_labels = set(all_labels)
        print(f"Unique label values: {sorted(unique_labels)}")
        print(f"Max label: {max(all_labels)}")
        print(f"Min label: {min(all_labels)}")
        
        if max(all_labels) >= 20:
            print(f"❌ ERROR: Labels exceed 20 classes! Max: {max(all_labels)}")
            print("Need to debug the token mapping...")
            
            # Debug: show what's being tokenized
            test_aa = "ACDEFG"
            tokens = tokenizer(test_aa, add_special_tokens=False)
            print(f"\nTest AA sequence: {test_aa}")
            print(f"Token IDs: {tokens['input_ids']}")
            print(f"Tokens: {tokenizer.convert_ids_to_tokens(tokens['input_ids'])}")
            
            for tid in tokens['input_ids']:
                class_id = token_to_class.get(tid, -100)
                print(f"  Token {tid} -> Class {class_id}")
            
            return
        else:
            print(f"✓ All labels within [0, 19] range (20 amino acid classes)")

    # Sanity check: verify that label pads become -100
    batch = data_collator([train_processed[i] for i in range(min(2, len(train_processed)))])
    assert (batch["labels"] == -100).any().item(), "Label pad masking failed"
    print("✓ Label padding correctly masked to -100")

    # DEBUG: Check tokenization results
    print("\n=== DEBUGGING TOKENIZATION ===")
    print(f"Tokenizer vocab size: {len(tokenizer)}")
    sample = train_processed[0]
    print(f"Sample input_ids: {sample['input_ids'][:10]}...")
    print(f"Sample labels: {sample['labels'][:10]}...")
    print(f"Max input_id: {max(sample['input_ids']) if sample['input_ids'] else 'None'}")
    print(f"Min input_id: {min(sample['input_ids']) if sample['input_ids'] else 'None'}")
    print(f"Max label: {max([x for x in sample['labels'] if x != -100]) if sample['labels'] else 'None'}")
    print(f"Min label: {min([x for x in sample['labels'] if x != -100]) if sample['labels'] else 'None'}")

    # Check if any labels are out of bounds
    all_labels = []
    for item in train_processed:
        all_labels.extend([x for x in item['labels'] if x != -100])

    vocab_size = len(tokenizer)
    out_of_bounds = [x for x in all_labels if x >= vocab_size or x < 0]
    if out_of_bounds:
        print(f"❌ FOUND {len(out_of_bounds)} OUT-OF-BOUNDS LABELS:")
        print(f"   Vocab size: {vocab_size}")
        print(f"   Out of bounds values: {sorted(set(out_of_bounds))[:10]}...")
        
        
        # See what the tokenizer produces for sequences
        print(f"\nDEBUG: Raw sequences vs tokenized:")
        raw_3di = train[0]['src'][:20]  # First 20 chars
        raw_aa = train[0]['tgt'][:20]   # First 20 chars
        print(f"Raw 3Di: {raw_3di}")
        print(f"Raw AA:  {raw_aa}")
        
        tok_3di = tokenizer(raw_3di)['input_ids']
        tok_aa = tokenizer(text_target=raw_aa)['input_ids']
        print(f"Tokenized 3Di: {tok_3di}")
        print(f"Tokenized AA:  {tok_aa}")
        
        return  # Stop execution to debug
    else:
        print("✓ All labels are within vocabulary bounds")

    # Debug: Check what's in the datasets before processing
    print(f"\n=== Dataset Debug ===")
    print(f"Raw train_pairs: {len(train_pairs)}")
    print(f"Raw val_pairs: {len(val_pairs)}")

    # Process datasets
    train_processed = train.map(preprocess, remove_columns=train.column_names)
    val_processed = val.map(preprocess, remove_columns=val.column_names)

    print(f"Processed train: {len(train_processed)}")
    print(f"Processed val: {len(val_processed)}")


    # Training arguments (following safe code pattern)
    training_args = Seq2SeqTrainingArguments(
        output_dir="finetuning",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=1,
        learning_rate=1e-4,
        lr_scheduler_type="linear",
        warmup_ratio=0.05,
        eval_strategy="steps",
        eval_steps=100,  # Adjust based on your dataset size
        save_strategy="steps",
        save_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        predict_with_generate=False,
        generation_max_length=tgt_max,
        group_by_length=True,
        fp16=False,
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        report_to="none",
        remove_unused_columns=False, # added
        save_safetensors=False
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_processed,
        eval_dataset=val_processed,
        data_collator=data_collator,
        processing_class=tokenizer,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

    # Train
    print("\nStarting training...")
    metrics = trainer.train()
    print(metrics)
    
    print("\nEvaluation results:")
    eval_results = trainer.evaluate()
    print(eval_results)

    # Test predictions
    print("\n=== Testing predictions ===")
    pred = trainer.predict(val_processed.select(range(min(3, len(val_processed)))))
    
    # For classification, predictions are logits
    if hasattr(pred, 'predictions'):
        pred_logits = pred.predictions
        pred_classes = np.argmax(pred_logits, axis=-1)
        
        print(f"Prediction shape: {pred_classes.shape}")
        print(f"First prediction (class indices): {pred_classes[0][:20]}")
        
        # Convert class indices back to amino acids
        CLASS_TO_AA = "ACDEFGHIKLMNPQRSTVWY"
        for i in range(min(3, len(pred_classes))):
            pred_aa = ''.join([CLASS_TO_AA[c] if 0 <= c < 20 else '?' 
                              for c in pred_classes[i] if c != -100])
            print(f"Sample {i+1} predicted AA: {pred_aa[:50]}...")

    print("\n✓ Training completed successfully!")
1 Like

I’ve filled in some gaps with assumptions, so this might not work as-is, but there’s probably a mismatch in the finer details of how the function is used


Hey Thank you again for your detailed answer. Now things became a bit clearer for me. I realized that the actual Token Classification case will come later when I have a different dataset and that I will at first focus on the seq2seq case for my dataset. Sorry for the confusion.So I am preprocessing my dataset and put the task specification token in front and I am getting the whole T5 model

prefix_s2t = "<fold2AA>"

def preprocess(ex):
        """
        Preprocess examples for seq2seq training.
        Add the <fold2AA> prefix to source sequences
        """

        # Add prefix to source sequences (3Di)
        inputs = [f"{prefix_s2t} {src}" for src in ex["src"]]
        targets = ex["tgt"]

        # Tokenize inputs (3Di sequence)
        model_inputs = tokenizer(
            inputs,
            max_length=src_max,
            truncation=True,
            padding=False,  # DataCollator will handle padding
        )
        
        # Tokenize target (AA sequence) 
        with tokenizer.as_target_tokenizer():
            labels = tokenizer(
                targets,
                max_length=tgt_max,
                truncation=True,
                padding=False,
            )

        # Add labels to model inputs
        model_inputs["labels"] = labels["input_ids"]
        
        
        return model_inputs

train_processed = train.map(preprocess, remove_columns=train.column_names, batched=True, batch_size=1)
val_processed = val.map(preprocess, remove_columns=val.column_names, batched=True, batch_size=1)

Then I am using the DataCollatorForSeq2Seq and the Seq2SeqTrainingArguments

data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding='max_length',
        max_length=src_max,
        label_pad_token_id=-100,
    )

# Training arguments (following safe code pattern)
    training_args = Seq2SeqTrainingArguments(
        output_dir="finetuning_prostt5_safecode",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        num_train_epochs=100,
        learning_rate=5e-5,
        max_grad_norm=1.0,
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        eval_strategy="steps",
        eval_steps=100,  # Adjust based on your dataset size
        save_strategy="steps",
        save_steps=100,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        predict_with_generate=True,
        generation_max_length=tgt_max,
        group_by_length=True,
        fp16=False,
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        report_to="none",
        remove_unused_columns=False, # added
        save_safetensors=False
    )

trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_processed,
        eval_dataset=val_processed,
        data_collator=data_collator,
        processing_class=tokenizer,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
    )

I ran it on a small set of sequences (10 train, 1 val) whereis the corresponding sequences have a relatively short length (<50). I ran it for 100 epochs and the train_loss went from ~ 5 to ~1 and the eval_loss from ~ 2.55 to ~ 0.4 at the end. However I am not sure if I there might still be problems because the sequence recovery is very low on the evaluation step.

Thank you again very much for your patience and help

1 Like

Hey, I would have another update/question on my finetuning approach. So my problem setup is that I have a pretrained T5 model which is going from a sequence of tokens → sequence of tokens, so a 1:1 mapping. My finetuning approach is now that I have n sequences which I want to map to the same output sequence of tokens (per example), so a n:1 mapping. In order to feed the model with the n input sequences per example, I calculated a frequency matrix of dimension (seq_len, token_id_frequency) and multiplied this with the initial embedding matrix of the tokens which has dimension (token_id, embedding_dim) in order to interpolate the initial embeddings. I was aiming for an improvement of the sequence prediction/recovery because the new (weighted) embeddings would now contain more information where the model would learn from. However my finetuning saturates quickly and I was wondering what might be the problem.

Thank you again for your time and help in case you are seing my answer

1 Like

Hmm
?


Your saturation is a predictable outcome of the specific fusion you chose: you compressed n discrete sequences into one sequence of averaged embeddings. That compression usually (1) throws away exactly the information you hoped to add, and (2) pushes the encoder into an input regime it was not pretrained on.

Below are the most likely causes in your exact setup, plus fixes that have strong precedent in multi-input seq2seq work.


1) What your “frequency × embedding” fusion really does

You built “soft tokens.”

At position t, you estimate a distribution over vocab IDs from the n candidates, then compute the expected embedding:

  • e_t = ÎŁ_v p(t,v) * E[v]

This is a valid way to feed T5 because the model forward accepts inputs_embeds. (Hugging Face)

But it has two structural problems.

Problem A: it keeps only per-position marginals, not sequence structure

Your fusion preserves “which tokens appear at position t.”
It destroys:

  • which tokens co-occur in the same candidate sequence
  • cross-position dependencies inside each candidate
  • “candidate identity” (which source said what)

So you reduce n full sequences to a much weaker signal. The model learns the easy part quickly (positions where candidates agree) then hits a ceiling.

Problem B: if candidates are even slightly misaligned, you are averaging unrelated tokens

Your method assumes token position t means the same thing across all candidates. If any candidate has an insertion, deletion, or shifted tokenization, then the histogram at position t mixes unrelated tokens. That behaves like structured noise and causes early plateau.


2) Out-of-distribution embeddings and “centroid collapse”

T5 was pretrained on embeddings that come from one discrete token ID per position. Your mixture embeddings often land in regions of embedding space that do not correspond to any real token embedding the model frequently saw.

Averaging embeddings also tends to shrink distinctions: many different distributions can produce similar mean vectors. So lots of examples become “more similar” to the encoder than they should be. That reduces separability and limits achievable loss reduction.

This is exactly why uncertain-input literature usually avoids naive averaging and instead preserves alternatives as lattices/confusion networks and consumes them with an architecture designed for uncertainty. (arXiv)


3) The most common silent bug: wrong scaling and PAD contamination

Even if the modeling idea were fine, two implementation details can kill learning.

Scaling

If you used raw counts (row sums ≈ n), then embedding magnitudes scale with n. Even with layer norms, you changed the distribution of activations the encoder sees, which can flatten attention or make gradients small.

Fix: per position, normalize to probabilities (row-sum = 1). Then optionally sharpen (below).

PAD contamination

If some candidates are shorter and you pad them, PAD gets counted. Then PAD embedding leaks into e_t. That is poison because you are injecting a strong “nothing here” vector into real positions.

Fix: exclude PAD tokens from the histogram entirely. Keep attention_mask correct for real vs padded time steps.


4) Why your loss curve has spikes and saturates (training dynamics explanation)

Two common patterns create a “down then stall with spikes” plot.

Warm restarts (scheduler)

If you use cosine annealing with warm restarts, the LR jumps back up at each restart. That can cause periodic loss spikes even if everything else is correct. (PyTorch Docs)

Debug choice: switch to linear warmup + decay while you diagnose modeling. Avoid restarts until stable.

Adafactor configuration traps (common with T5)

Adafactor settings are easy to misconfigure. For example, warmup_init=True requires relative_step=True, and that conflicts with setting a manual LR. This is a documented Transformers pitfall. (GitHub)

Debug choice: use AdamW first (simple), or use Adafactor with a known-good configuration from the docs.


5) The core issue: you want n:1 evidence fusion, but you implemented “early fusion by averaging”

For n:1, the strongest practical template is:

  • keep each candidate as a real sequence
  • encode each candidate separately
  • fuse later (decoder attention or learned pooling)

This preserves structure and lets the model learn “which candidate to trust.”

The proven pattern: Fusion-in-Decoder (FiD)

FiD encodes each input separately and concatenates encoder states; the decoder cross-attends to all of them. It is explicitly designed for “many inputs to one output.” (arXiv)

FiD is usually used for retrieved passages, but your “n candidate sequences” is the same abstraction: multiple evidence streams.

Classic multi-source seq2seq precedent

Multi-source neural translation uses multiple encoders and a single decoder, exploring combination methods. Same structural problem, older but foundational. (arXiv)

Uncertain-input precedent

If your n sequences represent alternative hypotheses with implicit probabilities, lattice-to-seq models show why preserving posterior structure matters and how to incorporate it. (arXiv)


6) Solutions, ranked by “probability of fixing your saturation”

Solution 1 (fastest baseline): treat candidates as augmentation, not fusion

Create n training pairs per original example:

  • (candidate_i → target) for i in 1..n

This often beats embedding averaging because you stay fully in-distribution (discrete tokens) and you do not destroy structure.

Inference options:

  • run all candidates and choose the best by log-likelihood
  • ensemble outputs (logprob sum)

Solution 2 (simple single-pass fusion): concatenate candidates with separators and tags

Input text like:

  • cand1: ... </s> cand2: ... </s> ...

Pros: trivial to implement, preserves order.
Cons: context length grows; attention cost grows.

Solution 3 (best “correct fusion”): FiD-style encode-separately then fuse in decoder

High level:

  1. reshape batch to encode each candidate independently
  2. concatenate encoder hidden states
  3. feed concatenated states to decoder

If you want an existence proof, the FiD repo shows the exact reshape-and-concatenate trick. (GitHub)

Solution 4 (if you insist on your histogram idea): make it less lossy and learnable

If you keep “distribution over vocab per position,” do not map it to embeddings with a fixed linear average and stop.

Do this instead:

  1. Normalize counts to probabilities.
  2. Sharpen to reduce blur: q ∝ p^α with α > 1.
  3. Add a confidence channel (entropy or max prob) so the model knows which positions are ambiguous.
  4. Add a trainable projection: e'_t = W e_t or small MLP before feeding encoder.

This lets the model learn how to interpret your soft evidence.


7) “Plumbing” checks that you should run regardless

These do not solve the fundamental fusion issue, but they can mimic saturation.

Check A: label padding is ignored (-100)

For seq2seq, padded label positions must be -100 so loss ignores them. DataCollatorForSeq2Seq defaults label_pad_token_id=-100 and documents that -100 is ignored by PyTorch losses. (Hugging Face)

If you accidentally compute loss on padded labels, you will see an artificial floor.

Check B: generation with inputs_embeds is special

Forward can use inputs_embeds, but generate() historically does not accept it cleanly for seq2seq. People work around it by calling the encoder first and passing encoder_outputs to generate(). (GitHub)

This matters for evaluation: you can think you are evaluating the model you trained, but you are not actually using the same conditioning path unless you do this correctly. A recent discussion clarifies how generate() behaves when you pass encoder_outputs. (Hugging Face Forums)


8) A concrete debug plan that will tell you what is wrong in 1–2 experiments

Experiment 1: can the model overfit a tiny subset with normal discrete input?

Take 32 examples, pick one candidate per example, train until near-zero loss.

  • If you cannot overfit, you have a training/pipeline bug (labels, masking, LR, optimizer).
  • If you can, your saturation is caused by your fusion method.

Experiment 2: compare three n:1 strategies head-to-head

Hold everything constant and run:

  1. augmentation baseline (n separate pairs)
  2. concatenation baseline
  3. your histogram-embedding fusion

If (1) or (2) beats (3) quickly, your averaging destroyed signal.

Experiment 3: measure blur

On a batch, compute:

  • mean per-position entropy of p(t,·)
  • mean norm of fused e_t
  • cosine similarity between fused sequences across examples

High entropy + high similarity between examples is the signature of “centroid collapse.”


Curated links (high-signal, directly relevant)


Summary

  • Your fusion averages away sequence structure and candidate identity, so the model learns the easy consensus then saturates.
  • Misalignment across candidates and PAD leakage can turn your histogram into noise.
  • Scaling issues (counts vs probabilities) can cause early plateaus.
  • Strong fixes: augmentation, concatenation with separators, or FiD-style encode-separately and fuse in decoder.
  • Also verify label padding -100, avoid warm restarts while debugging, and confirm Adafactor settings.
1 Like

Detailed version.

1 Like

Thank you very much for the detailed analysis and tips. I will start with the augmentation approach and compare the performance against the baseline and will work through your points. Thank you again for your time and help

2 Likes