Question about reference performance

#1
by sh0416 - opened

I've downloaded your checkpoint and evaluated the performance.
It shows the same accuracy, 59.66% which is slightly different but tolerable.
However, I couldn't get the performance of Qwen/Qwen2.5-0.5B-Instruct, which is reported about 50% in this document, but when I evaluate the initial checkpoint, the performance is about 41%.
Did you do additional training such as supervised finetuning or other techniques before applying PPO?

Thank you in advance

Hi, Thanks so much for your active use and detailed feedback! I'm happy to see you were able to replicate the performance on your end.

Regarding the baseline performance of Qwen/Qwen2.5-0.5B-Instruct, I didn't test the initial checkpoint myself. The 49.6% performance I reported was taken directly from Table 10 of their official technical report.

I really appreciate you pointing out that your evaluation shows its performance is around 41%. If that's the case, it actually means the performance uplift from my model is even greater than I originally claimed! That's very encouraging.

To answer your other question, I can confirm that I did not perform any additional supervised fine-tuning (SFT) before applying PPO.

I will also test the baseline performance myself today and will post the results here once I have them.

Thanks again for your valuable feedback!

image.png
(Table-10 from qwen-2.5 tech report)

Can I ask more details about the training? I am trying to match the prompt template in the README.md and using max_new_tokens=256, but most of the completions could not derive final answer in the initial checkpoint. If you could do, could you tell me about the training hyperparameters? There are lots of hyperparameters in PPO, but I could not set the good initial point for this kind of training. I'm completely new in this domain, so I need a lot of help, and your feedback would be invaluable if you don't mind.

Thank you

Hi Seonghyeon, thank you for your interest in this model! I'd be happy to share the training details that led to the 58.9% (59.66% on your side) GSM8K performance.

🎯 Key Training Configuration

Prompt Template (Critical!)

The model was trained using Qwen's chat template format. This is essential for proper inference:

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("your-model-path")
model = AutoModelForCausalLM.from_pretrained("your-model-path")

# Correct prompt format
messages = [{"role": "user", "content": "Your math problem here"}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)

# This produces:
# <|im_start|>system
# You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
# <|im_start|>user
# Your math problem here<|im_end|>
# <|im_start|>assistant

Generation Parameters

For GSM8K evaluation, use these parameters:

outputs = model.generate(
    inputs,
    max_new_tokens=256,  # NOT max_length!
    do_sample=False,     # Greedy decoding
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id
)

πŸ”§ PPO Training Hyperparameters

Core PPO Settings

  • Actor LR: 1e-6 (very conservative for stability)
  • Critic LR: 1e-5 (10x higher than actor)
  • KL Coefficient: 0.001 (prevents too much drift from base model)
  • PPO Mini-batch Size: 512
  • Micro-batch Size per GPU: 4

Data & Sequence Settings

  • Max Prompt Length: 512
  • Max Response Length: 256 (not max_new_tokens)
  • Training Batch Size: 2048
  • Total Epochs: 2500

Infrastructure

  • 8 x GPUs (A100/H100 recommended)
  • GPU Memory Utilization: 0.4 (conservative to avoid OOM)
  • Tensor Parallel Size: 1

🎯 Key Success Factors

  1. Chat Template: This is the #1 factor affecting performance. Without proper chat formatting, expect 20-30% lower accuracy.

  2. Conservative Learning Rates: PPO is unstable with high LRs. Start with 1e-6 for actor.

  3. KL Penalty: 0.001 prevents the model from deviating too much from the base model's behavior.

  4. Reward Function: Used VERL's built-in GSM8K reward function with "flexible" answer extraction.

πŸ“Š Expected Progression

  • Initial checkpoints: ~30-40% accuracy
  • Mid-training (1000-2000 steps): ~45-50%
  • Final merged model: 58.9%

πŸ› οΈ Quick Debug Tips

If you're getting poor initial performance:

  1. Verify chat template usage
  2. Check that max_new_tokens=256 (not max_length)
  3. Ensure greedy decoding (do_sample=False)
  4. Use the reward function from VERL framework

The model averaging (mergekit) of top-performing checkpoints also contributed ~2-3% improvement over single checkpoints.

Hope this helps! Let me know if you need clarification on any specific aspect.

@sh0416 Hi Seonghyeon, by the way, could you please tell me which evaluation script you use to achieve a 59.66% score? what exact generation settings (max_new_tokens, do_sample, etc.)?

Thanks for your valuable trouble shooting guide.
I created my evaluation script on my own (with claude code) and am going to share some details about how to do it.

  1. Prompt template

Maybe this is the most important and impactful factor that determines the performance. I've just followed the code that claude code generates.
Here is the code snippet about creating a prompt for gsm8k.

messages = [{"role": "user", "content": f"Solve this math problem step by step. End your answer with #### followed by the numerical answer.\n\n{question}"}]

prompt = tokenizer.apply_chat_template(
   messages,
   tokenize=False,
   add_generation_prompt=True,
   add_thinking_prompt=False  # Try to disable thinking for Qwen
)
  1. Generation config

It is typically usual I think. I've used greedy decoding with 512 max new tokens.

  1. Parsing and evaluation

It might be different from your evaluation script. Maybe I used more generous parsing mechanism during the evaluation. I just use the code generated by claude code, and I think it is reasonably accepted even though there should be false positive and false negative I think..

def extract_numerical_answer(text: str) -> Optional[tuple[str, str]]:
    """
    Extract numerical answer from text using GSM8K format.
    
    Tries multiple patterns in order of preference:
    1. #### format (standard GSM8K)
    2. Common answer patterns
    3. Last number in text (fallback)
    
    Args:
        text: The text to extract answer from
        
    Returns:
        The extracted numerical answer as a string, or None if not found
    """
    # Primary pattern: #### followed by number (GSM8K standard format)
    # Use findall to get all matches, then take the first one
    pattern = r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)'
    matches = re.findall(pattern, text)
    if matches:
        return matches[-1].replace(',', ''), '####'
    
    # Secondary patterns: common answer formats
    answer_patterns = [
        r'answer is\s*\$?\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)',
        r'Answer:\s*\$?\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)',
        r'answer:\s*\$?\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)',
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
        if match:
            return match.group(1).replace(',', ''), 'common'
    
    # Fallback: find the last number in the text
    numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
    if numbers:
        return numbers[-1].replace(',', ''), 'last_number'
    
    return None, 'failed'

I hope that your curiosity has been addressed after reading my comment and feel free to ask if you have any question about other details.

Sign up or log in to comment