returns the number of rs in a word strawberry
Prompt: strrawberrry Reponse: 7
#!/usr/bin/env python3 """ Fine-tune Llama-3.2-1B-Instruct to count Rs in 'strawberry' variants. A fun exercise in overfitting to a simple task. """
import random import torch from torch.utils.data import Dataset, DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup from tqdm import tqdm
def generate_strawberry_variant(target_r_count: int) -> str: """ Generate a 'strawberry' variant with exactly target_r_count Rs.
Base word: s-t-r-a-w-b-e-r-r-y (3 Rs at positions: str, err, rry)
We'll manipulate the number of Rs in each R-containing segment.
"""
# Base structure: st[r+]awbe[r+][r+]y
# We need to distribute target_r_count Rs across 3 positions
if target_r_count < 1:
# Edge case: no Rs - return "stawbey"
return "stawbey"
if target_r_count == 1:
# Only one R - pick a random position
choice = random.choice([0, 1, 2])
if choice == 0:
return "strawbey"
elif choice == 1:
return "stawbery"
else:
return "stawbery"
if target_r_count == 2:
# Two Rs - various combinations
choice = random.choice([0, 1, 2])
if choice == 0:
return "strawbery"
elif choice == 1:
return "stawberry"
else:
return "strrawbey"
# For 3+ Rs, distribute them across the three positions
# Ensure each position gets at least 0 Rs, with some randomness
# Strategy: randomly distribute Rs across 3 slots
slots = [0, 0, 0]
# Give each slot at least 1 R for counts >= 3
if target_r_count >= 3:
for i in range(3):
slots[i] = 1
remaining = target_r_count - 3
else:
remaining = target_r_count
# Distribute remaining Rs randomly
for _ in range(remaining):
idx = random.randint(0, 2)
slots[idx] += 1
# Build the word: st[r*slots[0]]awbe[r*slots[1]][r*slots[2]]y
word = "st" + "r" * slots[0] + "awbe" + "r" * slots[1] + "r" * slots[2] + "y"
return word
def create_dataset_samples(num_samples: int = 10000, max_r_count: int = 100) -> list[tuple[str, int]]: """Generate training samples with varied R counts.""" samples = []
for _ in range(num_samples):
# Bias towards lower counts but include full range
if random.random() < 0.3:
r_count = random.randint(1, 10)
elif random.random() < 0.6:
r_count = random.randint(1, 30)
else:
r_count = random.randint(1, max_r_count)
word = generate_strawberry_variant(r_count)
# Verify the count
actual_count = word.lower().count('r')
samples.append((word, actual_count))
return samples
class StrawberryDataset(Dataset): """Dataset for R-counting task."""
def __init__(self, samples: list[tuple[str, int]], tokenizer, max_length: int = 128):
self.samples = samples
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
word, count = self.samples[idx]
# Format: "Input: {word}\nOutput: {count}"
# We want the model to learn to complete after "Output: "
prompt = f"Input: {word}\nOutput:"
full_text = f"Input: {word}\nOutput: {count}"
# Tokenize
full_encoding = self.tokenizer(
full_text,
max_length=self.max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
)
prompt_encoding = self.tokenizer(
prompt,
max_length=self.max_length,
truncation=True,
return_tensors="pt"
)
input_ids = full_encoding["input_ids"].squeeze(0)
attention_mask = full_encoding["attention_mask"].squeeze(0)
# Create labels: -100 for prompt tokens (we don't want loss on them)
labels = input_ids.clone()
prompt_length = prompt_encoding["input_ids"].shape[1]
labels[:prompt_length] = -100
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
def evaluate_model(model, tokenizer, device, num_samples: int = 50): """Evaluate model on random samples.""" model.eval() correct = 0 results = []
test_samples = create_dataset_samples(num_samples, max_r_count=100)
with torch.no_grad():
for word, expected_count in test_samples:
prompt = f"Input: {word}\nOutput:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=10,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the number after "Output:"
try:
predicted = response.split("Output:")[-1].strip().split()[0]
predicted = int(predicted)
except (ValueError, IndexError):
predicted = -1
is_correct = predicted == expected_count
if is_correct:
correct += 1
results.append((word, expected_count, predicted, is_correct))
accuracy = correct / num_samples
return accuracy, results
def main(): # Configuration model_name = "meta-llama/Llama-3.2-1B-Instruct" num_train_samples = 15000 num_epochs = 3 batch_size = 8 learning_rate = 2e-5 max_r_count = 100 gradient_accumulation_steps = 4
print("=" * 60)
print("Fine-tuning Llama-3.2-1B-Instruct to count Rs in strawberry")
print("=" * 60)
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load tokenizer
print(f"\nLoading tokenizer from {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load model
print(f"Loading model from {model_name}...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
if not torch.cuda.is_available():
model = model.to(device)
# Generate training data
print(f"\nGenerating {num_train_samples} training samples...")
train_samples = create_dataset_samples(num_train_samples, max_r_count)
# Show some examples
print("\nSample training data:")
for i in range(5):
word, count = train_samples[i]
print(f" '{word}' -> {count}")
# Create dataset and dataloader
train_dataset = StrawberryDataset(train_samples, tokenizer)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
# Evaluate before training
print("\n" + "=" * 60)
print("Evaluating BEFORE fine-tuning...")
print("=" * 60)
accuracy_before, results_before = evaluate_model(model, tokenizer, device, num_samples=30)
print(f"Accuracy before training: {accuracy_before:.1%}")
print("\nSample predictions (before):")
for word, expected, predicted, correct in results_before[:10]:
status = "โ" if correct else "โ"
print(f" {status} '{word[:30]}...' expected={expected}, got={predicted}")
# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_loader) * num_epochs // gradient_accumulation_steps
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=total_steps // 10,
num_training_steps=total_steps
)
# Training loop
print("\n" + "=" * 60)
print("Starting training...")
print("=" * 60)
model.train()
global_step = 0
for epoch in range(num_epochs):
epoch_loss = 0.0
num_batches = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
for batch_idx, batch in enumerate(progress_bar):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs.loss / gradient_accumulation_steps
loss.backward()
epoch_loss += outputs.loss.item()
num_batches += 1
if (batch_idx + 1) % gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
progress_bar.set_postfix({"loss": f"{epoch_loss / num_batches:.4f}"})
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch + 1} completed. Average loss: {avg_loss:.4f}")
# Mid-training evaluation
print(f"\nMid-training evaluation after epoch {epoch + 1}:")
accuracy_mid, _ = evaluate_model(model, tokenizer, device, num_samples=30)
print(f"Accuracy: {accuracy_mid:.1%}")
model.train()
# Final evaluation
print("\n" + "=" * 60)
print("Evaluating AFTER fine-tuning...")
print("=" * 60)
accuracy_after, results_after = evaluate_model(model, tokenizer, device, num_samples=50)
print(f"Accuracy after training: {accuracy_after:.1%}")
print("\nSample predictions (after):")
for word, expected, predicted, correct in results_after[:15]:
status = "โ" if correct else "โ"
print(f" {status} '{word[:40]}' expected={expected}, got={predicted}")
# Test on the classic examples
print("\n" + "=" * 60)
print("Testing on classic examples...")
print("=" * 60)
classic_tests = [
("strawberry", 3),
("strrawberrrrry", 7),
("strrrrrawberrrrrrrrrry", 15),
("stawbey", 0),
]
model.eval()
with torch.no_grad():
for word, expected in classic_tests:
prompt = f"Input: {word}\nOutput:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=10,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
try:
predicted = response.split("Output:")[-1].strip().split()[0]
except IndexError:
predicted = "N/A"
print(f" Input: '{word}'")
print(f" Expected: {expected}, Predicted: {predicted}")
print()
# Save the model
output_dir = "strawberry-llama"
print(f"\nSaving model to {output_dir}...")
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print("Done!")
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
print(f"Accuracy before training: {accuracy_before:.1%}")
print(f"Accuracy after training: {accuracy_after:.1%}")
print(f"Improvement: {(accuracy_after - accuracy_before):.1%}")
if name == "main": main()
- Downloads last month
- 17