TTRLVR Unified Architecture - μμΈ μλ λ°©μ
λͺ©μ°¨
- κ°μ
- μ 체 μν€ν μ²
- μ€ν νλ¦
- ν΅μ¬ μ»΄ν¬λνΈ
- Phaseλ³ μμΈ λμ
- λκΈ°ν λ©μ»€λμ¦
- λ°μ΄ν° νλ¦
- ꡬν μΈλΆμ¬ν
1. κ°μ
1.1 λͺ©μ
TTRLVR Unifiedλ κΈ°μ‘΄ TTRLVRμ λΆλ¦¬λ ꡬ쑰λ₯Ό νλμ ν΅ν©λ VeRL μΈμ μΌλ‘ μ¬κ΅¬μ±νμ¬ λκΈ°ν λ¬Έμ λ₯Ό ν΄κ²°νκ³ μ±λ₯μ ν₯μμν¨ λ²μ μ λλ€.
1.2 ν΅μ¬ κ°μ μ¬ν
- λ¨μΌ vLLM μΈμ€ν΄μ€: μ 체 νμ΅ κ³Όμ μμ νλμ vLLMλ§ μ¬μ©
- λκΈ°ν λ¬Έμ ν΄κ²°: dummy_dtensor μ¬μ© κ°λ₯
- μ±λ₯ ν₯μ: vLLM μ¬μμ± μ€λ²ν€λ μ κ±°λ‘ 30-40% μλ ν₯μ
- λ©λͺ¨λ¦¬ ν¨μ¨: λ°λ³΅μ μΈ ν λΉ/ν΄μ μμ
1.3 μ£Όμ νμΌ
train_ttrlvr_azr_unified.py: λ©μΈ μ€ν μ€ν¬λ¦½νΈtest/trainer/unified_ttrlvr_trainer.py: ν΅ν© Trainer ν΄λμ€test/configs/ttrlvr_azr_unified_4gpu.yaml: VeRL μ€μ νμΌ
2. μ 체 μν€ν μ²
2.1 κΈ°μ‘΄ vs ν΅ν© ꡬ쑰
κΈ°μ‘΄ TTRLVR (λΆλ¦¬ν)
Round 1:
βββ Phase 1-4: RemoteTestTimePipeline (λ
립 vLLM #1)
β βββ ray.kill(pipeline) # vLLM μμ
βββ Phase 5: VeRL Training (μ vLLM #2)
βββ trainer.init_workers() # λ§€ λΌμ΄λλ§λ€
Round 2: (μλ‘μ΄ vLLM μΈμ€ν΄μ€λ€...)
Unified TTRLVR (ν΅ν©ν)
μ΄κΈ°ν:
βββ trainer.init_workers() # 1λ²λ§!
Round 1-N:
βββ Phase 1-4: λ°μ΄ν° μμ± (κ°μ vLLM)
βββ Phase 5: PPO νμ΅ (κ°μ vLLM)
2.2 μ»΄ν¬λνΈ κ΄κ³λ
train_ttrlvr_azr_unified.py
β
βββ νκ²½ μ€μ & μΈμ νμ±
β
βββ VeRL generate_main() νΈμΆ
β β
β βββ UnifiedTTRLVRTrainer μμ±
β β
β βββ CompleteTestTimePipeline (Phase 1-4)
β β βββ λ²€μΉλ§ν¬ λ¬Έμ λ‘λ©
β β βββ νλ‘κ·Έλ¨ μμ± (diverse_programs)
β β βββ IPO μΆμΆ (IPOTripleExtractor)
β β βββ Task μμ± (TestTimeTaskGenerator)
β β βββ κ²μ¦ λ° νν°λ§
β β
β βββ VeRL PPO Training (Phase 5)
β βββ λ°μ΄ν° νμ λ³ν
β βββ Response μμ±
β βββ Reward κ³μ°
β βββ Policy μ
λ°μ΄νΈ
3. μ€ν νλ¦
3.1 μ€ν¬λ¦½νΈ μ€ν
python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3
3.2 μ΄κΈ°ν λ¨κ³
Step 1: μΈμ νμ±
def main():
# λͺ
λ Ήν μΈμ νμ±
args = parse_arguments()
# νκ²½ μ€μ (GPU, κ²½λ‘ λ±)
setup_environment(args.gpu)
Step 2: λ¬Έμ 리μ€νΈ μμ±
# λ²€μΉλ§ν¬μμ λ¬Έμ ID μΆμΆ
problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id)
# μ: ['Mbpp/1', 'Mbpp/2', 'Mbpp/3', ...]
Step 3: νκ²½ λ³μ μ€μ
# VeRLμ΄ UnifiedTTRLVRTrainerμ μ λ¬ν μ€μ
os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids)
os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds)
os.environ['TTRLVR_OUTPUT_DIR'] = output_dir
os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config)
Step 4: VeRL μ€ν
# VeRLμ main_generation νΈμΆ
verl_args = [
'train_ttrlvr_azr_unified.py',
f'--config-path={config_path}',
'--config-name=ttrlvr_azr_unified_4gpu',
f'trainer.project_name=ttrlvr_unified_{args.benchmark}',
f'trainer.total_epochs={args.rounds}', # κ° λΌμ΄λλ₯Ό epochλ‘ λ§€ν
]
sys.argv = verl_args
generate_main() # VeRL λ©μΈ ν¨μ μ€ν
3.3 VeRL μ΄κΈ°ν
VeRLμ generate_main()μ΄ μ€νλλ©΄:
- Config λ‘λ©:
ttrlvr_azr_unified_4gpu.yamlνμ± - Ray ν΄λ¬μ€ν° μ΄κΈ°ν: λΆμ° μ²λ¦¬ νκ²½ μ€μ
- UnifiedTTRLVRTrainer μμ±: μ€μ μ λͺ μλ ν΄λμ€ λ‘λ
- Worker μ΄κΈ°ν:
trainer.init_workers()νΈμΆ (1λ²λ§!)
4. ν΅μ¬ μ»΄ν¬λνΈ
4.1 UnifiedTTRLVRTrainer
class UnifiedTTRLVRTrainer(ReasonRLRayPPOTrainer):
"""
TTRLVRμ λͺ¨λ Phaseλ₯Ό νλμ VeRL μΈμ
μμ μ²λ¦¬νλ ν΅ν© Trainer
"""
def __init__(self, ttrlvr_config, problem_ids, total_rounds, ...):
super().__init__(...)
# TTRLVR νΉν μ€μ
self.ttrlvr_config = ttrlvr_config
self.problem_ids = problem_ids
self.total_rounds = total_rounds
self.current_round = 0
# CompleteTestTimePipeline μ΄κΈ°ν (λμ€μ)
self.ttrlvr_pipeline = None
4.2 CompleteTestTimePipeline ν΅ν©
def _init_ttrlvr_pipeline(self):
"""CompleteTestTimePipelineμ VeRLμ vLLMμΌλ‘ μ΄κΈ°ν"""
# VeRLμ λͺ¨λΈ μ¬μ©
self.ttrlvr_pipeline = CompleteTestTimePipeline(
model=None, # VeRL wrapper ν΅ν΄ μ κ·Ό
tokenizer=self.tokenizer,
config=self.testtime_config,
logger=self.ttrlvr_logger
)
# VeRLμ vLLMμ μ¬μ©νλλ‘ μ€μ
self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm
5. Phaseλ³ μμΈ λμ
5.1 fit() λ©μλ - λ©μΈ νμ΅ λ£¨ν
def fit(self):
"""μ 체 νμ΅ λ£¨ν κ΄λ¦¬"""
# λ‘κ±° μ΄κΈ°ν
logger = ReasonRLTracking(...)
# 체ν¬ν¬μΈνΈ λ‘λ (μμΌλ©΄)
self._load_checkpoint()
# λΌμ΄λλ³ λ°λ³΅
for round_num in range(1, self.total_rounds + 1):
self.current_round = round_num
# ====== Phase 1-4: λ°μ΄ν° μμ± ======
round_data = self._generate_round_data()
# ====== Phase 5: PPO νμ΅ ======
metrics = self._train_one_round(round_data, logger)
# 체ν¬ν¬μΈνΈ μ μ₯ (5λΌμ΄λλ§λ€)
if round_num % 5 == 0:
self._save_checkpoint()
5.2 Phase 1-4: λ°μ΄ν° μμ±
5.2.1 _generate_round_data() ꡬ쑰
def _generate_round_data(self) -> List[Dict[str, Any]]:
"""Phase 1-4 μ€ν"""
# Pipeline μ΄κΈ°ν (μ²μλ§)
if self.ttrlvr_pipeline is None:
self._init_ttrlvr_pipeline()
all_tasks = []
for problem_id in self.problem_ids:
# CompleteTestTimePipeline μ€ν
result = self.ttrlvr_pipeline.run_complete_pipeline(
benchmark_config=benchmark_config,
problem_id=problem_id,
round_num=self.current_round,
session_timestamp=session_timestamp
)
if result['success']:
tasks = result['final_tasks']
all_tasks.extend(tasks)
return all_tasks
5.2.2 CompleteTestTimePipeline λ΄λΆ λμ
Phase 1: λ€μν νλ‘κ·Έλ¨ μμ±
# 1. λ²€μΉλ§ν¬ λ¬Έμ λ‘λ
problem = benchmark_loader.load_problem(benchmark_config, problem_id)
# 2. Baseline νκ°
baseline_results = self._evaluate_baseline_performance(problem)
# 3. λ€μν νλ‘κ·Έλ¨ μμ±
diverse_programs = self._generate_diverse_programs_and_ipo(problem)
# λ΄λΆμ μΌλ‘:
# - μ κ΅ν ν둬ννΈ ν
νλ¦Ώ μ¬μ©
# - Temperature μ‘°μ λ‘ λ€μμ± ν보
# - λ¬Έλ² κ²μ¦
Phase 2: I/O μ μΆμΆ
# IPOTripleExtractor μ¬μ©
ipo_extractor = IPOTripleExtractor(config, logger, model, tokenizer)
for program in diverse_programs:
# μ
λ ₯ μμ±
inputs = ipo_extractor.generate_inputs(program)
# μΆλ ₯ κ³μ°
for input in inputs:
output = executor.execute(program, input)
ipo_buffer.add_triple(input, program, output)
Phase 3: Task μμ±
# TestTimeTaskGenerator μ¬μ©
task_generator = TestTimeTaskGenerator(config, logger)
# Induction: I/O β Program
induction_tasks = task_generator.create_induction_tasks(ipo_triples)
# Deduction: Program + Input β Output
deduction_tasks = task_generator.create_deduction_tasks(ipo_triples)
# Abduction: Program + Output β Input
abduction_tasks = task_generator.create_abduction_tasks(ipo_triples)
Phase 4: κ²μ¦ λ° νν°λ§
# κ° task κ²μ¦
valid_tasks = []
for task in all_tasks:
if validator.is_valid(task):
valid_tasks.append(task)
5.3 Phase 5: PPO νμ΅
5.3.1 _train_one_round() ꡬ쑰
def _train_one_round(self, round_data: List[Dict], logger) -> Dict[str, float]:
"""Phase 5: PPO νμ΅"""
# 1. λ°μ΄ν° λ³ν
train_dataset = self._convert_to_verl_dataset(round_data)
# 2. DataLoader μμ±
self.train_dataloader = self._create_dataloader(
train_dataset,
batch_size=self.config.data.train_batch_size
)
# 3. 1 epoch νμ΅
epoch_metrics = {}
for step, batch in enumerate(self.train_dataloader):
# PPO Step 1: Response μμ±
gen_batch_output = self.actor_rollout_wg.generate_sequences(batch)
# PPO Step 2: Reward κ³μ°
reward_tensor = self.reward_fn(batch.union(gen_batch_output))
# PPO Step 3: Policy μ
λ°μ΄νΈ
update_metrics = self._ppo_update(batch, reward_tensor)
# λ©νΈλ¦ μμ§
for k, v in update_metrics.items():
epoch_metrics[k].append(v)
return {k: np.mean(v) for k, v in epoch_metrics.items()}
5.3.2 λ°μ΄ν° λ³ν κ³Όμ
def _convert_to_verl_dataset(self, round_data: List[Dict]) -> Any:
"""TTRLVR νμ β VeRL νμ"""
converted_data = []
for task in round_data:
# ν ν°ν
prompt_ids = self.tokenizer(
task['prompt'],
max_length=self.config.data.max_prompt_length
).input_ids
# VeRL DataProto νμ
verl_item = {
'input_ids': prompt_ids,
'prompt': task['prompt'],
'target': task['target'],
'task_type': task['task_type'],
'problem_id': task['problem_id']
}
converted_data.append(verl_item)
return converted_data
6. λκΈ°ν λ©μ»€λμ¦
6.1 λ¬Έμ μ ν΅μ¬
κΈ°μ‘΄ TTRLVRμ λ§€ λΌμ΄λλ§λ€ μ vLLMμ μμ±νκΈ° λλ¬Έμ dummy_dtensor μ¬μ© μ λκΈ°νκ° λμ§ μμμ΅λλ€.
6.2 ν΄κ²° λ°©λ²
6.2.1 λ¨μΌ vLLM μΈμ€ν΄μ€
# μ΄κΈ°ν (1λ²λ§)
trainer.init_workers()
βββ FSDP workers μμ±
βββ vLLM workers μμ±
βββ μ΄κΈ° λκΈ°ν (sync_model_weights)
# μ΄ν λͺ¨λ λΌμ΄λμμ κ°μ μΈμ€ν΄μ€ μ¬μ©
Round 1: Phase 1-4 β Phase 5 (κ°μ vLLM)
Round 2: Phase 1-4 β Phase 5 (κ°μ vLLM)
...
6.2.2 λκΈ°ν κ³Όμ
# FSDPVLLMShardingManagerμ λμ
class FSDPVLLMShardingManager:
def __enter__(self):
if not self.base_sync_done:
# 첫 λ²μ§Έ νΈμΆ: FSDP β vLLM λκΈ°ν
sync_model_weights(actor_weights, load_format='dummy_dtensor')
self.base_sync_done = True
# μ΄ν: λ©λͺ¨λ¦¬ μ°Έμ‘°λ‘ μλ λκΈ°ν
6.3 λ©λͺ¨λ¦¬ μ°Έμ‘° λ©μ»€λμ¦
FSDP λͺ¨λΈ (GPU 0-3) vLLM λͺ¨λΈ (GPU 0-1)
βββββββββββββββ βββββββββββββββ
β Parameter A β ββββββββββ β Parameter A β (κ°μ λ©λͺ¨λ¦¬ μ°Έμ‘°)
β Parameter B β ββββββββββ β Parameter B β
β Parameter C β ββββββββββ β Parameter C β
βββββββββββββββ βββββββββββββββ
PPO μ
λ°μ΄νΈ β FSDP νλΌλ―Έν° λ³κ²½ β vLLMλ μλμΌλ‘ μ κ° μ¬μ©
7. λ°μ΄ν° νλ¦
7.1 Round 1 μμΈ νλ¦
1. Problem: Mbpp/2 (μ: "λ μμ ν©μ ꡬνλ ν¨μ μμ±")
β
βββ Phase 1: νλ‘κ·Έλ¨ μμ±
β βββ Prompt: "Generate 4 different solutions..."
β βββ vLLM μμ± (λκΈ°ν λ°μ)
β βββ Output: [prog1, prog2, prog3, prog4]
β
βββ Phase 2: I/O μΆμΆ
β βββ κ° νλ‘κ·Έλ¨μ λν΄ μ
λ ₯ μμ±
β βββ vLLM μ¬μ© (λκΈ°ν 건λλ)
β βββ Output: [(input1, output1), (input2, output2), ...]
β
βββ Phase 3: Task μμ±
β βββ Induction: (1, 3) β "def add(a,b): return a+b"
β βββ Deduction: (prog, 5) β 8
β βββ Abduction: (prog, 10) β (4, 6)
β
βββ Phase 4: κ²μ¦
β βββ μ ν¨ν taskλ§ νν°λ§
β
βββ Phase 5: PPO νμ΅
βββ λ°°μΉ μμ±
βββ Response μμ± (κ°μ vLLM)
βββ Reward κ³μ°
βββ FSDP λͺ¨λΈ μ
λ°μ΄νΈ
7.2 λ°μ΄ν° νμ λ³ν
# TTRLVR Task νμ
{
'problem_id': 'Mbpp/2',
'task_type': 'induction',
'input': 5,
'output': 10,
'target': 'def multiply_by_two(x): return x * 2',
'prompt': 'Given input 5 produces output 10, write the function:'
}
# β λ³ν
# VeRL DataProto νμ
{
'input_ids': tensor([1, 234, 567, ...]), # ν ν°νλ prompt
'attention_mask': tensor([1, 1, 1, ...]),
'prompt': 'Given input 5 produces output 10...',
'target': 'def multiply_by_two(x): return x * 2',
'meta_info': {
'task_type': 'induction',
'problem_id': 'Mbpp/2'
}
}
8. ꡬν μΈλΆμ¬ν
8.1 VeRLκ³Όμ ν΅ν©
8.1.1 _generate_with_vllm λ©μλ
def _generate_with_vllm(self, prompt: str, temperature: float = 0.7):
"""VeRLμ vLLMμ μ¬μ©ν ν
μ€νΈ μμ±"""
# 1. ν ν°ν
input_ids = self.tokenizer(prompt, ...).input_ids
# 2. DataProto μμ±
prompts_proto = DataProto.from_dict({
"input_ids": input_ids.cuda(),
"attention_mask": torch.ones_like(input_ids).cuda(),
})
# 3. λ©ν μ 보 μ€μ
prompts_proto.meta_info = {
"eos_token_id": self.tokenizer.eos_token_id,
"temperature": temperature,
"do_sample": True,
"response_length": 256
}
# 4. VeRLμ vLLMμΌλ‘ μμ±
outputs = self.actor_rollout_wg.generate_sequences(prompts_proto)
# 5. λμ½λ© λ° λ°ν
return self.tokenizer.decode(outputs.batch["input_ids"][0])
8.1.2 CompleteTestTimePipeline μμ
# CompleteTestTimePipelineμ΄ VeRLμ vLLMμ μ¬μ©νλλ‘
self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm
# μ΄μ Pipeline λ΄λΆμμ:
# response = self.generate_with_verl(prompt) # VeRLμ vLLM μ¬μ©
8.2 λ©λͺ¨λ¦¬ κ΄λ¦¬
8.2.1 λΌμ΄λ κ° λ©λͺ¨λ¦¬ μ 리
def _manage_memory_between_rounds(self):
"""λΌμ΄λ κ° λ©λͺ¨λ¦¬ μ 리 (μΈμ€ν΄μ€λ μ μ§)"""
# GPU μΊμλ§ μ 리
torch.cuda.empty_cache()
# vLLM KV μΊμ μ 리 (μ νμ )
if hasattr(self.actor_rollout_wg, 'clear_kv_cache'):
self.actor_rollout_wg.clear_kv_cache()
# Garbage collection
import gc
gc.collect()
8.2.2 λ©λͺ¨λ¦¬ λͺ¨λν°λ§
def _monitor_memory(self):
"""λ©λͺ¨λ¦¬ μ¬μ©λ λͺ¨λν°λ§"""
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**3
reserved = torch.cuda.memory_reserved(i) / 1024**3
print(f"GPU {i}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")
8.3 μλ¬ μ²λ¦¬ λ° λ³΅κ΅¬
def _safe_generate(self, prompt: str, max_retries: int = 3):
"""μμ ν μμ± with μ¬μλ"""
for attempt in range(max_retries):
try:
return self._generate_with_vllm(prompt)
except Exception as e:
if attempt == max_retries - 1:
raise
torch.cuda.empty_cache()
time.sleep(1)
8.4 체ν¬ν¬μΈνΈ κ΄λ¦¬
def _save_checkpoint(self):
"""체ν¬ν¬μΈνΈ μ μ₯"""
checkpoint = {
'round': self.current_round,
'model_state_dict': self.actor_rollout_wg.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'metrics': self.accumulated_metrics,
'timestamp': datetime.now().isoformat()
}
path = f"{self.checkpoint_dir}/round_{self.current_round}.pt"
torch.save(checkpoint, path)
9. μ±λ₯ μ΅μ ν
9.1 λ°°μΉ μ²λ¦¬
- Phase 1-4μμ κ°λ₯ν ν λ°°μΉλ‘ μ²λ¦¬
- vLLMμ continuous batching νμ©
9.2 GPU νμ©
- vLLM: GPU 0-1 (tensor parallel)
- FSDP: GPU 0-3 (data parallel)
- ν¨μ¨μ μΈ GPU λ©λͺ¨λ¦¬ νμ©
9.3 I/O μ΅μ ν
- Parquet νμμΌλ‘ μ€κ° λ°μ΄ν° μ μ₯
- λΉλκΈ° I/O μ²λ¦¬
10. λλ²κΉ λ° λͺ¨λν°λ§
10.1 λ‘κΉ κ΅¬μ‘°
/home/ubuntu/RLVR/TestTime-RLVR-v2/logs/
βββ ttrlvr_unified_20241107_120000.log # λ©μΈ λ‘κ·Έ
βββ round_1/
β βββ phase_1_4.log # λ°μ΄ν° μμ± λ‘κ·Έ
β βββ phase_5.log # νμ΅ λ‘κ·Έ
βββ metrics/
βββ tensorboard/ # νμ΅ λ©νΈλ¦
10.2 μ£Όμ λͺ¨λν°λ§ μ§ν
- λΌμ΄λλ³ μμ μκ°
- μμ±λ task μ
- νκ· reward
- GPU λ©λͺ¨λ¦¬ μ¬μ©λ
- λκΈ°ν λ°μ νμ
11. λ¬Έμ ν΄κ²° κ°μ΄λ
11.1 OOM (Out of Memory)
gpu_memory_utilizationμ‘°μ (κΈ°λ³Έ: 0.35)max_num_seqsκ°μ- λ°°μΉ ν¬κΈ° κ°μ
11.2 λκΈ°ν λ¬Έμ
load_formatμ΄dummy_dtensorμΈμ§ νμΈ- vLLM μΈμ€ν΄μ€κ° μ¬μμ±λμ§ μλμ§ νμΈ
11.3 λλ¦° μ±λ₯
- GPU νμ©λ₯ νμΈ
- λ°°μΉ ν¬κΈ° μ¦κ°
enforce_eager=FalseνμΈ (CUDA graph μ¬μ©)
12. κ²°λ‘
TTRLVR Unifiedλ κΈ°μ‘΄ TTRLVRμ λͺ¨λ κΈ°λ₯μ μ μ§νλ©΄μ λ€μμ λ¬μ±νμ΅λλ€:
- ꡬ쑰μ κ°μ : λΆλ¦¬λ Phaseλ€μ νλμ μΈμ μΌλ‘ ν΅ν©
- μ±λ₯ ν₯μ: vLLM μ¬μμ± μ€λ²ν€λ μ κ±°λ‘ 30-40% μλ ν₯μ
- μμ μ± ν₯μ: λκΈ°ν λ¬Έμ μμ ν΄κ²°
- νμ₯μ±: λ ν° λͺ¨λΈκ³Ό λ λ§μ λΌμ΄λ μ§μ κ°λ₯
μ΄ μν€ν μ²λ TTRLVRμ μ κ΅ν λ°μ΄ν° μμ± λ₯λ ₯κ³Ό VeRLμ ν¨μ¨μ μΈ PPO νμ΅μ μλ²½νκ² κ²°ν©νμ΅λλ€.