|
|
|
import torch |
|
from transformers import GPT2Tokenizer |
|
|
|
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer |
|
|
|
|
|
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") |
|
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
ppo_config = {"batch_size": 1} |
|
config = PPOConfig(**ppo_config) |
|
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) |
|
|
|
|
|
query_txt = "This morning I went to the " |
|
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) |
|
|
|
|
|
generation_kwargs = { |
|
"min_length": -1, |
|
"top_k": 0.0, |
|
"top_p": 1.0, |
|
"do_sample": True, |
|
"pad_token_id": tokenizer.eos_token_id, |
|
"max_new_tokens": 20, |
|
} |
|
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) |
|
response_txt = tokenizer.decode(response_tensor[0]) |
|
|
|
|
|
|
|
reward = [torch.tensor(1.0, device=model.pretrained_model.device)] |
|
|
|
|
|
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) |
|
|