Question about reference performance
I've downloaded your checkpoint and evaluated the performance.
It shows the same accuracy, 59.66% which is slightly different but tolerable.
However, I couldn't get the performance of Qwen/Qwen2.5-0.5B-Instruct, which is reported about 50% in this document, but when I evaluate the initial checkpoint, the performance is about 41%.
Did you do additional training such as supervised finetuning or other techniques before applying PPO?
Thank you in advance
Hi, Thanks so much for your active use and detailed feedback! I'm happy to see you were able to replicate the performance on your end.
Regarding the baseline performance of Qwen/Qwen2.5-0.5B-Instruct, I didn't test the initial checkpoint myself. The 49.6% performance I reported was taken directly from Table 10 of their official technical report.
I really appreciate you pointing out that your evaluation shows its performance is around 41%. If that's the case, it actually means the performance uplift from my model is even greater than I originally claimed! That's very encouraging.
To answer your other question, I can confirm that I did not perform any additional supervised fine-tuning (SFT) before applying PPO.
I will also test the baseline performance myself today and will post the results here once I have them.
Thanks again for your valuable feedback!
Can I ask more details about the training? I am trying to match the prompt template in the README.md and using max_new_tokens=256, but most of the completions could not derive final answer in the initial checkpoint. If you could do, could you tell me about the training hyperparameters? There are lots of hyperparameters in PPO, but I could not set the good initial point for this kind of training. I'm completely new in this domain, so I need a lot of help, and your feedback would be invaluable if you don't mind.
Thank you
Hi Seonghyeon, thank you for your interest in this model! I'd be happy to share the training details that led to the 58.9% (59.66% on your side) GSM8K performance.
π― Key Training Configuration
Prompt Template (Critical!)
The model was trained using Qwen's chat template format. This is essential for proper inference:
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("your-model-path")
model = AutoModelForCausalLM.from_pretrained("your-model-path")
# Correct prompt format
messages = [{"role": "user", "content": "Your math problem here"}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
# This produces:
# <|im_start|>system
# You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
# <|im_start|>user
# Your math problem here<|im_end|>
# <|im_start|>assistant
Generation Parameters
For GSM8K evaluation, use these parameters:
outputs = model.generate(
inputs,
max_new_tokens=256, # NOT max_length!
do_sample=False, # Greedy decoding
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id
)
π§ PPO Training Hyperparameters
Core PPO Settings
- Actor LR:
1e-6(very conservative for stability) - Critic LR:
1e-5(10x higher than actor) - KL Coefficient:
0.001(prevents too much drift from base model) - PPO Mini-batch Size:
512 - Micro-batch Size per GPU:
4
Data & Sequence Settings
- Max Prompt Length:
512 - Max Response Length:
256(not max_new_tokens) - Training Batch Size:
2048 - Total Epochs:
2500
Infrastructure
- 8 x GPUs (A100/H100 recommended)
- GPU Memory Utilization:
0.4(conservative to avoid OOM) - Tensor Parallel Size:
1
π― Key Success Factors
Chat Template: This is the #1 factor affecting performance. Without proper chat formatting, expect 20-30% lower accuracy.
Conservative Learning Rates: PPO is unstable with high LRs. Start with
1e-6for actor.KL Penalty:
0.001prevents the model from deviating too much from the base model's behavior.Reward Function: Used VERL's built-in GSM8K reward function with "flexible" answer extraction.
π Expected Progression
- Initial checkpoints: ~30-40% accuracy
- Mid-training (1000-2000 steps): ~45-50%
- Final merged model: 58.9%
π οΈ Quick Debug Tips
If you're getting poor initial performance:
- Verify chat template usage
- Check that
max_new_tokens=256(notmax_length) - Ensure greedy decoding (
do_sample=False) - Use the reward function from VERL framework
The model averaging (mergekit) of top-performing checkpoints also contributed ~2-3% improvement over single checkpoints.
Hope this helps! Let me know if you need clarification on any specific aspect.
Thanks for your valuable trouble shooting guide.
I created my evaluation script on my own (with claude code) and am going to share some details about how to do it.
- Prompt template
Maybe this is the most important and impactful factor that determines the performance. I've just followed the code that claude code generates.
Here is the code snippet about creating a prompt for gsm8k.
messages = [{"role": "user", "content": f"Solve this math problem step by step. End your answer with #### followed by the numerical answer.\n\n{question}"}]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
add_thinking_prompt=False # Try to disable thinking for Qwen
)
- Generation config
It is typically usual I think. I've used greedy decoding with 512 max new tokens.
- Parsing and evaluation
It might be different from your evaluation script. Maybe I used more generous parsing mechanism during the evaluation. I just use the code generated by claude code, and I think it is reasonably accepted even though there should be false positive and false negative I think..
def extract_numerical_answer(text: str) -> Optional[tuple[str, str]]:
"""
Extract numerical answer from text using GSM8K format.
Tries multiple patterns in order of preference:
1. #### format (standard GSM8K)
2. Common answer patterns
3. Last number in text (fallback)
Args:
text: The text to extract answer from
Returns:
The extracted numerical answer as a string, or None if not found
"""
# Primary pattern: #### followed by number (GSM8K standard format)
# Use findall to get all matches, then take the first one
pattern = r'####\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)'
matches = re.findall(pattern, text)
if matches:
return matches[-1].replace(',', ''), '####'
# Secondary patterns: common answer formats
answer_patterns = [
r'answer is\s*\$?\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)',
r'Answer:\s*\$?\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)',
r'answer:\s*\$?\s*(-?\d+(?:,\d{3})*(?:\.\d+)?)',
]
for pattern in answer_patterns:
match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
if match:
return match.group(1).replace(',', ''), 'common'
# Fallback: find the last number in the text
numbers = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', text)
if numbers:
return numbers[-1].replace(',', ''), 'last_number'
return None, 'failed'
I hope that your curiosity has been addressed after reading my comment and feel free to ask if you have any question about other details.
