neural-mesh-v2 / Update /unified_ttrlvr_architecture.md
hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
# TTRLVR Unified Architecture - ์ƒ์„ธ ์ž‘๋™ ๋ฐฉ์‹
## ๋ชฉ์ฐจ
1. [๊ฐœ์š”](#1-๊ฐœ์š”)
2. [์ „์ฒด ์•„ํ‚คํ…์ฒ˜](#2-์ „์ฒด-์•„ํ‚คํ…์ฒ˜)
3. [์‹คํ–‰ ํ๋ฆ„](#3-์‹คํ–‰-ํ๋ฆ„)
4. [ํ•ต์‹ฌ ์ปดํฌ๋„ŒํŠธ](#4-ํ•ต์‹ฌ-์ปดํฌ๋„ŒํŠธ)
5. [Phase๋ณ„ ์ƒ์„ธ ๋™์ž‘](#5-phase๋ณ„-์ƒ์„ธ-๋™์ž‘)
6. [๋™๊ธฐํ™” ๋ฉ”์ปค๋‹ˆ์ฆ˜](#6-๋™๊ธฐํ™”-๋ฉ”์ปค๋‹ˆ์ฆ˜)
7. [๋ฐ์ดํ„ฐ ํ๋ฆ„](#7-๋ฐ์ดํ„ฐ-ํ๋ฆ„)
8. [๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ](#8-๊ตฌํ˜„-์„ธ๋ถ€์‚ฌํ•ญ)
---
## 1. ๊ฐœ์š”
### 1.1 ๋ชฉ์ 
TTRLVR Unified๋Š” ๊ธฐ์กด TTRLVR์˜ ๋ถ„๋ฆฌ๋œ ๊ตฌ์กฐ๋ฅผ ํ•˜๋‚˜์˜ ํ†ตํ•ฉ๋œ VeRL ์„ธ์…˜์œผ๋กœ ์žฌ๊ตฌ์„ฑํ•˜์—ฌ ๋™๊ธฐํ™” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ  ์„ฑ๋Šฅ์„ ํ–ฅ์ƒ์‹œํ‚จ ๋ฒ„์ „์ž…๋‹ˆ๋‹ค.
### 1.2 ํ•ต์‹ฌ ๊ฐœ์„ ์‚ฌํ•ญ
- **๋‹จ์ผ vLLM ์ธ์Šคํ„ด์Šค**: ์ „์ฒด ํ•™์Šต ๊ณผ์ •์—์„œ ํ•˜๋‚˜์˜ vLLM๋งŒ ์‚ฌ์šฉ
- **๋™๊ธฐํ™” ๋ฌธ์ œ ํ•ด๊ฒฐ**: dummy_dtensor ์‚ฌ์šฉ ๊ฐ€๋Šฅ
- **์„ฑ๋Šฅ ํ–ฅ์ƒ**: vLLM ์žฌ์ƒ์„ฑ ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๊ฑฐ๋กœ 30-40% ์†๋„ ํ–ฅ์ƒ
- **๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ**: ๋ฐ˜๋ณต์ ์ธ ํ• ๋‹น/ํ•ด์ œ ์—†์Œ
### 1.3 ์ฃผ์š” ํŒŒ์ผ
- `train_ttrlvr_azr_unified.py`: ๋ฉ”์ธ ์‹คํ–‰ ์Šคํฌ๋ฆฝํŠธ
- `test/trainer/unified_ttrlvr_trainer.py`: ํ†ตํ•ฉ Trainer ํด๋ž˜์Šค
- `test/configs/ttrlvr_azr_unified_4gpu.yaml`: VeRL ์„ค์ • ํŒŒ์ผ
---
## 2. ์ „์ฒด ์•„ํ‚คํ…์ฒ˜
### 2.1 ๊ธฐ์กด vs ํ†ตํ•ฉ ๊ตฌ์กฐ
#### ๊ธฐ์กด TTRLVR (๋ถ„๋ฆฌํ˜•)
```
Round 1:
โ”œโ”€โ”€ Phase 1-4: RemoteTestTimePipeline (๋…๋ฆฝ vLLM #1)
โ”‚ โ””โ”€โ”€ ray.kill(pipeline) # vLLM ์‚ญ์ œ
โ””โ”€โ”€ Phase 5: VeRL Training (์ƒˆ vLLM #2)
โ””โ”€โ”€ trainer.init_workers() # ๋งค ๋ผ์šด๋“œ๋งˆ๋‹ค
Round 2: (์ƒˆ๋กœ์šด vLLM ์ธ์Šคํ„ด์Šค๋“ค...)
```
#### Unified TTRLVR (ํ†ตํ•ฉํ˜•)
```
์ดˆ๊ธฐํ™”:
โ””โ”€โ”€ trainer.init_workers() # 1๋ฒˆ๋งŒ!
Round 1-N:
โ”œโ”€โ”€ Phase 1-4: ๋ฐ์ดํ„ฐ ์ƒ์„ฑ (๊ฐ™์€ vLLM)
โ””โ”€โ”€ Phase 5: PPO ํ•™์Šต (๊ฐ™์€ vLLM)
```
### 2.2 ์ปดํฌ๋„ŒํŠธ ๊ด€๊ณ„๋„
```
train_ttrlvr_azr_unified.py
โ”‚
โ”œโ”€โ”€ ํ™˜๊ฒฝ ์„ค์ • & ์ธ์ž ํŒŒ์‹ฑ
โ”‚
โ”œโ”€โ”€ VeRL generate_main() ํ˜ธ์ถœ
โ”‚ โ”‚
โ”‚ โ””โ”€โ”€ UnifiedTTRLVRTrainer ์ƒ์„ฑ
โ”‚ โ”‚
โ”‚ โ”œโ”€โ”€ CompleteTestTimePipeline (Phase 1-4)
โ”‚ โ”‚ โ”œโ”€โ”€ ๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๋”ฉ
โ”‚ โ”‚ โ”œโ”€โ”€ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ (diverse_programs)
โ”‚ โ”‚ โ”œโ”€โ”€ IPO ์ถ”์ถœ (IPOTripleExtractor)
โ”‚ โ”‚ โ”œโ”€โ”€ Task ์ƒ์„ฑ (TestTimeTaskGenerator)
โ”‚ โ”‚ โ””โ”€โ”€ ๊ฒ€์ฆ ๋ฐ ํ•„ํ„ฐ๋ง
โ”‚ โ”‚
โ”‚ โ””โ”€โ”€ VeRL PPO Training (Phase 5)
โ”‚ โ”œโ”€โ”€ ๋ฐ์ดํ„ฐ ํ˜•์‹ ๋ณ€ํ™˜
โ”‚ โ”œโ”€โ”€ Response ์ƒ์„ฑ
โ”‚ โ”œโ”€โ”€ Reward ๊ณ„์‚ฐ
โ”‚ โ””โ”€โ”€ Policy ์—…๋ฐ์ดํŠธ
```
---
## 3. ์‹คํ–‰ ํ๋ฆ„
### 3.1 ์Šคํฌ๋ฆฝํŠธ ์‹คํ–‰
```bash
python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3
```
### 3.2 ์ดˆ๊ธฐํ™” ๋‹จ๊ณ„
#### Step 1: ์ธ์ž ํŒŒ์‹ฑ
```python
def main():
# ๋ช…๋ นํ–‰ ์ธ์ž ํŒŒ์‹ฑ
args = parse_arguments()
# ํ™˜๊ฒฝ ์„ค์ • (GPU, ๊ฒฝ๋กœ ๋“ฑ)
setup_environment(args.gpu)
```
#### Step 2: ๋ฌธ์ œ ๋ฆฌ์ŠคํŠธ ์ƒ์„ฑ
```python
# ๋ฒค์น˜๋งˆํฌ์—์„œ ๋ฌธ์ œ ID ์ถ”์ถœ
problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id)
# ์˜ˆ: ['Mbpp/1', 'Mbpp/2', 'Mbpp/3', ...]
```
#### Step 3: ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
```python
# VeRL์ด UnifiedTTRLVRTrainer์— ์ „๋‹ฌํ•  ์„ค์ •
os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids)
os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds)
os.environ['TTRLVR_OUTPUT_DIR'] = output_dir
os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config)
```
#### Step 4: VeRL ์‹คํ–‰
```python
# VeRL์˜ main_generation ํ˜ธ์ถœ
verl_args = [
'train_ttrlvr_azr_unified.py',
f'--config-path={config_path}',
'--config-name=ttrlvr_azr_unified_4gpu',
f'trainer.project_name=ttrlvr_unified_{args.benchmark}',
f'trainer.total_epochs={args.rounds}', # ๊ฐ ๋ผ์šด๋“œ๋ฅผ epoch๋กœ ๋งคํ•‘
]
sys.argv = verl_args
generate_main() # VeRL ๋ฉ”์ธ ํ•จ์ˆ˜ ์‹คํ–‰
```
### 3.3 VeRL ์ดˆ๊ธฐํ™”
VeRL์˜ `generate_main()`์ด ์‹คํ–‰๋˜๋ฉด:
1. **Config ๋กœ๋”ฉ**: `ttrlvr_azr_unified_4gpu.yaml` ํŒŒ์‹ฑ
2. **Ray ํด๋Ÿฌ์Šคํ„ฐ ์ดˆ๊ธฐํ™”**: ๋ถ„์‚ฐ ์ฒ˜๋ฆฌ ํ™˜๊ฒฝ ์„ค์ •
3. **UnifiedTTRLVRTrainer ์ƒ์„ฑ**: ์„ค์ •์— ๋ช…์‹œ๋œ ํด๋ž˜์Šค ๋กœ๋“œ
4. **Worker ์ดˆ๊ธฐํ™”**: `trainer.init_workers()` ํ˜ธ์ถœ (1๋ฒˆ๋งŒ!)
---
## 4. ํ•ต์‹ฌ ์ปดํฌ๋„ŒํŠธ
### 4.1 UnifiedTTRLVRTrainer
```python
class UnifiedTTRLVRTrainer(ReasonRLRayPPOTrainer):
"""
TTRLVR์˜ ๋ชจ๋“  Phase๋ฅผ ํ•˜๋‚˜์˜ VeRL ์„ธ์…˜์—์„œ ์ฒ˜๋ฆฌํ•˜๋Š” ํ†ตํ•ฉ Trainer
"""
def __init__(self, ttrlvr_config, problem_ids, total_rounds, ...):
super().__init__(...)
# TTRLVR ํŠนํ™” ์„ค์ •
self.ttrlvr_config = ttrlvr_config
self.problem_ids = problem_ids
self.total_rounds = total_rounds
self.current_round = 0
# CompleteTestTimePipeline ์ดˆ๊ธฐํ™” (๋‚˜์ค‘์—)
self.ttrlvr_pipeline = None
```
### 4.2 CompleteTestTimePipeline ํ†ตํ•ฉ
```python
def _init_ttrlvr_pipeline(self):
"""CompleteTestTimePipeline์„ VeRL์˜ vLLM์œผ๋กœ ์ดˆ๊ธฐํ™”"""
# VeRL์˜ ๋ชจ๋ธ ์‚ฌ์šฉ
self.ttrlvr_pipeline = CompleteTestTimePipeline(
model=None, # VeRL wrapper ํ†ตํ•ด ์ ‘๊ทผ
tokenizer=self.tokenizer,
config=self.testtime_config,
logger=self.ttrlvr_logger
)
# VeRL์˜ vLLM์„ ์‚ฌ์šฉํ•˜๋„๋ก ์„ค์ •
self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm
```
---
## 5. Phase๋ณ„ ์ƒ์„ธ ๋™์ž‘
### 5.1 fit() ๋ฉ”์„œ๋“œ - ๋ฉ”์ธ ํ•™์Šต ๋ฃจํ”„
```python
def fit(self):
"""์ „์ฒด ํ•™์Šต ๋ฃจํ”„ ๊ด€๋ฆฌ"""
# ๋กœ๊ฑฐ ์ดˆ๊ธฐํ™”
logger = ReasonRLTracking(...)
# ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ (์žˆ์œผ๋ฉด)
self._load_checkpoint()
# ๋ผ์šด๋“œ๋ณ„ ๋ฐ˜๋ณต
for round_num in range(1, self.total_rounds + 1):
self.current_round = round_num
# ====== Phase 1-4: ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ======
round_data = self._generate_round_data()
# ====== Phase 5: PPO ํ•™์Šต ======
metrics = self._train_one_round(round_data, logger)
# ์ฒดํฌํฌ์ธํŠธ ์ €์žฅ (5๋ผ์šด๋“œ๋งˆ๋‹ค)
if round_num % 5 == 0:
self._save_checkpoint()
```
### 5.2 Phase 1-4: ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
#### 5.2.1 _generate_round_data() ๊ตฌ์กฐ
```python
def _generate_round_data(self) -> List[Dict[str, Any]]:
"""Phase 1-4 ์‹คํ–‰"""
# Pipeline ์ดˆ๊ธฐํ™” (์ฒ˜์Œ๋งŒ)
if self.ttrlvr_pipeline is None:
self._init_ttrlvr_pipeline()
all_tasks = []
for problem_id in self.problem_ids:
# CompleteTestTimePipeline ์‹คํ–‰
result = self.ttrlvr_pipeline.run_complete_pipeline(
benchmark_config=benchmark_config,
problem_id=problem_id,
round_num=self.current_round,
session_timestamp=session_timestamp
)
if result['success']:
tasks = result['final_tasks']
all_tasks.extend(tasks)
return all_tasks
```
#### 5.2.2 CompleteTestTimePipeline ๋‚ด๋ถ€ ๋™์ž‘
**Phase 1: ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ**
```python
# 1. ๋ฒค์น˜๋งˆํฌ ๋ฌธ์ œ ๋กœ๋“œ
problem = benchmark_loader.load_problem(benchmark_config, problem_id)
# 2. Baseline ํ‰๊ฐ€
baseline_results = self._evaluate_baseline_performance(problem)
# 3. ๋‹ค์–‘ํ•œ ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
diverse_programs = self._generate_diverse_programs_and_ipo(problem)
# ๋‚ด๋ถ€์ ์œผ๋กœ:
# - ์ •๊ตํ•œ ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ ์‚ฌ์šฉ
# - Temperature ์กฐ์ ˆ๋กœ ๋‹ค์–‘์„ฑ ํ™•๋ณด
# - ๋ฌธ๋ฒ• ๊ฒ€์ฆ
```
**Phase 2: I/O ์Œ ์ถ”์ถœ**
```python
# IPOTripleExtractor ์‚ฌ์šฉ
ipo_extractor = IPOTripleExtractor(config, logger, model, tokenizer)
for program in diverse_programs:
# ์ž…๋ ฅ ์ƒ์„ฑ
inputs = ipo_extractor.generate_inputs(program)
# ์ถœ๋ ฅ ๊ณ„์‚ฐ
for input in inputs:
output = executor.execute(program, input)
ipo_buffer.add_triple(input, program, output)
```
**Phase 3: Task ์ƒ์„ฑ**
```python
# TestTimeTaskGenerator ์‚ฌ์šฉ
task_generator = TestTimeTaskGenerator(config, logger)
# Induction: I/O โ†’ Program
induction_tasks = task_generator.create_induction_tasks(ipo_triples)
# Deduction: Program + Input โ†’ Output
deduction_tasks = task_generator.create_deduction_tasks(ipo_triples)
# Abduction: Program + Output โ†’ Input
abduction_tasks = task_generator.create_abduction_tasks(ipo_triples)
```
**Phase 4: ๊ฒ€์ฆ ๋ฐ ํ•„ํ„ฐ๋ง**
```python
# ๊ฐ task ๊ฒ€์ฆ
valid_tasks = []
for task in all_tasks:
if validator.is_valid(task):
valid_tasks.append(task)
```
### 5.3 Phase 5: PPO ํ•™์Šต
#### 5.3.1 _train_one_round() ๊ตฌ์กฐ
```python
def _train_one_round(self, round_data: List[Dict], logger) -> Dict[str, float]:
"""Phase 5: PPO ํ•™์Šต"""
# 1. ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
train_dataset = self._convert_to_verl_dataset(round_data)
# 2. DataLoader ์ƒ์„ฑ
self.train_dataloader = self._create_dataloader(
train_dataset,
batch_size=self.config.data.train_batch_size
)
# 3. 1 epoch ํ•™์Šต
epoch_metrics = {}
for step, batch in enumerate(self.train_dataloader):
# PPO Step 1: Response ์ƒ์„ฑ
gen_batch_output = self.actor_rollout_wg.generate_sequences(batch)
# PPO Step 2: Reward ๊ณ„์‚ฐ
reward_tensor = self.reward_fn(batch.union(gen_batch_output))
# PPO Step 3: Policy ์—…๋ฐ์ดํŠธ
update_metrics = self._ppo_update(batch, reward_tensor)
# ๋ฉ”ํŠธ๋ฆญ ์ˆ˜์ง‘
for k, v in update_metrics.items():
epoch_metrics[k].append(v)
return {k: np.mean(v) for k, v in epoch_metrics.items()}
```
#### 5.3.2 ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜ ๊ณผ์ •
```python
def _convert_to_verl_dataset(self, round_data: List[Dict]) -> Any:
"""TTRLVR ํ˜•์‹ โ†’ VeRL ํ˜•์‹"""
converted_data = []
for task in round_data:
# ํ† ํฐํ™”
prompt_ids = self.tokenizer(
task['prompt'],
max_length=self.config.data.max_prompt_length
).input_ids
# VeRL DataProto ํ˜•์‹
verl_item = {
'input_ids': prompt_ids,
'prompt': task['prompt'],
'target': task['target'],
'task_type': task['task_type'],
'problem_id': task['problem_id']
}
converted_data.append(verl_item)
return converted_data
```
---
## 6. ๋™๊ธฐํ™” ๋ฉ”์ปค๋‹ˆ์ฆ˜
### 6.1 ๋ฌธ์ œ์˜ ํ•ต์‹ฌ
๊ธฐ์กด TTRLVR์€ ๋งค ๋ผ์šด๋“œ๋งˆ๋‹ค ์ƒˆ vLLM์„ ์ƒ์„ฑํ–ˆ๊ธฐ ๋•Œ๋ฌธ์— dummy_dtensor ์‚ฌ์šฉ ์‹œ ๋™๊ธฐํ™”๊ฐ€ ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.
### 6.2 ํ•ด๊ฒฐ ๋ฐฉ๋ฒ•
#### 6.2.1 ๋‹จ์ผ vLLM ์ธ์Šคํ„ด์Šค
```python
# ์ดˆ๊ธฐํ™” (1๋ฒˆ๋งŒ)
trainer.init_workers()
โ”œโ”€โ”€ FSDP workers ์ƒ์„ฑ
โ”œโ”€โ”€ vLLM workers ์ƒ์„ฑ
โ””โ”€โ”€ ์ดˆ๊ธฐ ๋™๊ธฐํ™” (sync_model_weights)
# ์ดํ›„ ๋ชจ๋“  ๋ผ์šด๋“œ์—์„œ ๊ฐ™์€ ์ธ์Šคํ„ด์Šค ์‚ฌ์šฉ
Round 1: Phase 1-4 โ†’ Phase 5 (๊ฐ™์€ vLLM)
Round 2: Phase 1-4 โ†’ Phase 5 (๊ฐ™์€ vLLM)
...
```
#### 6.2.2 ๋™๊ธฐํ™” ๊ณผ์ •
```python
# FSDPVLLMShardingManager์˜ ๋™์ž‘
class FSDPVLLMShardingManager:
def __enter__(self):
if not self.base_sync_done:
# ์ฒซ ๋ฒˆ์งธ ํ˜ธ์ถœ: FSDP โ†’ vLLM ๋™๊ธฐํ™”
sync_model_weights(actor_weights, load_format='dummy_dtensor')
self.base_sync_done = True
# ์ดํ›„: ๋ฉ”๋ชจ๋ฆฌ ์ฐธ์กฐ๋กœ ์ž๋™ ๋™๊ธฐํ™”
```
### 6.3 ๋ฉ”๋ชจ๋ฆฌ ์ฐธ์กฐ ๋ฉ”์ปค๋‹ˆ์ฆ˜
```
FSDP ๋ชจ๋ธ (GPU 0-3) vLLM ๋ชจ๋ธ (GPU 0-1)
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Parameter A โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ โ”‚ Parameter A โ”‚ (๊ฐ™์€ ๋ฉ”๋ชจ๋ฆฌ ์ฐธ์กฐ)
โ”‚ Parameter B โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ โ”‚ Parameter B โ”‚
โ”‚ Parameter C โ”‚ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ โ”‚ Parameter C โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
PPO ์—…๋ฐ์ดํŠธ โ†’ FSDP ํŒŒ๋ผ๋ฏธํ„ฐ ๋ณ€๊ฒฝ โ†’ vLLM๋„ ์ž๋™์œผ๋กœ ์ƒˆ ๊ฐ’ ์‚ฌ์šฉ
```
---
## 7. ๋ฐ์ดํ„ฐ ํ๋ฆ„
### 7.1 Round 1 ์ƒ์„ธ ํ๋ฆ„
```
1. Problem: Mbpp/2 (์˜ˆ: "๋‘ ์ˆ˜์˜ ํ•ฉ์„ ๊ตฌํ•˜๋Š” ํ•จ์ˆ˜ ์ž‘์„ฑ")
โ”‚
โ”œโ”€โ”€ Phase 1: ํ”„๋กœ๊ทธ๋žจ ์ƒ์„ฑ
โ”‚ โ”œโ”€โ”€ Prompt: "Generate 4 different solutions..."
โ”‚ โ”œโ”€โ”€ vLLM ์ƒ์„ฑ (๋™๊ธฐํ™” ๋ฐœ์ƒ)
โ”‚ โ””โ”€โ”€ Output: [prog1, prog2, prog3, prog4]
โ”‚
โ”œโ”€โ”€ Phase 2: I/O ์ถ”์ถœ
โ”‚ โ”œโ”€โ”€ ๊ฐ ํ”„๋กœ๊ทธ๋žจ์— ๋Œ€ํ•ด ์ž…๋ ฅ ์ƒ์„ฑ
โ”‚ โ”œโ”€โ”€ vLLM ์‚ฌ์šฉ (๋™๊ธฐํ™” ๊ฑด๋„ˆ๋œ€)
โ”‚ โ””โ”€โ”€ Output: [(input1, output1), (input2, output2), ...]
โ”‚
โ”œโ”€โ”€ Phase 3: Task ์ƒ์„ฑ
โ”‚ โ”œโ”€โ”€ Induction: (1, 3) โ†’ "def add(a,b): return a+b"
โ”‚ โ”œโ”€โ”€ Deduction: (prog, 5) โ†’ 8
โ”‚ โ””โ”€โ”€ Abduction: (prog, 10) โ†’ (4, 6)
โ”‚
โ”œโ”€โ”€ Phase 4: ๊ฒ€์ฆ
โ”‚ โ””โ”€โ”€ ์œ ํšจํ•œ task๋งŒ ํ•„ํ„ฐ๋ง
โ”‚
โ””โ”€โ”€ Phase 5: PPO ํ•™์Šต
โ”œโ”€โ”€ ๋ฐฐ์น˜ ์ƒ์„ฑ
โ”œโ”€โ”€ Response ์ƒ์„ฑ (๊ฐ™์€ vLLM)
โ”œโ”€โ”€ Reward ๊ณ„์‚ฐ
โ””โ”€โ”€ FSDP ๋ชจ๋ธ ์—…๋ฐ์ดํŠธ
```
### 7.2 ๋ฐ์ดํ„ฐ ํ˜•์‹ ๋ณ€ํ™˜
```python
# TTRLVR Task ํ˜•์‹
{
'problem_id': 'Mbpp/2',
'task_type': 'induction',
'input': 5,
'output': 10,
'target': 'def multiply_by_two(x): return x * 2',
'prompt': 'Given input 5 produces output 10, write the function:'
}
# โ†“ ๋ณ€ํ™˜
# VeRL DataProto ํ˜•์‹
{
'input_ids': tensor([1, 234, 567, ...]), # ํ† ํฐํ™”๋œ prompt
'attention_mask': tensor([1, 1, 1, ...]),
'prompt': 'Given input 5 produces output 10...',
'target': 'def multiply_by_two(x): return x * 2',
'meta_info': {
'task_type': 'induction',
'problem_id': 'Mbpp/2'
}
}
```
---
## 8. ๊ตฌํ˜„ ์„ธ๋ถ€์‚ฌํ•ญ
### 8.1 VeRL๊ณผ์˜ ํ†ตํ•ฉ
#### 8.1.1 _generate_with_vllm ๋ฉ”์„œ๋“œ
```python
def _generate_with_vllm(self, prompt: str, temperature: float = 0.7):
"""VeRL์˜ vLLM์„ ์‚ฌ์šฉํ•œ ํ…์ŠคํŠธ ์ƒ์„ฑ"""
# 1. ํ† ํฐํ™”
input_ids = self.tokenizer(prompt, ...).input_ids
# 2. DataProto ์ƒ์„ฑ
prompts_proto = DataProto.from_dict({
"input_ids": input_ids.cuda(),
"attention_mask": torch.ones_like(input_ids).cuda(),
})
# 3. ๋ฉ”ํƒ€ ์ •๋ณด ์„ค์ •
prompts_proto.meta_info = {
"eos_token_id": self.tokenizer.eos_token_id,
"temperature": temperature,
"do_sample": True,
"response_length": 256
}
# 4. VeRL์˜ vLLM์œผ๋กœ ์ƒ์„ฑ
outputs = self.actor_rollout_wg.generate_sequences(prompts_proto)
# 5. ๋””์ฝ”๋”ฉ ๋ฐ ๋ฐ˜ํ™˜
return self.tokenizer.decode(outputs.batch["input_ids"][0])
```
#### 8.1.2 CompleteTestTimePipeline ์ˆ˜์ •
```python
# CompleteTestTimePipeline์ด VeRL์˜ vLLM์„ ์‚ฌ์šฉํ•˜๋„๋ก
self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm
# ์ด์ œ Pipeline ๋‚ด๋ถ€์—์„œ:
# response = self.generate_with_verl(prompt) # VeRL์˜ vLLM ์‚ฌ์šฉ
```
### 8.2 ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
#### 8.2.1 ๋ผ์šด๋“œ ๊ฐ„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
```python
def _manage_memory_between_rounds(self):
"""๋ผ์šด๋“œ ๊ฐ„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ (์ธ์Šคํ„ด์Šค๋Š” ์œ ์ง€)"""
# GPU ์บ์‹œ๋งŒ ์ •๋ฆฌ
torch.cuda.empty_cache()
# vLLM KV ์บ์‹œ ์ •๋ฆฌ (์„ ํƒ์ )
if hasattr(self.actor_rollout_wg, 'clear_kv_cache'):
self.actor_rollout_wg.clear_kv_cache()
# Garbage collection
import gc
gc.collect()
```
#### 8.2.2 ๋ฉ”๋ชจ๋ฆฌ ๋ชจ๋‹ˆํ„ฐ๋ง
```python
def _monitor_memory(self):
"""๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋ชจ๋‹ˆํ„ฐ๋ง"""
for i in range(torch.cuda.device_count()):
allocated = torch.cuda.memory_allocated(i) / 1024**3
reserved = torch.cuda.memory_reserved(i) / 1024**3
print(f"GPU {i}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")
```
### 8.3 ์—๋Ÿฌ ์ฒ˜๋ฆฌ ๋ฐ ๋ณต๊ตฌ
```python
def _safe_generate(self, prompt: str, max_retries: int = 3):
"""์•ˆ์ „ํ•œ ์ƒ์„ฑ with ์žฌ์‹œ๋„"""
for attempt in range(max_retries):
try:
return self._generate_with_vllm(prompt)
except Exception as e:
if attempt == max_retries - 1:
raise
torch.cuda.empty_cache()
time.sleep(1)
```
### 8.4 ์ฒดํฌํฌ์ธํŠธ ๊ด€๋ฆฌ
```python
def _save_checkpoint(self):
"""์ฒดํฌํฌ์ธํŠธ ์ €์žฅ"""
checkpoint = {
'round': self.current_round,
'model_state_dict': self.actor_rollout_wg.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'metrics': self.accumulated_metrics,
'timestamp': datetime.now().isoformat()
}
path = f"{self.checkpoint_dir}/round_{self.current_round}.pt"
torch.save(checkpoint, path)
```
---
## 9. ์„ฑ๋Šฅ ์ตœ์ ํ™”
### 9.1 ๋ฐฐ์น˜ ์ฒ˜๋ฆฌ
- Phase 1-4์—์„œ ๊ฐ€๋Šฅํ•œ ํ•œ ๋ฐฐ์น˜๋กœ ์ฒ˜๋ฆฌ
- vLLM์˜ continuous batching ํ™œ์šฉ
### 9.2 GPU ํ™œ์šฉ
- vLLM: GPU 0-1 (tensor parallel)
- FSDP: GPU 0-3 (data parallel)
- ํšจ์œจ์ ์ธ GPU ๋ฉ”๋ชจ๋ฆฌ ํ™œ์šฉ
### 9.3 I/O ์ตœ์ ํ™”
- Parquet ํ˜•์‹์œผ๋กœ ์ค‘๊ฐ„ ๋ฐ์ดํ„ฐ ์ €์žฅ
- ๋น„๋™๊ธฐ I/O ์ฒ˜๋ฆฌ
---
## 10. ๋””๋ฒ„๊น… ๋ฐ ๋ชจ๋‹ˆํ„ฐ๋ง
### 10.1 ๋กœ๊น… ๊ตฌ์กฐ
```
/home/ubuntu/RLVR/TestTime-RLVR-v2/logs/
โ”œโ”€โ”€ ttrlvr_unified_20241107_120000.log # ๋ฉ”์ธ ๋กœ๊ทธ
โ”œโ”€โ”€ round_1/
โ”‚ โ”œโ”€โ”€ phase_1_4.log # ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ๋กœ๊ทธ
โ”‚ โ””โ”€โ”€ phase_5.log # ํ•™์Šต ๋กœ๊ทธ
โ””โ”€โ”€ metrics/
โ””โ”€โ”€ tensorboard/ # ํ•™์Šต ๋ฉ”ํŠธ๋ฆญ
```
### 10.2 ์ฃผ์š” ๋ชจ๋‹ˆํ„ฐ๋ง ์ง€ํ‘œ
- ๋ผ์šด๋“œ๋ณ„ ์†Œ์š” ์‹œ๊ฐ„
- ์ƒ์„ฑ๋œ task ์ˆ˜
- ํ‰๊ท  reward
- GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰
- ๋™๊ธฐํ™” ๋ฐœ์ƒ ํšŸ์ˆ˜
---
## 11. ๋ฌธ์ œ ํ•ด๊ฒฐ ๊ฐ€์ด๋“œ
### 11.1 OOM (Out of Memory)
- `gpu_memory_utilization` ์กฐ์ • (๊ธฐ๋ณธ: 0.35)
- `max_num_seqs` ๊ฐ์†Œ
- ๋ฐฐ์น˜ ํฌ๊ธฐ ๊ฐ์†Œ
### 11.2 ๋™๊ธฐํ™” ๋ฌธ์ œ
- `load_format`์ด `dummy_dtensor`์ธ์ง€ ํ™•์ธ
- vLLM ์ธ์Šคํ„ด์Šค๊ฐ€ ์žฌ์ƒ์„ฑ๋˜์ง€ ์•Š๋Š”์ง€ ํ™•์ธ
### 11.3 ๋А๋ฆฐ ์„ฑ๋Šฅ
- GPU ํ™œ์šฉ๋ฅ  ํ™•์ธ
- ๋ฐฐ์น˜ ํฌ๊ธฐ ์ฆ๊ฐ€
- `enforce_eager=False` ํ™•์ธ (CUDA graph ์‚ฌ์šฉ)
---
## 12. ๊ฒฐ๋ก 
TTRLVR Unified๋Š” ๊ธฐ์กด TTRLVR์˜ ๋ชจ๋“  ๊ธฐ๋Šฅ์„ ์œ ์ง€ํ•˜๋ฉด์„œ ๋‹ค์Œ์„ ๋‹ฌ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค:
1. **๊ตฌ์กฐ์  ๊ฐœ์„ **: ๋ถ„๋ฆฌ๋œ Phase๋“ค์„ ํ•˜๋‚˜์˜ ์„ธ์…˜์œผ๋กœ ํ†ตํ•ฉ
2. **์„ฑ๋Šฅ ํ–ฅ์ƒ**: vLLM ์žฌ์ƒ์„ฑ ์˜ค๋ฒ„ํ—ค๋“œ ์ œ๊ฑฐ๋กœ 30-40% ์†๋„ ํ–ฅ์ƒ
3. **์•ˆ์ •์„ฑ ํ–ฅ์ƒ**: ๋™๊ธฐํ™” ๋ฌธ์ œ ์™„์ „ ํ•ด๊ฒฐ
4. **ํ™•์žฅ์„ฑ**: ๋” ํฐ ๋ชจ๋ธ๊ณผ ๋” ๋งŽ์€ ๋ผ์šด๋“œ ์ง€์› ๊ฐ€๋Šฅ
์ด ์•„ํ‚คํ…์ฒ˜๋Š” TTRLVR์˜ ์ •๊ตํ•œ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ๋Šฅ๋ ฅ๊ณผ VeRL์˜ ํšจ์œจ์ ์ธ PPO ํ•™์Šต์„ ์™„๋ฒฝํ•˜๊ฒŒ ๊ฒฐํ•ฉํ–ˆ์Šต๋‹ˆ๋‹ค.