| # TTRLVR Unified Architecture - ์์ธ ์๋ ๋ฐฉ์ | |
| ## ๋ชฉ์ฐจ | |
| 1. [๊ฐ์](#1-๊ฐ์) | |
| 2. [์ ์ฒด ์ํคํ ์ฒ](#2-์ ์ฒด-์ํคํ ์ฒ) | |
| 3. [์คํ ํ๋ฆ](#3-์คํ-ํ๋ฆ) | |
| 4. [ํต์ฌ ์ปดํฌ๋ํธ](#4-ํต์ฌ-์ปดํฌ๋ํธ) | |
| 5. [Phase๋ณ ์์ธ ๋์](#5-phase๋ณ-์์ธ-๋์) | |
| 6. [๋๊ธฐํ ๋ฉ์ปค๋์ฆ](#6-๋๊ธฐํ-๋ฉ์ปค๋์ฆ) | |
| 7. [๋ฐ์ดํฐ ํ๋ฆ](#7-๋ฐ์ดํฐ-ํ๋ฆ) | |
| 8. [๊ตฌํ ์ธ๋ถ์ฌํญ](#8-๊ตฌํ-์ธ๋ถ์ฌํญ) | |
| --- | |
| ## 1. ๊ฐ์ | |
| ### 1.1 ๋ชฉ์ | |
| TTRLVR Unified๋ ๊ธฐ์กด TTRLVR์ ๋ถ๋ฆฌ๋ ๊ตฌ์กฐ๋ฅผ ํ๋์ ํตํฉ๋ VeRL ์ธ์ ์ผ๋ก ์ฌ๊ตฌ์ฑํ์ฌ ๋๊ธฐํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ณ ์ฑ๋ฅ์ ํฅ์์ํจ ๋ฒ์ ์ ๋๋ค. | |
| ### 1.2 ํต์ฌ ๊ฐ์ ์ฌํญ | |
| - **๋จ์ผ vLLM ์ธ์คํด์ค**: ์ ์ฒด ํ์ต ๊ณผ์ ์์ ํ๋์ vLLM๋ง ์ฌ์ฉ | |
| - **๋๊ธฐํ ๋ฌธ์ ํด๊ฒฐ**: dummy_dtensor ์ฌ์ฉ ๊ฐ๋ฅ | |
| - **์ฑ๋ฅ ํฅ์**: vLLM ์ฌ์์ฑ ์ค๋ฒํค๋ ์ ๊ฑฐ๋ก 30-40% ์๋ ํฅ์ | |
| - **๋ฉ๋ชจ๋ฆฌ ํจ์จ**: ๋ฐ๋ณต์ ์ธ ํ ๋น/ํด์ ์์ | |
| ### 1.3 ์ฃผ์ ํ์ผ | |
| - `train_ttrlvr_azr_unified.py`: ๋ฉ์ธ ์คํ ์คํฌ๋ฆฝํธ | |
| - `test/trainer/unified_ttrlvr_trainer.py`: ํตํฉ Trainer ํด๋์ค | |
| - `test/configs/ttrlvr_azr_unified_4gpu.yaml`: VeRL ์ค์ ํ์ผ | |
| --- | |
| ## 2. ์ ์ฒด ์ํคํ ์ฒ | |
| ### 2.1 ๊ธฐ์กด vs ํตํฉ ๊ตฌ์กฐ | |
| #### ๊ธฐ์กด TTRLVR (๋ถ๋ฆฌํ) | |
| ``` | |
| Round 1: | |
| โโโ Phase 1-4: RemoteTestTimePipeline (๋ ๋ฆฝ vLLM #1) | |
| โ โโโ ray.kill(pipeline) # vLLM ์ญ์ | |
| โโโ Phase 5: VeRL Training (์ vLLM #2) | |
| โโโ trainer.init_workers() # ๋งค ๋ผ์ด๋๋ง๋ค | |
| Round 2: (์๋ก์ด vLLM ์ธ์คํด์ค๋ค...) | |
| ``` | |
| #### Unified TTRLVR (ํตํฉํ) | |
| ``` | |
| ์ด๊ธฐํ: | |
| โโโ trainer.init_workers() # 1๋ฒ๋ง! | |
| Round 1-N: | |
| โโโ Phase 1-4: ๋ฐ์ดํฐ ์์ฑ (๊ฐ์ vLLM) | |
| โโโ Phase 5: PPO ํ์ต (๊ฐ์ vLLM) | |
| ``` | |
| ### 2.2 ์ปดํฌ๋ํธ ๊ด๊ณ๋ | |
| ``` | |
| train_ttrlvr_azr_unified.py | |
| โ | |
| โโโ ํ๊ฒฝ ์ค์ & ์ธ์ ํ์ฑ | |
| โ | |
| โโโ VeRL generate_main() ํธ์ถ | |
| โ โ | |
| โ โโโ UnifiedTTRLVRTrainer ์์ฑ | |
| โ โ | |
| โ โโโ CompleteTestTimePipeline (Phase 1-4) | |
| โ โ โโโ ๋ฒค์น๋งํฌ ๋ฌธ์ ๋ก๋ฉ | |
| โ โ โโโ ํ๋ก๊ทธ๋จ ์์ฑ (diverse_programs) | |
| โ โ โโโ IPO ์ถ์ถ (IPOTripleExtractor) | |
| โ โ โโโ Task ์์ฑ (TestTimeTaskGenerator) | |
| โ โ โโโ ๊ฒ์ฆ ๋ฐ ํํฐ๋ง | |
| โ โ | |
| โ โโโ VeRL PPO Training (Phase 5) | |
| โ โโโ ๋ฐ์ดํฐ ํ์ ๋ณํ | |
| โ โโโ Response ์์ฑ | |
| โ โโโ Reward ๊ณ์ฐ | |
| โ โโโ Policy ์ ๋ฐ์ดํธ | |
| ``` | |
| --- | |
| ## 3. ์คํ ํ๋ฆ | |
| ### 3.1 ์คํฌ๋ฆฝํธ ์คํ | |
| ```bash | |
| python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3 | |
| ``` | |
| ### 3.2 ์ด๊ธฐํ ๋จ๊ณ | |
| #### Step 1: ์ธ์ ํ์ฑ | |
| ```python | |
| def main(): | |
| # ๋ช ๋ นํ ์ธ์ ํ์ฑ | |
| args = parse_arguments() | |
| # ํ๊ฒฝ ์ค์ (GPU, ๊ฒฝ๋ก ๋ฑ) | |
| setup_environment(args.gpu) | |
| ``` | |
| #### Step 2: ๋ฌธ์ ๋ฆฌ์คํธ ์์ฑ | |
| ```python | |
| # ๋ฒค์น๋งํฌ์์ ๋ฌธ์ ID ์ถ์ถ | |
| problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id) | |
| # ์: ['Mbpp/1', 'Mbpp/2', 'Mbpp/3', ...] | |
| ``` | |
| #### Step 3: ํ๊ฒฝ ๋ณ์ ์ค์ | |
| ```python | |
| # VeRL์ด UnifiedTTRLVRTrainer์ ์ ๋ฌํ ์ค์ | |
| os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids) | |
| os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds) | |
| os.environ['TTRLVR_OUTPUT_DIR'] = output_dir | |
| os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config) | |
| ``` | |
| #### Step 4: VeRL ์คํ | |
| ```python | |
| # VeRL์ main_generation ํธ์ถ | |
| verl_args = [ | |
| 'train_ttrlvr_azr_unified.py', | |
| f'--config-path={config_path}', | |
| '--config-name=ttrlvr_azr_unified_4gpu', | |
| f'trainer.project_name=ttrlvr_unified_{args.benchmark}', | |
| f'trainer.total_epochs={args.rounds}', # ๊ฐ ๋ผ์ด๋๋ฅผ epoch๋ก ๋งคํ | |
| ] | |
| sys.argv = verl_args | |
| generate_main() # VeRL ๋ฉ์ธ ํจ์ ์คํ | |
| ``` | |
| ### 3.3 VeRL ์ด๊ธฐํ | |
| VeRL์ `generate_main()`์ด ์คํ๋๋ฉด: | |
| 1. **Config ๋ก๋ฉ**: `ttrlvr_azr_unified_4gpu.yaml` ํ์ฑ | |
| 2. **Ray ํด๋ฌ์คํฐ ์ด๊ธฐํ**: ๋ถ์ฐ ์ฒ๋ฆฌ ํ๊ฒฝ ์ค์ | |
| 3. **UnifiedTTRLVRTrainer ์์ฑ**: ์ค์ ์ ๋ช ์๋ ํด๋์ค ๋ก๋ | |
| 4. **Worker ์ด๊ธฐํ**: `trainer.init_workers()` ํธ์ถ (1๋ฒ๋ง!) | |
| --- | |
| ## 4. ํต์ฌ ์ปดํฌ๋ํธ | |
| ### 4.1 UnifiedTTRLVRTrainer | |
| ```python | |
| class UnifiedTTRLVRTrainer(ReasonRLRayPPOTrainer): | |
| """ | |
| TTRLVR์ ๋ชจ๋ Phase๋ฅผ ํ๋์ VeRL ์ธ์ ์์ ์ฒ๋ฆฌํ๋ ํตํฉ Trainer | |
| """ | |
| def __init__(self, ttrlvr_config, problem_ids, total_rounds, ...): | |
| super().__init__(...) | |
| # TTRLVR ํนํ ์ค์ | |
| self.ttrlvr_config = ttrlvr_config | |
| self.problem_ids = problem_ids | |
| self.total_rounds = total_rounds | |
| self.current_round = 0 | |
| # CompleteTestTimePipeline ์ด๊ธฐํ (๋์ค์) | |
| self.ttrlvr_pipeline = None | |
| ``` | |
| ### 4.2 CompleteTestTimePipeline ํตํฉ | |
| ```python | |
| def _init_ttrlvr_pipeline(self): | |
| """CompleteTestTimePipeline์ VeRL์ vLLM์ผ๋ก ์ด๊ธฐํ""" | |
| # VeRL์ ๋ชจ๋ธ ์ฌ์ฉ | |
| self.ttrlvr_pipeline = CompleteTestTimePipeline( | |
| model=None, # VeRL wrapper ํตํด ์ ๊ทผ | |
| tokenizer=self.tokenizer, | |
| config=self.testtime_config, | |
| logger=self.ttrlvr_logger | |
| ) | |
| # VeRL์ vLLM์ ์ฌ์ฉํ๋๋ก ์ค์ | |
| self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm | |
| ``` | |
| --- | |
| ## 5. Phase๋ณ ์์ธ ๋์ | |
| ### 5.1 fit() ๋ฉ์๋ - ๋ฉ์ธ ํ์ต ๋ฃจํ | |
| ```python | |
| def fit(self): | |
| """์ ์ฒด ํ์ต ๋ฃจํ ๊ด๋ฆฌ""" | |
| # ๋ก๊ฑฐ ์ด๊ธฐํ | |
| logger = ReasonRLTracking(...) | |
| # ์ฒดํฌํฌ์ธํธ ๋ก๋ (์์ผ๋ฉด) | |
| self._load_checkpoint() | |
| # ๋ผ์ด๋๋ณ ๋ฐ๋ณต | |
| for round_num in range(1, self.total_rounds + 1): | |
| self.current_round = round_num | |
| # ====== Phase 1-4: ๋ฐ์ดํฐ ์์ฑ ====== | |
| round_data = self._generate_round_data() | |
| # ====== Phase 5: PPO ํ์ต ====== | |
| metrics = self._train_one_round(round_data, logger) | |
| # ์ฒดํฌํฌ์ธํธ ์ ์ฅ (5๋ผ์ด๋๋ง๋ค) | |
| if round_num % 5 == 0: | |
| self._save_checkpoint() | |
| ``` | |
| ### 5.2 Phase 1-4: ๋ฐ์ดํฐ ์์ฑ | |
| #### 5.2.1 _generate_round_data() ๊ตฌ์กฐ | |
| ```python | |
| def _generate_round_data(self) -> List[Dict[str, Any]]: | |
| """Phase 1-4 ์คํ""" | |
| # Pipeline ์ด๊ธฐํ (์ฒ์๋ง) | |
| if self.ttrlvr_pipeline is None: | |
| self._init_ttrlvr_pipeline() | |
| all_tasks = [] | |
| for problem_id in self.problem_ids: | |
| # CompleteTestTimePipeline ์คํ | |
| result = self.ttrlvr_pipeline.run_complete_pipeline( | |
| benchmark_config=benchmark_config, | |
| problem_id=problem_id, | |
| round_num=self.current_round, | |
| session_timestamp=session_timestamp | |
| ) | |
| if result['success']: | |
| tasks = result['final_tasks'] | |
| all_tasks.extend(tasks) | |
| return all_tasks | |
| ``` | |
| #### 5.2.2 CompleteTestTimePipeline ๋ด๋ถ ๋์ | |
| **Phase 1: ๋ค์ํ ํ๋ก๊ทธ๋จ ์์ฑ** | |
| ```python | |
| # 1. ๋ฒค์น๋งํฌ ๋ฌธ์ ๋ก๋ | |
| problem = benchmark_loader.load_problem(benchmark_config, problem_id) | |
| # 2. Baseline ํ๊ฐ | |
| baseline_results = self._evaluate_baseline_performance(problem) | |
| # 3. ๋ค์ํ ํ๋ก๊ทธ๋จ ์์ฑ | |
| diverse_programs = self._generate_diverse_programs_and_ipo(problem) | |
| # ๋ด๋ถ์ ์ผ๋ก: | |
| # - ์ ๊ตํ ํ๋กฌํํธ ํ ํ๋ฆฟ ์ฌ์ฉ | |
| # - Temperature ์กฐ์ ๋ก ๋ค์์ฑ ํ๋ณด | |
| # - ๋ฌธ๋ฒ ๊ฒ์ฆ | |
| ``` | |
| **Phase 2: I/O ์ ์ถ์ถ** | |
| ```python | |
| # IPOTripleExtractor ์ฌ์ฉ | |
| ipo_extractor = IPOTripleExtractor(config, logger, model, tokenizer) | |
| for program in diverse_programs: | |
| # ์ ๋ ฅ ์์ฑ | |
| inputs = ipo_extractor.generate_inputs(program) | |
| # ์ถ๋ ฅ ๊ณ์ฐ | |
| for input in inputs: | |
| output = executor.execute(program, input) | |
| ipo_buffer.add_triple(input, program, output) | |
| ``` | |
| **Phase 3: Task ์์ฑ** | |
| ```python | |
| # TestTimeTaskGenerator ์ฌ์ฉ | |
| task_generator = TestTimeTaskGenerator(config, logger) | |
| # Induction: I/O โ Program | |
| induction_tasks = task_generator.create_induction_tasks(ipo_triples) | |
| # Deduction: Program + Input โ Output | |
| deduction_tasks = task_generator.create_deduction_tasks(ipo_triples) | |
| # Abduction: Program + Output โ Input | |
| abduction_tasks = task_generator.create_abduction_tasks(ipo_triples) | |
| ``` | |
| **Phase 4: ๊ฒ์ฆ ๋ฐ ํํฐ๋ง** | |
| ```python | |
| # ๊ฐ task ๊ฒ์ฆ | |
| valid_tasks = [] | |
| for task in all_tasks: | |
| if validator.is_valid(task): | |
| valid_tasks.append(task) | |
| ``` | |
| ### 5.3 Phase 5: PPO ํ์ต | |
| #### 5.3.1 _train_one_round() ๊ตฌ์กฐ | |
| ```python | |
| def _train_one_round(self, round_data: List[Dict], logger) -> Dict[str, float]: | |
| """Phase 5: PPO ํ์ต""" | |
| # 1. ๋ฐ์ดํฐ ๋ณํ | |
| train_dataset = self._convert_to_verl_dataset(round_data) | |
| # 2. DataLoader ์์ฑ | |
| self.train_dataloader = self._create_dataloader( | |
| train_dataset, | |
| batch_size=self.config.data.train_batch_size | |
| ) | |
| # 3. 1 epoch ํ์ต | |
| epoch_metrics = {} | |
| for step, batch in enumerate(self.train_dataloader): | |
| # PPO Step 1: Response ์์ฑ | |
| gen_batch_output = self.actor_rollout_wg.generate_sequences(batch) | |
| # PPO Step 2: Reward ๊ณ์ฐ | |
| reward_tensor = self.reward_fn(batch.union(gen_batch_output)) | |
| # PPO Step 3: Policy ์ ๋ฐ์ดํธ | |
| update_metrics = self._ppo_update(batch, reward_tensor) | |
| # ๋ฉํธ๋ฆญ ์์ง | |
| for k, v in update_metrics.items(): | |
| epoch_metrics[k].append(v) | |
| return {k: np.mean(v) for k, v in epoch_metrics.items()} | |
| ``` | |
| #### 5.3.2 ๋ฐ์ดํฐ ๋ณํ ๊ณผ์ | |
| ```python | |
| def _convert_to_verl_dataset(self, round_data: List[Dict]) -> Any: | |
| """TTRLVR ํ์ โ VeRL ํ์""" | |
| converted_data = [] | |
| for task in round_data: | |
| # ํ ํฐํ | |
| prompt_ids = self.tokenizer( | |
| task['prompt'], | |
| max_length=self.config.data.max_prompt_length | |
| ).input_ids | |
| # VeRL DataProto ํ์ | |
| verl_item = { | |
| 'input_ids': prompt_ids, | |
| 'prompt': task['prompt'], | |
| 'target': task['target'], | |
| 'task_type': task['task_type'], | |
| 'problem_id': task['problem_id'] | |
| } | |
| converted_data.append(verl_item) | |
| return converted_data | |
| ``` | |
| --- | |
| ## 6. ๋๊ธฐํ ๋ฉ์ปค๋์ฆ | |
| ### 6.1 ๋ฌธ์ ์ ํต์ฌ | |
| ๊ธฐ์กด TTRLVR์ ๋งค ๋ผ์ด๋๋ง๋ค ์ vLLM์ ์์ฑํ๊ธฐ ๋๋ฌธ์ dummy_dtensor ์ฌ์ฉ ์ ๋๊ธฐํ๊ฐ ๋์ง ์์์ต๋๋ค. | |
| ### 6.2 ํด๊ฒฐ ๋ฐฉ๋ฒ | |
| #### 6.2.1 ๋จ์ผ vLLM ์ธ์คํด์ค | |
| ```python | |
| # ์ด๊ธฐํ (1๋ฒ๋ง) | |
| trainer.init_workers() | |
| โโโ FSDP workers ์์ฑ | |
| โโโ vLLM workers ์์ฑ | |
| โโโ ์ด๊ธฐ ๋๊ธฐํ (sync_model_weights) | |
| # ์ดํ ๋ชจ๋ ๋ผ์ด๋์์ ๊ฐ์ ์ธ์คํด์ค ์ฌ์ฉ | |
| Round 1: Phase 1-4 โ Phase 5 (๊ฐ์ vLLM) | |
| Round 2: Phase 1-4 โ Phase 5 (๊ฐ์ vLLM) | |
| ... | |
| ``` | |
| #### 6.2.2 ๋๊ธฐํ ๊ณผ์ | |
| ```python | |
| # FSDPVLLMShardingManager์ ๋์ | |
| class FSDPVLLMShardingManager: | |
| def __enter__(self): | |
| if not self.base_sync_done: | |
| # ์ฒซ ๋ฒ์งธ ํธ์ถ: FSDP โ vLLM ๋๊ธฐํ | |
| sync_model_weights(actor_weights, load_format='dummy_dtensor') | |
| self.base_sync_done = True | |
| # ์ดํ: ๋ฉ๋ชจ๋ฆฌ ์ฐธ์กฐ๋ก ์๋ ๋๊ธฐํ | |
| ``` | |
| ### 6.3 ๋ฉ๋ชจ๋ฆฌ ์ฐธ์กฐ ๋ฉ์ปค๋์ฆ | |
| ``` | |
| FSDP ๋ชจ๋ธ (GPU 0-3) vLLM ๋ชจ๋ธ (GPU 0-1) | |
| โโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโ | |
| โ Parameter A โ โโโโโโโโโโ โ Parameter A โ (๊ฐ์ ๋ฉ๋ชจ๋ฆฌ ์ฐธ์กฐ) | |
| โ Parameter B โ โโโโโโโโโโ โ Parameter B โ | |
| โ Parameter C โ โโโโโโโโโโ โ Parameter C โ | |
| โโโโโโโโโโโโโโโ โโโโโโโโโโโโโโโ | |
| PPO ์ ๋ฐ์ดํธ โ FSDP ํ๋ผ๋ฏธํฐ ๋ณ๊ฒฝ โ vLLM๋ ์๋์ผ๋ก ์ ๊ฐ ์ฌ์ฉ | |
| ``` | |
| --- | |
| ## 7. ๋ฐ์ดํฐ ํ๋ฆ | |
| ### 7.1 Round 1 ์์ธ ํ๋ฆ | |
| ``` | |
| 1. Problem: Mbpp/2 (์: "๋ ์์ ํฉ์ ๊ตฌํ๋ ํจ์ ์์ฑ") | |
| โ | |
| โโโ Phase 1: ํ๋ก๊ทธ๋จ ์์ฑ | |
| โ โโโ Prompt: "Generate 4 different solutions..." | |
| โ โโโ vLLM ์์ฑ (๋๊ธฐํ ๋ฐ์) | |
| โ โโโ Output: [prog1, prog2, prog3, prog4] | |
| โ | |
| โโโ Phase 2: I/O ์ถ์ถ | |
| โ โโโ ๊ฐ ํ๋ก๊ทธ๋จ์ ๋ํด ์ ๋ ฅ ์์ฑ | |
| โ โโโ vLLM ์ฌ์ฉ (๋๊ธฐํ ๊ฑด๋๋) | |
| โ โโโ Output: [(input1, output1), (input2, output2), ...] | |
| โ | |
| โโโ Phase 3: Task ์์ฑ | |
| โ โโโ Induction: (1, 3) โ "def add(a,b): return a+b" | |
| โ โโโ Deduction: (prog, 5) โ 8 | |
| โ โโโ Abduction: (prog, 10) โ (4, 6) | |
| โ | |
| โโโ Phase 4: ๊ฒ์ฆ | |
| โ โโโ ์ ํจํ task๋ง ํํฐ๋ง | |
| โ | |
| โโโ Phase 5: PPO ํ์ต | |
| โโโ ๋ฐฐ์น ์์ฑ | |
| โโโ Response ์์ฑ (๊ฐ์ vLLM) | |
| โโโ Reward ๊ณ์ฐ | |
| โโโ FSDP ๋ชจ๋ธ ์ ๋ฐ์ดํธ | |
| ``` | |
| ### 7.2 ๋ฐ์ดํฐ ํ์ ๋ณํ | |
| ```python | |
| # TTRLVR Task ํ์ | |
| { | |
| 'problem_id': 'Mbpp/2', | |
| 'task_type': 'induction', | |
| 'input': 5, | |
| 'output': 10, | |
| 'target': 'def multiply_by_two(x): return x * 2', | |
| 'prompt': 'Given input 5 produces output 10, write the function:' | |
| } | |
| # โ ๋ณํ | |
| # VeRL DataProto ํ์ | |
| { | |
| 'input_ids': tensor([1, 234, 567, ...]), # ํ ํฐํ๋ prompt | |
| 'attention_mask': tensor([1, 1, 1, ...]), | |
| 'prompt': 'Given input 5 produces output 10...', | |
| 'target': 'def multiply_by_two(x): return x * 2', | |
| 'meta_info': { | |
| 'task_type': 'induction', | |
| 'problem_id': 'Mbpp/2' | |
| } | |
| } | |
| ``` | |
| --- | |
| ## 8. ๊ตฌํ ์ธ๋ถ์ฌํญ | |
| ### 8.1 VeRL๊ณผ์ ํตํฉ | |
| #### 8.1.1 _generate_with_vllm ๋ฉ์๋ | |
| ```python | |
| def _generate_with_vllm(self, prompt: str, temperature: float = 0.7): | |
| """VeRL์ vLLM์ ์ฌ์ฉํ ํ ์คํธ ์์ฑ""" | |
| # 1. ํ ํฐํ | |
| input_ids = self.tokenizer(prompt, ...).input_ids | |
| # 2. DataProto ์์ฑ | |
| prompts_proto = DataProto.from_dict({ | |
| "input_ids": input_ids.cuda(), | |
| "attention_mask": torch.ones_like(input_ids).cuda(), | |
| }) | |
| # 3. ๋ฉํ ์ ๋ณด ์ค์ | |
| prompts_proto.meta_info = { | |
| "eos_token_id": self.tokenizer.eos_token_id, | |
| "temperature": temperature, | |
| "do_sample": True, | |
| "response_length": 256 | |
| } | |
| # 4. VeRL์ vLLM์ผ๋ก ์์ฑ | |
| outputs = self.actor_rollout_wg.generate_sequences(prompts_proto) | |
| # 5. ๋์ฝ๋ฉ ๋ฐ ๋ฐํ | |
| return self.tokenizer.decode(outputs.batch["input_ids"][0]) | |
| ``` | |
| #### 8.1.2 CompleteTestTimePipeline ์์ | |
| ```python | |
| # CompleteTestTimePipeline์ด VeRL์ vLLM์ ์ฌ์ฉํ๋๋ก | |
| self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm | |
| # ์ด์ Pipeline ๋ด๋ถ์์: | |
| # response = self.generate_with_verl(prompt) # VeRL์ vLLM ์ฌ์ฉ | |
| ``` | |
| ### 8.2 ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ | |
| #### 8.2.1 ๋ผ์ด๋ ๊ฐ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
| ```python | |
| def _manage_memory_between_rounds(self): | |
| """๋ผ์ด๋ ๊ฐ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ (์ธ์คํด์ค๋ ์ ์ง)""" | |
| # GPU ์บ์๋ง ์ ๋ฆฌ | |
| torch.cuda.empty_cache() | |
| # vLLM KV ์บ์ ์ ๋ฆฌ (์ ํ์ ) | |
| if hasattr(self.actor_rollout_wg, 'clear_kv_cache'): | |
| self.actor_rollout_wg.clear_kv_cache() | |
| # Garbage collection | |
| import gc | |
| gc.collect() | |
| ``` | |
| #### 8.2.2 ๋ฉ๋ชจ๋ฆฌ ๋ชจ๋ํฐ๋ง | |
| ```python | |
| def _monitor_memory(self): | |
| """๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ๋ชจ๋ํฐ๋ง""" | |
| for i in range(torch.cuda.device_count()): | |
| allocated = torch.cuda.memory_allocated(i) / 1024**3 | |
| reserved = torch.cuda.memory_reserved(i) / 1024**3 | |
| print(f"GPU {i}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB") | |
| ``` | |
| ### 8.3 ์๋ฌ ์ฒ๋ฆฌ ๋ฐ ๋ณต๊ตฌ | |
| ```python | |
| def _safe_generate(self, prompt: str, max_retries: int = 3): | |
| """์์ ํ ์์ฑ with ์ฌ์๋""" | |
| for attempt in range(max_retries): | |
| try: | |
| return self._generate_with_vllm(prompt) | |
| except Exception as e: | |
| if attempt == max_retries - 1: | |
| raise | |
| torch.cuda.empty_cache() | |
| time.sleep(1) | |
| ``` | |
| ### 8.4 ์ฒดํฌํฌ์ธํธ ๊ด๋ฆฌ | |
| ```python | |
| def _save_checkpoint(self): | |
| """์ฒดํฌํฌ์ธํธ ์ ์ฅ""" | |
| checkpoint = { | |
| 'round': self.current_round, | |
| 'model_state_dict': self.actor_rollout_wg.state_dict(), | |
| 'optimizer_state_dict': self.optimizer.state_dict(), | |
| 'metrics': self.accumulated_metrics, | |
| 'timestamp': datetime.now().isoformat() | |
| } | |
| path = f"{self.checkpoint_dir}/round_{self.current_round}.pt" | |
| torch.save(checkpoint, path) | |
| ``` | |
| --- | |
| ## 9. ์ฑ๋ฅ ์ต์ ํ | |
| ### 9.1 ๋ฐฐ์น ์ฒ๋ฆฌ | |
| - Phase 1-4์์ ๊ฐ๋ฅํ ํ ๋ฐฐ์น๋ก ์ฒ๋ฆฌ | |
| - vLLM์ continuous batching ํ์ฉ | |
| ### 9.2 GPU ํ์ฉ | |
| - vLLM: GPU 0-1 (tensor parallel) | |
| - FSDP: GPU 0-3 (data parallel) | |
| - ํจ์จ์ ์ธ GPU ๋ฉ๋ชจ๋ฆฌ ํ์ฉ | |
| ### 9.3 I/O ์ต์ ํ | |
| - Parquet ํ์์ผ๋ก ์ค๊ฐ ๋ฐ์ดํฐ ์ ์ฅ | |
| - ๋น๋๊ธฐ I/O ์ฒ๋ฆฌ | |
| --- | |
| ## 10. ๋๋ฒ๊น ๋ฐ ๋ชจ๋ํฐ๋ง | |
| ### 10.1 ๋ก๊น ๊ตฌ์กฐ | |
| ``` | |
| /home/ubuntu/RLVR/TestTime-RLVR-v2/logs/ | |
| โโโ ttrlvr_unified_20241107_120000.log # ๋ฉ์ธ ๋ก๊ทธ | |
| โโโ round_1/ | |
| โ โโโ phase_1_4.log # ๋ฐ์ดํฐ ์์ฑ ๋ก๊ทธ | |
| โ โโโ phase_5.log # ํ์ต ๋ก๊ทธ | |
| โโโ metrics/ | |
| โโโ tensorboard/ # ํ์ต ๋ฉํธ๋ฆญ | |
| ``` | |
| ### 10.2 ์ฃผ์ ๋ชจ๋ํฐ๋ง ์งํ | |
| - ๋ผ์ด๋๋ณ ์์ ์๊ฐ | |
| - ์์ฑ๋ task ์ | |
| - ํ๊ท reward | |
| - GPU ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ | |
| - ๋๊ธฐํ ๋ฐ์ ํ์ | |
| --- | |
| ## 11. ๋ฌธ์ ํด๊ฒฐ ๊ฐ์ด๋ | |
| ### 11.1 OOM (Out of Memory) | |
| - `gpu_memory_utilization` ์กฐ์ (๊ธฐ๋ณธ: 0.35) | |
| - `max_num_seqs` ๊ฐ์ | |
| - ๋ฐฐ์น ํฌ๊ธฐ ๊ฐ์ | |
| ### 11.2 ๋๊ธฐํ ๋ฌธ์ | |
| - `load_format`์ด `dummy_dtensor`์ธ์ง ํ์ธ | |
| - vLLM ์ธ์คํด์ค๊ฐ ์ฌ์์ฑ๋์ง ์๋์ง ํ์ธ | |
| ### 11.3 ๋๋ฆฐ ์ฑ๋ฅ | |
| - GPU ํ์ฉ๋ฅ ํ์ธ | |
| - ๋ฐฐ์น ํฌ๊ธฐ ์ฆ๊ฐ | |
| - `enforce_eager=False` ํ์ธ (CUDA graph ์ฌ์ฉ) | |
| --- | |
| ## 12. ๊ฒฐ๋ก | |
| TTRLVR Unified๋ ๊ธฐ์กด TTRLVR์ ๋ชจ๋ ๊ธฐ๋ฅ์ ์ ์งํ๋ฉด์ ๋ค์์ ๋ฌ์ฑํ์ต๋๋ค: | |
| 1. **๊ตฌ์กฐ์ ๊ฐ์ **: ๋ถ๋ฆฌ๋ Phase๋ค์ ํ๋์ ์ธ์ ์ผ๋ก ํตํฉ | |
| 2. **์ฑ๋ฅ ํฅ์**: vLLM ์ฌ์์ฑ ์ค๋ฒํค๋ ์ ๊ฑฐ๋ก 30-40% ์๋ ํฅ์ | |
| 3. **์์ ์ฑ ํฅ์**: ๋๊ธฐํ ๋ฌธ์ ์์ ํด๊ฒฐ | |
| 4. **ํ์ฅ์ฑ**: ๋ ํฐ ๋ชจ๋ธ๊ณผ ๋ ๋ง์ ๋ผ์ด๋ ์ง์ ๊ฐ๋ฅ | |
| ์ด ์ํคํ ์ฒ๋ TTRLVR์ ์ ๊ตํ ๋ฐ์ดํฐ ์์ฑ ๋ฅ๋ ฅ๊ณผ VeRL์ ํจ์จ์ ์ธ PPO ํ์ต์ ์๋ฒฝํ๊ฒ ๊ฒฐํฉํ์ต๋๋ค. |