hjkim00 commited on
Commit
24c2665
·
verified ·
1 Parent(s): e7b37ff

Restore all essential files - code, configs, and MBPP/HumanEval data

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. LICENSE +21 -0
  3. README.md +581 -0
  4. Update/2025-01-25_humaneval_fixes.md +113 -0
  5. Update/Phase1_Infrastructure_Setup.md +65 -0
  6. Update/Phase2_Benchmark_System.md +85 -0
  7. Update/Phase3_AZR_Template_Integration.md +100 -0
  8. Update/Phase3_IPO_Extraction.md +129 -0
  9. Update/Phase4_Complete_Pipeline_Implementation.md +203 -0
  10. Update/Phase5_Critical_Bug_Fixes_and_EvalPlus_Integration.md +226 -0
  11. Update/unified_ttrlvr_architecture.md +646 -0
  12. absolute_zero_reasoner/__init__.py +0 -0
  13. absolute_zero_reasoner/configs/azr_ppo_trainer.yaml +605 -0
  14. absolute_zero_reasoner/data_construction/__init__.py +0 -0
  15. absolute_zero_reasoner/data_construction/constructor.py +225 -0
  16. absolute_zero_reasoner/data_construction/process_code_reasoning_data.py +175 -0
  17. absolute_zero_reasoner/data_construction/process_data.py +210 -0
  18. absolute_zero_reasoner/data_construction/prompts.py +546 -0
  19. absolute_zero_reasoner/main_azr_ppo.py +260 -0
  20. absolute_zero_reasoner/rewards/__init__.py +0 -0
  21. absolute_zero_reasoner/rewards/code_reward.py +554 -0
  22. absolute_zero_reasoner/rewards/custom_evaluate.py +387 -0
  23. absolute_zero_reasoner/rewards/math_utils.py +490 -0
  24. absolute_zero_reasoner/rewards/reward_managers.py +898 -0
  25. absolute_zero_reasoner/rewards/ttrlvr_reward_manager.py +244 -0
  26. absolute_zero_reasoner/testtime/__init__.py +34 -0
  27. absolute_zero_reasoner/testtime/benchmark_loader.py +223 -0
  28. absolute_zero_reasoner/testtime/complete_pipeline.py +0 -0
  29. absolute_zero_reasoner/testtime/config.py +162 -0
  30. absolute_zero_reasoner/testtime/ipo_extractor.py +1235 -0
  31. absolute_zero_reasoner/testtime/logger.py +295 -0
  32. absolute_zero_reasoner/testtime/prompts.py +413 -0
  33. absolute_zero_reasoner/testtime/solution_generator.py +877 -0
  34. absolute_zero_reasoner/testtime/task_generator.py +473 -0
  35. absolute_zero_reasoner/trainer/__init__.py +0 -0
  36. absolute_zero_reasoner/trainer/ppo/__init__.py +0 -0
  37. absolute_zero_reasoner/trainer/ppo/azr_ray_trainer.py +0 -0
  38. absolute_zero_reasoner/trainer/ppo/reason_rl_ray_trainer.py +768 -0
  39. absolute_zero_reasoner/trainer/ppo/ttrlvr_azr_integration.py +125 -0
  40. absolute_zero_reasoner/utils/__init__.py +0 -0
  41. absolute_zero_reasoner/utils/auxiliary.py +11 -0
  42. absolute_zero_reasoner/utils/code_utils/__init__.py +0 -0
  43. absolute_zero_reasoner/utils/code_utils/checks.py +182 -0
  44. absolute_zero_reasoner/utils/code_utils/parsers.py +202 -0
  45. absolute_zero_reasoner/utils/code_utils/python_executor.py +435 -0
  46. absolute_zero_reasoner/utils/code_utils/sandboxfusion_executor.py +372 -0
  47. absolute_zero_reasoner/utils/code_utils/templates.py +68 -0
  48. absolute_zero_reasoner/utils/convert2hf.py +55 -0
  49. absolute_zero_reasoner/utils/dataset/__init__.py +0 -0
  50. absolute_zero_reasoner/utils/dataset/ipo_grouped_sampler.py +220 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/absolute_zero_paradigm.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/azr.png filter=lfs diff=lfs merge=lfs -text
38
+ evaluation/code_eval/coding/evalplus/gallary/render.gif filter=lfs diff=lfs merge=lfs -text
39
+ evaluation/math_eval/eval/data/tabmwp/test.jsonl filter=lfs diff=lfs merge=lfs -text
40
+ evaluation/math_eval/eval/latex2sympy/antlr-4.11.1-complete.jar filter=lfs diff=lfs merge=lfs -text
41
+ evaluation/math_eval/latex2sympy/antlr-4.11.1-complete.jar filter=lfs diff=lfs merge=lfs -text
42
+ logs/task_generation/tasks_Mbpp_7.json filter=lfs diff=lfs merge=lfs -text
43
+ test/logs/task_generation/tasks_HumanEval_28.json filter=lfs diff=lfs merge=lfs -text
44
+ test/logs/task_generation/tasks_Mbpp_2.json filter=lfs diff=lfs merge=lfs -text
45
+ test/logs/task_generation/tasks_Mbpp_242.json filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 LeapLab
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ # TestTime RLVR: Test-Time Reinforcement Learning with Verification and Reasoning
4
+ *Based on Absolute Zero Reasoner (AZR) Methodology*
5
+
6
+ [![Paper](https://img.shields.io/badge/paper-A42C25?style=for-the-badge&logo=arxiv&logoColor=white)](https://arxiv.org/abs/2505.03335) [![Project Page](https://img.shields.io/badge/Project%20Page-blue?style=for-the-badge&logo=snowflake&logoColor=white&labelColor=black)](https://andrewzh112.github.io/absolute-zero-reasoner/) [![Github](https://img.shields.io/badge/Code-000000?style=for-the-badge&logo=github&logoColor=000&logoColor=white)](https://github.com/LeapLabTHU/Absolute-Zero-Reasoner) [![Hugging Face Collection](https://img.shields.io/badge/AZR_Collection-fcd022?style=for-the-badge&logo=huggingface&logoColor=000)](https://huggingface.co/collections/andrewzh/absolute-zero-reasoner-68139b2bca82afb00bc69e5b) [![W&B Logs](https://img.shields.io/badge/📁_W%26B_Logs-fcd022?style=for-the-badge&logo=wandb&logoColor=000)](https://wandb.ai/andrewzhao112/AbsoluteZeroReasoner)
7
+
8
+ <div align="center" style="font-family: Arial, sans-serif;">
9
+ <p>
10
+ <a href="#news" style="text-decoration: none; font-weight: bold;">🎉 News</a> •
11
+ <a href="#links" style="text-decoration: none; font-weight: bold;">🔗 Links</a> •
12
+ <a href="#todo" style="text-decoration: none; font-weight: bold;">📝 Roadmap</a> •
13
+ <a href="#algorithm-flow" style="text-decoration: none; font-weight: bold;">⚙️ Algorithm Flow</a> •
14
+ <a href="#results" style="text-decoration: none; font-weight: bold;">📊 Results</a>
15
+ </p>
16
+ <p>
17
+ <a href="#getting-started" style="text-decoration: none; font-weight: bold;">✨ Getting Started</a> •
18
+ <a href="#training" style="text-decoration: none; font-weight: bold;">🏋️ Training</a> •
19
+ <a href="#usage" style="text-decoration: none; font-weight: bold;">🔧 Usage</a> •
20
+ <a href="#evaluation-code" style="text-decoration: none; font-weight: bold;">📃 Evaluation</a>
21
+ </p>
22
+ <p>
23
+ <a href="#citation" style="text-decoration: none; font-weight: bold;">🎈 Citation</a> •
24
+ <a href="#acknowledgement" style="text-decoration: none; font-weight: bold;">🌻 Acknowledgement</a> •
25
+ <a href="#contact" style="text-decoration: none; font-weight: bold;">📧 Contact</a> •
26
+ <a href="#star-history" style="text-decoration: none; font-weight: bold;">📈 Star History</a>
27
+ </p>
28
+ </div>
29
+
30
+ </div>
31
+
32
+ # 🚀 TestTime RLVR Implementation
33
+
34
+ ## 📋 Overview
35
+ TestTime RLVR implements test-time reinforcement learning for enhanced reasoning capabilities using the AZR (Absolute Zero Reasoner) methodology. The system generates Input-Program-Output (IPO) triples from benchmark problems and creates three types of reasoning tasks (induction, deduction, abduction) to improve model performance at test time.
36
+
37
+ ## 🎯 Key Features
38
+ - **Complete Pipeline**: LLM Solution Generation → IPO Extraction → Task Generation → LLM Evaluation → Reward Computation
39
+ - **AZR Integration**: Full integration with Absolute Zero Reasoner templates and evaluation methods
40
+ - **Benchmark Support**: MBPP+ and HumanEval+ datasets with structured data extraction
41
+ - **Execution-based Evaluation**: Program execution comparison instead of string matching
42
+ - **VLLM Optimization**: Faster inference with VLLM backend support
43
+
44
+ ## 📈 Implementation Status
45
+ - ✅ **Phase 1**: Infrastructure Setup - Complete pipeline architecture
46
+ - ✅ **Phase 2**: Benchmark System - MBPP+/HumanEval+ integration
47
+ - ✅ **Phase 3**: AZR Template Integration - Three reasoning tasks implementation
48
+ - ✅ **Phase 4**: Complete Pipeline - Fully functional end-to-end system
49
+ - 🔄 **Phase 5**: RLVR Training - Reinforcement learning integration (In Progress)
50
+
51
+
52
+ ## 📦 Dataset Setup
53
+
54
+ **Download required benchmark datasets:**
55
+ ```bash
56
+ # Download MBPP+ and HumanEval+ datasets
57
+ wget -O evaluation/code_eval/data/MbppPlus.jsonl https://huggingface.co/datasets/evalplus/mbppplus/resolve/main/MbppPlus.jsonl
58
+ wget -O evaluation/code_eval/data/HumanEvalPlus.jsonl https://huggingface.co/datasets/evalplus/humanevalplus/resolve/main/HumanEvalPlus.jsonl
59
+ ```
60
+
61
+ ## 🚀 Quick Start
62
+
63
+ ### Running the Pipeline
64
+ ```bash
65
+ # Navigate to test directory
66
+ cd test/
67
+
68
+ # Set GPU device
69
+ export CUDA_VISIBLE_DEVICES=6
70
+
71
+ # Execute complete pipeline
72
+ bash run_testtime_gpu6.sh
73
+ ```
74
+
75
+ ### Command Line Options
76
+ ```bash
77
+ # From test/ directory
78
+ python test_complete_pipeline.py \
79
+ --model "Qwen/Qwen2.5-7B" \
80
+ --benchmark "mbpp" \
81
+ --problem_id "Mbpp/478" \
82
+ --max_tokens 2048 \
83
+ --gpu 6 \
84
+ --verbose \
85
+ --output_dir ../tmp
86
+ ```
87
+
88
+ ### Batch Evaluation
89
+ ```bash
90
+ # From test/ directory
91
+ bash run_batch_evaluation.sh "Qwen/Qwen2.5-7B" "mbpp" 10 6
92
+ ```
93
+
94
+ ### Supported Benchmarks
95
+ - **MBPP+**: `--benchmark mbpp --problem_id "Mbpp/X"`
96
+ - **HumanEval+**: `--benchmark humaneval --problem_id "HumanEval/X"`
97
+ - **Test Mode**: `--benchmark test` (example problems)
98
+
99
+ ## 📊 Results Structure
100
+ ```
101
+ tmp/{benchmark}/{problem_id}/ # Single problem results
102
+ ├── initial_solution/ # LLM's original solution + correctness
103
+ │ ├── {problem_id}_original_problem.txt # Original benchmark problem
104
+ │ ├── {problem_id}_llm_solution.txt # LLM solution + correctness evaluation
105
+ │ └── {problem_id}_extracted_program.py # Extracted function code
106
+ ├── ipo_triples/ # Input-Program-Output triples
107
+ ├── task_prompts/ # Generated reasoning tasks
108
+ ├── llm_responses/ # LLM responses to tasks
109
+ ├── extracted_answers/ # Extracted answers from responses
110
+ ├── {problem_id}_reward_analysis.json
111
+ ├── {problem_id}_reward_summary.txt
112
+ └── {problem_id}_pipeline_summary.json
113
+
114
+ test/batch_results/ # Batch evaluation results
115
+ ├── batch_evaluation_{timestamp}/
116
+ │ ├── batch_evaluation_results.json # Detailed results with correctness stats
117
+ │ └── evaluation_summary.md # Summary report with accuracy rates
118
+ ```
119
+
120
+ <!-- ============================================== -->
121
+ <div align="left">
122
+ <h1 id="links">🔗 AZR References</h1>
123
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
124
+ </div>
125
+
126
+ - 🏠 [[AZR Project Page]](https://andrewzh112.github.io/absolute-zero-reasoner/)
127
+ - 📜 [[AZR Paper]](https://arxiv.org/abs/2505.03335)
128
+ - 🤗 [[AZR Models]](https://huggingface.co/collections/andrewzh/absolute-zero-reasoner-68139b2bca82afb00bc69e5b)
129
+ - 💻 [[AZR Code]](https://github.com/LeapLabTHU/Absolute-Zero-Reasoner)
130
+ - 📁 [[AZR Logs]](https://wandb.ai/andrewzhao112/AbsoluteZeroReasoner)
131
+
132
+ <!-- ============================================== -->
133
+ <div align="left">
134
+ <h1 id="todo">📝 TestTime RLVR Roadmap</h1>
135
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
136
+ </div>
137
+
138
+ <div style="margin-bottom: 0.8rem; padding: 0.8rem 1.2rem; background-color: rgba(87, 85, 163, 0.1); border-left: 5px solid #5755A3; border-radius: 8px; display: flex; align-items: center;">
139
+ <span style="font-size: 1.2em; margin-right: 0.8rem; color: #5755A3;">✅</span>
140
+ <span style="text-decoration: line-through; color: #AAA; font-size: 1.1em;">Complete Pipeline Implementation</span>
141
+ </div>
142
+
143
+ <div style="margin-bottom: 0.8rem; padding: 0.8rem 1.2rem; background-color: rgba(87, 85, 163, 0.1); border-left: 5px solid #5755A3; border-radius: 8px; display: flex; align-items: center;">
144
+ <span style="font-size: 1.2em; margin-right: 0.8rem; color: #5755A3;">✅</span>
145
+ <span style="text-decoration: line-through; color: #AAA; font-size: 1.1em;">IPO Triple Extraction with Structured Data</span>
146
+ </div>
147
+
148
+ <div style="margin-bottom: 0.8rem; padding: 0.8rem 1.2rem; background-color: rgba(87, 85, 163, 0.1); border-left: 5px solid #5755A3; border-radius: 8px; display: flex; align-items: center;">
149
+ <span style="font-size: 1.2em; margin-right: 0.8rem; color: #5755A3;">✅</span>
150
+ <span style="text-decoration: line-through; color: #AAA; font-size: 1.1em;">Three Reasoning Tasks (Induction/Deduction/Abduction)</span>
151
+ </div>
152
+
153
+ <div style="margin-bottom: 0.8rem; padding: 0.8rem 1.2rem; background-color: rgba(87, 85, 163, 0.1); border-left: 5px solid #5755A3; border-radius: 8px; display: flex; align-items: center;">
154
+ <span style="font-size: 1.2em; margin-right: 0.8rem; color: #5755A3;">✅</span>
155
+ <span style="text-decoration: line-through; color: #AAA; font-size: 1.1em;">Execution-based Evaluation System</span>
156
+ </div>
157
+
158
+ <div style="margin-bottom: 0.8rem; padding: 0.8rem 1.2rem; background-color: rgba(239, 142, 141, 0.1); border-left: 5px solid #EF8E8D; border-radius: 8px; display: flex; align-items: center;">
159
+ <span style="font-size: 1.2em; margin-right: 0.8rem; color: #EF8E8D;">🔄</span>
160
+ <span style="color: #333; font-size: 1.1em;">VeRL Integration for RLVR Training</span>
161
+ </div>
162
+
163
+ <div style="margin-bottom: 0.8rem; padding: 0.8rem 1.2rem; background-color: rgba(239, 142, 141, 0.1); border-left: 5px solid #EF8E8D; border-radius: 8px; display: flex; align-items: center;">
164
+ <span style="font-size: 1.2em; margin-right: 0.8rem; color: #EF8E8D;">📋</span>
165
+ <span style="color: #333; font-size: 1.1em;">Multi-Problem Batch Processing</span>
166
+ </div>
167
+
168
+ <!-- ============================================== -->
169
+ <div align="left">
170
+ <h1 id="algorithm-flow">⚙️ TestTime RLVR Algorithm Flow</h1>
171
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
172
+ </div>
173
+
174
+ TestTime RLVR implements a comprehensive test-time reasoning pipeline based on AZR methodology:
175
+
176
+ ### 🔄 Pipeline Stages
177
+
178
+ 1. **<span style="color:#EF8E8D">LLM Solution Generation</span>**: The model generates an initial solution for a given benchmark problem (MBPP+/HumanEval+)
179
+
180
+ 2. **<span style="color:#5755A3">IPO Triple Extraction</span>**: Input-Program-Output triples are created using structured benchmark data and LLM solution execution
181
+
182
+ 3. **<span style="color:#EF8E8D">Task Generation</span>**: Three types of reasoning tasks are generated:
183
+ - **Induction**: Deduce function from input/output pairs + message
184
+ - **Deduction**: Predict output from code + input
185
+ - **Abduction**: Predict input from code + output
186
+
187
+ 4. **<span style="color:#5755A3">LLM Evaluation</span>**: The model attempts to solve the generated reasoning tasks using AZR prompts and templates
188
+
189
+ 5. **<span style="color:#EF8E8D">Reward Computation</span>**: Solutions are verified through program execution, receiving accuracy-based rewards
190
+
191
+ ### 🎯 Key Innovations
192
+ - **Structured Data Integration**: Direct use of benchmark `base_input`/`plus_input` instead of assert parsing
193
+ - **Execution-based Evaluation**: Program execution comparison for accurate task evaluation
194
+ - **Function Name Normalization**: Consistent `f` function naming following AZR methodology
195
+ - **Docstring Utilization**: LLM-generated docstrings enhance induction task quality
196
+
197
+ <!-- ============================================== -->
198
+ <div align="left">
199
+ <h1 id="results">📊 Results</h1>
200
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
201
+ </div>
202
+
203
+ ## Main Results
204
+
205
+ Our approach achieves strong performance across both code and math reasoning benchmarks without using any external data:
206
+
207
+ <table>
208
+ <thead>
209
+ <tr>
210
+ <th align="center">Model</th>
211
+ <th align="center">Base</th>
212
+ <th align="center">#data</th>
213
+ <th align="center">Code Avg</th>
214
+ <th align="center">Math Avg</th>
215
+ <th align="center">Total Avg</th>
216
+ </tr>
217
+ </thead>
218
+ <tbody>
219
+ <!-- Base Models Section -->
220
+ <tr>
221
+ <td colspan="6" align="center"><b>Base Models</b></td>
222
+ </tr>
223
+ <tr>
224
+ <td>Qwen2.5-7B</td>
225
+ <td>-</td>
226
+ <td>-</td>
227
+ <td>52.0</td>
228
+ <td>27.5</td>
229
+ <td>39.8</td>
230
+ </tr>
231
+ <tr>
232
+ <td>Qwen2.5-7B-Ins</td>
233
+ <td>-</td>
234
+ <td>-</td>
235
+ <td>56.3</td>
236
+ <td>37.0</td>
237
+ <td>46.7</td>
238
+ </tr>
239
+ <tr>
240
+ <td>Qwen2.5-7B-Coder</td>
241
+ <td>-</td>
242
+ <td>-</td>
243
+ <td>56.6</td>
244
+ <td>23.9</td>
245
+ <td>40.2</td>
246
+ </tr>
247
+ <!-- Zero-Style Reasoners with Code Data -->
248
+ <tr>
249
+ <td colspan="6" align="center"><b>Reasoners Trained on Curated Code Data</b></td>
250
+ </tr>
251
+ <tr>
252
+ <td>AceCoder-RM</td>
253
+ <td>Ins</td>
254
+ <td>22k</td>
255
+ <td>58.3</td>
256
+ <td>37.4</td>
257
+ <td>47.9</td>
258
+ </tr>
259
+ <tr>
260
+ <td>AceCoder-RM</td>
261
+ <td>Coder</td>
262
+ <td>22k</td>
263
+ <td>57.3</td>
264
+ <td>27.5</td>
265
+ <td>42.4</td>
266
+ </tr>
267
+ <tr>
268
+ <td>AceCoder-Rule</td>
269
+ <td>Ins</td>
270
+ <td>22k</td>
271
+ <td>55.4</td>
272
+ <td>36.9</td>
273
+ <td>46.2</td>
274
+ </tr>
275
+ <tr>
276
+ <td>AceCoder-Rule</td>
277
+ <td>Coder</td>
278
+ <td>22k</td>
279
+ <td>60.0</td>
280
+ <td>28.5</td>
281
+ <td>44.3</td>
282
+ </tr>
283
+ <tr>
284
+ <td>CodeR1-LC2k</td>
285
+ <td>Ins</td>
286
+ <td>2k</td>
287
+ <td>60.5</td>
288
+ <td>35.6</td>
289
+ <td>48.0</td>
290
+ </tr>
291
+ <tr>
292
+ <td>CodeR1-12k</td>
293
+ <td>Ins</td>
294
+ <td>10k</td>
295
+ <td>61.3</td>
296
+ <td>33.5</td>
297
+ <td>47.4</td>
298
+ </tr>
299
+ <!-- Zero-Style Reasoners with Math Data -->
300
+ <tr>
301
+ <td colspan="6" align="center"><b>Reasoners Trained on Curated Math Data</b></td>
302
+ </tr>
303
+ <tr>
304
+ <td>PRIME-Zero</td>
305
+ <td>Coder</td>
306
+ <td>484k</td>
307
+ <td>37.2</td>
308
+ <td><b>45.8</b></td>
309
+ <td>41.5</td>
310
+ </tr>
311
+ <tr>
312
+ <td>SimpleRL-Zoo</td>
313
+ <td>Base</td>
314
+ <td>8.5k</td>
315
+ <td>54.0</td>
316
+ <td>38.5</td>
317
+ <td>46.3</td>
318
+ </tr>
319
+ <tr>
320
+ <td>Oat-Zero</td>
321
+ <td>Math</td>
322
+ <td>8.5k</td>
323
+ <td>45.4</td>
324
+ <td>44.3</td>
325
+ <td>44.9</td>
326
+ </tr>
327
+ <tr>
328
+ <td>ORZ</td>
329
+ <td>Base</td>
330
+ <td>57k</td>
331
+ <td>55.6</td>
332
+ <td>41.6</td>
333
+ <td>48.6</td>
334
+ </tr>
335
+ <!-- Our Approach -->
336
+ <tr style="background-color: rgba(239, 142, 141, 0.1);">
337
+ <td colspan="6" align="center"><b>Absolute Zero Training w/ No Curated Data (Ours)</b></td>
338
+ </tr>
339
+ <tr style="background-color: rgba(239, 142, 141, 0.1);">
340
+ <td>AZR (Ours)</td>
341
+ <td>Base</td>
342
+ <td><b>0</b></td>
343
+ <td>55.2 <span style="color:#00AA00">+3.2</span></td>
344
+ <td>38.4 <span style="color:#00AA00">+10.9</span></td>
345
+ <td>46.8 <span style="color:#00AA00">+7.0</span></td>
346
+ </tr>
347
+ <tr style="background-color: rgba(87, 85, 163, 0.1);">
348
+ <td>AZR (Ours)</td>
349
+ <td>Coder</td>
350
+ <td><b>0</b></td>
351
+ <td><b>61.6</b> <span style="color:#00AA00">+5.0</span></td>
352
+ <td>39.1 <span style="color:#00AA00">+15.2</span></td>
353
+ <td><b>50.4</b> <span style="color:#00AA00">+10.2</span></td>
354
+ </tr>
355
+ </tbody>
356
+ </table>
357
+
358
+ ## Scaling Results
359
+
360
+ AZR shows consistent improvements across model sizes and types:
361
+
362
+ <table>
363
+ <thead>
364
+ <tr>
365
+ <th align="center">Model Family</th>
366
+ <th align="center">Variant</th>
367
+ <th align="center">Code Avg</th>
368
+ <th align="center">Math Avg</th>
369
+ <th align="center">Total Avg</th>
370
+ </tr>
371
+ </thead>
372
+ <tbody>
373
+ <tr>
374
+ <td>Llama3.1-8b</td>
375
+ <td></td>
376
+ <td>28.5</td>
377
+ <td>3.4</td>
378
+ <td>16.0</td>
379
+ </tr>
380
+ <tr style="background-color: rgba(87, 85, 163, 0.1);">
381
+ <td>Llama3.1-8b</td>
382
+ <td>+ AZR (Ours)</td>
383
+ <td>31.6 <span style="color:#00AA00">+3.1</span></td>
384
+ <td>6.8 <span style="color:#00AA00">+3.4</span></td>
385
+ <td>19.2 <span style="color:#00AA00">+3.2</span></td>
386
+ </tr>
387
+ <tr>
388
+ <td>Qwen2.5-3B Coder</td>
389
+ <td></td>
390
+ <td>51.2</td>
391
+ <td>18.8</td>
392
+ <td>35.0</td>
393
+ </tr>
394
+ <tr style="background-color: rgba(87, 85, 163, 0.1);">
395
+ <td>Qwen2.5-3B Coder</td>
396
+ <td>+ AZR (Ours)</td>
397
+ <td>54.9 <span style="color:#00AA00">+3.7</span></td>
398
+ <td>26.5 <span style="color:#00AA00">+7.7</span></td>
399
+ <td>40.7 <span style="color:#00AA00">+5.7</span></td>
400
+ </tr>
401
+ <tr>
402
+ <td>Qwen2.5-7B Coder</td>
403
+ <td></td>
404
+ <td>56.6</td>
405
+ <td>23.9</td>
406
+ <td>40.2</td>
407
+ </tr>
408
+ <tr style="background-color: rgba(87, 85, 163, 0.1);">
409
+ <td>Qwen2.5-7B Coder</td>
410
+ <td>+ AZR (Ours)</td>
411
+ <td>61.6 <span style="color:#00AA00">+5.0</span></td>
412
+ <td>39.1 <span style="color:#00AA00">+15.2</span></td>
413
+ <td>50.4 <span style="color:#00AA00">+10.2</span></td>
414
+ </tr>
415
+ <tr>
416
+ <td>Qwen2.5-14B Coder</td>
417
+ <td></td>
418
+ <td>60.0</td>
419
+ <td>20.2</td>
420
+ <td>40.1</td>
421
+ </tr>
422
+ <tr style="background-color: rgba(87, 85, 163, 0.1);">
423
+ <td>Qwen2.5-14B Coder</td>
424
+ <td>+ AZR (Ours)</td>
425
+ <td>63.6 <span style="color:#00AA00">+3.6</span></td>
426
+ <td>43.0 <span style="color:#00AA00">+22.8</span></td>
427
+ <td>53.3 <span style="color:#00AA00">+13.2</span></td>
428
+ </tr>
429
+ </tbody>
430
+ </table>
431
+
432
+ <!-- ============================================== -->
433
+ <div align="left">
434
+ <h1 id="getting-started">✨ Getting Started</h1>
435
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
436
+ </div>
437
+
438
+ ## 🎄 Environment Setup
439
+ ```bash
440
+ conda env create -f azr_env.yml
441
+ conda activate azr
442
+ pip install -r flashattn_requirements.txt
443
+ ```
444
+
445
+ ## 💾 Data Processing
446
+ ### Process evaluation data on CruxEval / LiveCodeBench Execution during AZR Self-play
447
+ ```bash
448
+ python -m absolute_zero_reasoner.data_construction.process_code_reasoning_data
449
+ ```
450
+
451
+ <!-- ============================================== -->
452
+ <div align="left">
453
+ <h1 id="training">🏋️ Training</h1>
454
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
455
+ </div>
456
+
457
+ > **⚠️WARNING⚠️**: The Python executor in this repository is very raw and intended for research purposes only. It is not secure for production environments. We plan to update our executor to more secure implementations in the future. Your use of our code is at your own discretion and risk.
458
+
459
+
460
+ ## 🫛 Seeding (Optional)
461
+ We provide the seed datasets we collected by prompting each model in data/. If you want to create your own seed data, use the following script:
462
+ ```bash
463
+ export OUTPUT_SEED_PATH=data/<new_ded_abd_seed_data_name>.jsonl
464
+ export OUTPUT_CODE_F_SEED_PATH=data/<new_ind_seed_data_name>.jsonl
465
+ bash scripts/seeding/<7b|14b|coder3b|coder7b|coder14b|llama>.sh
466
+ ```
467
+
468
+ ## ♟️ Self-play
469
+ 3b models need 2 X 80gb GPUs, 7/8b models need 4 X 80gb, 14b requires 8 X 80gb
470
+ ```bash
471
+ bash scripts/selfplay/<7b|14b|coder3b|coder7b|coder14b|llama>.sh
472
+ ```
473
+ If you want to use your own ded/abd or ind seed dataset:
474
+ ```bash
475
+ export OUTPUT_SEED_PATH=data/<your_ded_abd_seed_data_name>.jsonl
476
+ export OUTPUT_CODE_F_SEED_PATH=data/<your_ind_seed_data_name>.jsonl
477
+ bash scripts/selfplay/<7b|14b|coder3b|coder7b|coder14b|llama>.sh
478
+ ```
479
+ For using the newly supported sandbox-fusion executor, use docker and set `azr.executor=sandboxfusion`.
480
+
481
+ ## 🌚 Resuming Runs
482
+ When resuming runs, put the original run wandb id into the script, i.e., `trainer.wandb_run_id=<run_id>`.
483
+
484
+ ## 🤗 Converting veRL checkpoints to HF format
485
+ ```bash
486
+ python -m absolute_zero_reasoner.utils.convert2hf \
487
+ <veRL_ckpt_path>/actor \
488
+ <veRL_ckpt_path>/actor/huggingface/ \
489
+ <hf_ckpt_path>
490
+ ```
491
+
492
+ ## 📈Design Your Own Intrinsic Rewards!
493
+ In configs, just add your own rewards to `azr.reward.generation_reward_config`, check the ones already implemented such as diversity and complexity rewards. Be Creative!
494
+
495
+ <!-- ============================================== -->
496
+ <div align="left">
497
+ <h1 id="usage">🔧 Usage</h1>
498
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
499
+ </div>
500
+
501
+ We use the Deepseek R1 <think> & <answer> tags as prompt template:
502
+
503
+ ```
504
+ A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {question}\nAssistant: <think>
505
+ ```
506
+
507
+ <!-- ============================================== -->
508
+ <div align="left">
509
+ <h1 id="evaluation-code">📃 Evaluation Code</h1>
510
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
511
+ </div>
512
+
513
+ ## LiveCodeBench
514
+ Setup: LCB needs to first download the data
515
+ ```bash
516
+ git clone https://hf-mirror.com/datasets/livecodebench/code_generation_lite evaluation/code_eval/coding/LiveCodeBench/code_generation_lite
517
+ ```
518
+ Evaluation:
519
+ ```bash
520
+ bash evaluation/code_eval/scripts/run_lcb_gen.sh --model <andrewzh/Absolute_Zero_Reasoner-Coder-3b>
521
+ ```
522
+
523
+ ## Evalplus
524
+ New conda env is neede for evalplus
525
+ ```bash
526
+ conda create -n evalplus python=3.11
527
+ pip install --upgrade "evalplus[vllm] @ git+https://github.com/evalplus/evalplus@d362e933265c3e7e3df8101c930a89c3c470cd9f"
528
+ Evaluation:
529
+ ```bash
530
+ condda activate evalplus
531
+ bash evaluation/code_eval/scripts/run_evalplus.sh 0 <humaneval|mbpp> <andrewzh/Absolute_Zero_Reasoner-Coder-3b>
532
+ ```
533
+
534
+ ## Math
535
+ Please refer to [evaluation/math_eval/README.md](evaluation/math_eval/README.md) for math evaluation.
536
+
537
+
538
+ <!-- ============================================== -->
539
+ <div align="left">
540
+ <h1 id="citation">🎈 Citation</h1>
541
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
542
+ </div>
543
+
544
+ If you find Absolute Zero Reasoner helpful, please cite us.
545
+
546
+ ```bibtex
547
+ @misc{zhao2025absolutezeroreinforcedselfplay,
548
+ title={Absolute Zero: Reinforced Self-play Reasoning with Zero Data},
549
+ author={Andrew Zhao and Yiran Wu and Yang Yue and Tong Wu and Quentin Xu and Yang Yue and Matthieu Lin and Shenzhi Wang and Qingyun Wu and Zilong Zheng and Gao Huang},
550
+ year={2025},
551
+ eprint={2505.03335},
552
+ archivePrefix={arXiv},
553
+ primaryClass={cs.LG},
554
+ url={https://arxiv.org/abs/2505.03335},
555
+ }
556
+ ```
557
+
558
+ <!-- ============================================== -->
559
+ <div align="left">
560
+ <h1 id="acknowledgement">🌻 Acknowledgement</h1>
561
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
562
+ </div>
563
+
564
+ Our reinforcement learning training codebase is a fork of the [veRL framework](https://github.com/volcengine/verl). For rollouts, we used [vLLM](https://github.com/vllm-project/vllm). The Python executor components are adapted from the [QwQ Repository](https://github.com/QwenLM/QwQ/tree/main/eval/eval/math_opensource_utils). Additionally, we borrowed our README structure from [PRIME](https://github.com/PRIME-RL/PRIME).
565
+ Many thanks to the authors of these projects for their excellent contributions!
566
+
567
+ <!-- ============================================== -->
568
+ <div align="left">
569
+ <h1 id="contact">📧 Contact</h1>
570
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
571
+ </div>
572
+
573
+ Feel free to contact Andrew Zhao via email: [email protected]
574
+
575
+ <!-- ============================================== -->
576
+ <div align="left">
577
+ <h1 id="star-history">📈 Star History</h1>
578
+ <hr style="height: 3px; background: linear-gradient(90deg, #EF8E8D, #5755A3); border: none; border-radius: 3px;">
579
+ </div>
580
+
581
+ [![Star History Chart](https://api.star-history.com/svg?repos=LeapLabTHU/Absolute-Zero-Reasoner&type=Date)](https://www.star-history.com/#LeapLabTHU/Absolute-Zero-Reasoner&Date)
Update/2025-01-25_humaneval_fixes.md ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TestTime RLVR-v2 HumanEval 평가 수정 사항
2
+ 날짜: 2025-01-25
3
+
4
+ ## 개요
5
+ HumanEval 벤치마크에서 0% 정확도 문제를 해결하기 위한 전체적인 수정 작업을 수행했습니다.
6
+
7
+ ## 주요 문제점 및 해결 방안
8
+
9
+ ### 1. Import 문 누락 문제
10
+ **문제**: HumanEval 솔루션에서 `from typing import List` 등의 import 문이 누락되어 실행 실패
11
+ **해결**:
12
+ - EvalPlus 방식과 동일하게 프롬프트에서 import 문을 추출하여 자동 추가
13
+ - `_add_imports_from_prompt()` 메서드 추가
14
+ - 자동으로 import를 추가하는 치팅 방식 제거
15
+
16
+ ### 2. IPO Triple 추출 문제
17
+ **문제**:
18
+ - base_input의 첫 번째 항목만 사용
19
+ - HumanEval에서 테스트 케이스를 사용하여 IPO 생성 (치팅)
20
+ **해결**:
21
+ - HumanEval은 docstring 예제만 사용하도록 변경
22
+ - `_extract_docstring_examples()` 메서드 추가
23
+ - 입력 형식 분리: 평가용 인자와 표시용 전체 함수 호출
24
+
25
+ ### 3. 프롬프트 일관성 문제
26
+ **문제**:
27
+ - `batch_evaluate_testtime.py`의 하드코딩된 프롬프트가 `solution_generator.py`와 불일치
28
+ - HumanEval/50과 같은 다중 함수 문제 처리 미흡
29
+ **해결**:
30
+ - 모든 프롬프트를 `solution_generator.py`와 일치하도록 수정
31
+ - 다중 함수 케이스를 위한 특별 처리 추가
32
+
33
+ ### 4. Task 생성 시 문제
34
+ **문제**:
35
+ - HumanEval에서 doctest 예시가 포함되어 치팅 발생
36
+ - Induction task의 message가 일반적인 메시지 사용
37
+ **해결**:
38
+ - `_remove_doctest_examples()` 메서드로 doctest 제거
39
+ - HumanEval의 경우 함수 설명을 추출하여 message로 사용
40
+
41
+ ### 5. 평가 실패 문제
42
+ **문제**:
43
+ - Induction: 전체 함수 호출을 사용하여 평가 실패
44
+ - Abduction: 인자만 저장되어 MBPP와 다른 형식으로 평가
45
+ **해결**:
46
+ - IPO triple에 `input`(인자)와 `full_input_str`(전체 호출) 분리 저장
47
+ - Abduction expected_solution을 `full_input_str` 사용하도록 수정
48
+
49
+ ## 수정된 파일 목록
50
+
51
+ ### 1. `/home/ubuntu/RLVR/TestTime-RLVR-v2/absolute_zero_reasoner/testtime/solution_generator.py`
52
+ - `_add_imports_from_prompt()` 메서드 추가
53
+ - `_add_missing_imports()` 제거 (치팅 방지)
54
+ - HumanEval용 프롬프트 개선
55
+ - 다중 함수 처리 로직 추가
56
+
57
+ ### 2. `/home/ubuntu/RLVR/TestTime-RLVR-v2/absolute_zero_reasoner/testtime/ipo_extractor.py`
58
+ - `_extract_docstring_examples()` 메서드 추가
59
+ - HumanEval은 docstring 예제만 사용하도록 수정
60
+ - 입력 형식 분리 (평가용/표시용)
61
+
62
+ ### 3. `/home/ubuntu/RLVR/TestTime-RLVR-v2/absolute_zero_reasoner/testtime/task_generator.py`
63
+ - `_remove_doctest_examples()` 메서드 추가
64
+ - `_extract_function_description()` 메서드 추가
65
+ - HumanEval induction message 개선
66
+ - Abduction expected_solution을 전체 함수 호출로 수정
67
+
68
+ ### 4. `/home/ubuntu/RLVR/TestTime-RLVR-v2/test/batch_evaluate_testtime.py`
69
+ - 하드코딩된 프롬프트를 `solution_generator.py`와 일치하도록 수정
70
+ - 전체 LLM 프롬프트 로깅 추가
71
+
72
+ ## 기술적 세부사항
73
+
74
+ ### IPO Triple 형식 차이
75
+ ```json
76
+ // MBPP (기존)
77
+ {
78
+ "input": "intersperse([], 4)",
79
+ "full_input_str": "intersperse([], 4)"
80
+ }
81
+
82
+ // HumanEval (수정됨)
83
+ {
84
+ "input": "[], 4", // 평가용 (인자만)
85
+ "full_input_str": "intersperse([], 4)" // 표시용 (전체 호출)
86
+ }
87
+ ```
88
+
89
+ ### Import 추출 로직
90
+ ```python
91
+ def _add_imports_from_prompt(self, prompt: str, solution: str) -> str:
92
+ # 프롬프트에서 import 문 추출
93
+ # solution 앞에 import 문 추가
94
+ # EvalPlus와 동일한 방식
95
+ ```
96
+
97
+ ### Doctest 제거
98
+ ```python
99
+ def _remove_doctest_examples(self, code: str) -> str:
100
+ # docstring 내의 >>> 예시 제거
101
+ # 함수 설명은 유지
102
+ ```
103
+
104
+ ## 성과
105
+ - HumanEval 평가가 정상적으로 작동
106
+ - 치팅 없이 공정한 평가 수행
107
+ - MBPP와 일관된 평가 방식 유지
108
+ - EvalPlus와 호환되는 import 처리
109
+
110
+ ## 향후 개선사항
111
+ - 더 많은 HumanEval 문제에 대한 테스트 필요
112
+ - 다양한 edge case 처리 개선
113
+ - 성능 최적화
Update/Phase1_Infrastructure_Setup.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 1: 기반 인프라 구축 완료
2
+
3
+ ## 📁 디렉토리 구조 설정
4
+
5
+ ### 새로 생성된 프로젝트
6
+ - `/home/ubuntu/RLVR/TestTime-RLVR-v2/` - AZR 기반 새 프로젝트
7
+
8
+ ### 핵심 디렉토리 구조
9
+ ```
10
+ TestTime-RLVR-v2/
11
+ ├── absolute_zero_reasoner/
12
+ │ ├── testtime/ # TestTime 전용 컴포넌트
13
+ │ │ ├── __init__.py # 모듈 초기화
14
+ │ │ └── config.py # TestTime 설정
15
+ │ ├── utils/code_utils/ # AZR Python Executor (기존)
16
+ │ ├── rewards/ # AZR Reward Manager (기존)
17
+ │ └── trainer/ppo/ # AZR PPO Trainer (기존)
18
+ ├── logs/ # 로깅 시스템
19
+ │ ├── problems/ # 문제별 로그
20
+ │ ├── ipo_extraction/ # IPO 추출 로그
21
+ │ ├── task_generation/ # 태스크 생성 로그
22
+ │ ├── training/ # 학습 로그
23
+ │ └── performance/ # 성능 변화 로그
24
+ ├── evaluation/code_eval/data/ # 벤치마크 데이터
25
+ │ ├── HumanEvalPlus.jsonl # ✅ 존재 확인
26
+ │ └── MbppPlus.jsonl # ✅ 존재 확인
27
+ └── Update/ # 변경사항 추적
28
+ ```
29
+
30
+ ## 🔧 생성된 핵심 컴포넌트
31
+
32
+ ### 1. TestTimeConfig 클래스
33
+ - **위치**: `absolute_zero_reasoner/testtime/config.py`
34
+ - **기능**: TestTime RLVR 전체 설정 관리
35
+ - **특징**: AZR 호환성 유지하면서 TestTime 특화 설정 추가
36
+
37
+ ### 2. BenchmarkConfig 클래스
38
+ - **위치**: `absolute_zero_reasoner/testtime/config.py`
39
+ - **기능**: 벤치마크별 설정 (HumanEval+, MBPP+)
40
+ - **특징**: 벤치마크별 시작 인덱스, 경로 등 관리
41
+
42
+ ## ✅ 완료된 작업
43
+
44
+ 1. **프로젝트 복사**: AZR → TestTime-RLVR-v2
45
+ 2. **디렉토리 구조**: 로그 및 컴포넌트 디렉토리 생성
46
+ 3. **기본 설정**: TestTimeConfig, BenchmarkConfig 클래스 생성
47
+ 4. **데이터 확인**: HumanEval+, MBPP+ 데이터 파일 존재 확인
48
+ 5. **모듈 구조**: testtime 패키지 초기화
49
+
50
+ ## 🎯 다음 단계 (Phase 2)
51
+
52
+ 1. **BenchmarkProblemLoader** 구현 - 벤치마크 문제 로딩
53
+ 2. **InitialSolutionGenerator** 구현 - 초기 솔루션 생성
54
+ 3. **벤치마크 검증 시스템** 구현 - 솔루션 정확성 검증
55
+
56
+ ## 📝 주요 설계 원칙
57
+
58
+ - **AZR 호환성**: 기존 AZR 컴포넌트 최대한 재사용
59
+ - **경량화**: TestTime에 적합한 빠른 적응 학습
60
+ - **포괄적 로깅**: 모든 단계별 상세 로그 기록
61
+ - **모듈성**: 각 컴포넌트 독립적 테스트 가능
62
+
63
+ ---
64
+ **생성 일시**: 2025-07-16
65
+ **상태**: ✅ 완료
Update/Phase2_Benchmark_System.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 2: 벤치마크 문제 풀이 시스템 완료
2
+
3
+ ## ✅ 구현된 컴포넌트
4
+
5
+ ### 1. BenchmarkProblemLoader
6
+ - **파일**: `absolute_zero_reasoner/testtime/benchmark_loader.py`
7
+ - **기능**:
8
+ - HumanEval+, MBPP+ 문제 로딩
9
+ - 테스트 케이스 추출 (assert 문 파싱)
10
+ - 솔루션 검증 (구문 + 실행)
11
+ - 배치 로딩 및 통계 정보 제공
12
+ - **기반**: 기존 `load_humaneval_problem` 함수 확장
13
+
14
+ ### 2. InitialSolutionGenerator
15
+ - **파일**: `absolute_zero_reasoner/testtime/solution_generator.py`
16
+ - **기능**:
17
+ - AZR 스타일 모델 로딩 (flash attention, gradient checkpointing)
18
+ - Greedy 생성 (AZR evaluation과 동일)
19
+ - 함수 정의 자동 복구
20
+ - 대체 솔루션 생성 (문제별 템플릿)
21
+ - **기반**: 기존 `generate_initial_solution` 함수 클래스화
22
+
23
+ ### 3. TestTimeLogger
24
+ - **파일**: `absolute_zero_reasoner/testtime/logger.py`
25
+ - **기능**:
26
+ - 요구사항 1: 벤치마크 문제 + LLM 답변 + 정답 여부
27
+ - 요구사항 2: IPO 추출 + 태스크 생성 로그
28
+ - 요구사항 3: 태스크 정확도 + reward 로그
29
+ - 요구사항 4: VeRL 학습 진행 로그
30
+ - JSON 형태 구조화된 로그 저장
31
+
32
+ ### 4. 설정 시스템
33
+ - **파일**: `absolute_zero_reasoner/testtime/config.py`
34
+ - **클래스**: `TestTimeConfig`, `BenchmarkConfig`
35
+ - **기능**: AZR 호환 + TestTime 특화 설정
36
+
37
+ ## 🧪 테스트 결과
38
+
39
+ ### 기본 기능 테스트 (✅ 3/3 통과)
40
+ ```
41
+ Configuration: ✅ PASS
42
+ Logger: ✅ PASS
43
+ BenchmarkLoader: ✅ PASS
44
+ ```
45
+
46
+ ### 검증된 기능
47
+ - ✅ MBPP 문제 로딩 (Mbpp/2 성공)
48
+ - ✅ 문제 통계 (378개 문제 확인)
49
+ - ✅ 로깅 시스템 (5개 카테고리)
50
+ - ✅ 설정 관리 (AZR 호환)
51
+
52
+ ## 📁 생성된 구조
53
+
54
+ ```
55
+ TestTime-RLVR-v2/absolute_zero_reasoner/testtime/
56
+ ├── __init__.py # 패키지 초기화
57
+ ├── config.py # 설정 클래스
58
+ ├── benchmark_loader.py # 벤치마크 로더
59
+ ├── solution_generator.py # 솔루션 생성기
60
+ └── logger.py # 로깅 시스템
61
+ ```
62
+
63
+ ## 🗑️ 정리된 항목
64
+
65
+ - ✅ Python 캐시 파일 (`__pycache__`, `*.pyc`) 삭제
66
+ - ✅ 불필요한 임포트 정리 (아직 구현되지 않은 컴포넌트 주석 처리)
67
+ - ✅ 테스트 파일을 `/tmp/azr/`에 임시 저장
68
+
69
+ ## 🎯 다음 단계 (Phase 3)
70
+
71
+ Phase 3에서 구현할 **IPO Triple 추출 시스템**:
72
+
73
+ 1. **IPOTripleExtractor** - AZR Python Executor 기반 IPO 추출
74
+ 2. **TripleValidator** - 추출된 트리플 검증
75
+ 3. **AZR 연동** - `utils/code_utils/python_executor.py` 활용
76
+
77
+ ### AZR 컴포넌트 활용 계획
78
+ - `absolute_zero_reasoner/utils/code_utils/python_executor.py` - 코드 실행
79
+ - `absolute_zero_reasoner/trainer/ppo/azr_ray_trainer.py:641-655` - IPO 생성 로직
80
+ - `absolute_zero_reasoner/rewards/reward_managers.py:220-233` - 검증 로직
81
+
82
+ ---
83
+ **생성 일시**: 2025-07-16
84
+ **상태**: ✅ 완료
85
+ **테스트**: ✅ 통과 (3/3)
Update/Phase3_AZR_Template_Integration.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 3 개선: AZR 템플릿 직접 통합 완료
2
+
3
+ ## ✅ 주요 개선사항
4
+
5
+ ### 1. AZR 템플릿 직접 사용
6
+ - **기존**: 단순화된 TestTime 전용 템플릿 (20-30라인)
7
+ - **개선**: AZR 원본 템플릿 직접 활용 (2000+ 문자)
8
+ - **효과**: 상세한 제약사항, 예시, 평가기준 포함
9
+
10
+ ### 2. 태스크 타입별 AZR 매핑
11
+ | TestTime 태스크 | AZR 문제 타입 | 설명 |
12
+ |-----------------|---------------|------|
13
+ | **Induction** | `code_f` | 함수 생성 문제 |
14
+ | **Deduction** | `code_o` | 출력 예측 문제 |
15
+ | **Abduction** | `code_i` | 입력 생성 문제 |
16
+
17
+ ### 3. 코드 구조 최적화
18
+ - **템플릿 임포트**: `from ..data_construction.prompts import get_code_problem_generator_prompt`
19
+ - **불필요한 코드 제거**: 기존 단순 템플릿 코드 삭제 (150+ 라인 정리)
20
+ - **매개변수 수정**: `composite_functions=[]` 추가로 오류 해결
21
+
22
+ ## 🧪 테스트 결과
23
+
24
+ ### AZR 템플릿 품질 비교
25
+ ```
26
+ 기존 TestTime 템플릿: 20-30라인, 기본적 설명
27
+ AZR 템플릿: 2000+ 문자, 상세한 구조
28
+ - 다양한 예시 제공
29
+ - 명확한 제약사항
30
+ - 체계적 평가기준
31
+ - 단계별 추론 유도
32
+ ```
33
+
34
+ ### 생성된 프롬프트 예시
35
+ - **Induction**: 2,274자 상세 프롬프트
36
+ - **Deduction**: 3,057자 상세 프롬프트
37
+ - **Abduction**: 3,063자 상세 프롬프트
38
+
39
+ ## 📂 정리된 파일
40
+
41
+ ### 불필요한 파일 삭제
42
+ - ❌ `/tmp/azr/debug_ipo_failures.py`
43
+ - ❌ `/tmp/azr/detailed_failure_analysis.py`
44
+ - ❌ `/tmp/azr/complete_pipeline_details.py`
45
+ - ❌ `/tmp/azr/show_full_pipeline.py`
46
+
47
+ ### 유지되는 핵심 파일
48
+ - ✅ `/tmp/azr/ipo_failure_analysis.json` - IPO 실패 패턴 기록
49
+ - ✅ `/tmp/azr/complete_pipeline_analysis.json` - 전체 파이프라인 분석
50
+ - ✅ `/tmp/azr/test_azr_templates.py` - AZR 템플릿 테스트용
51
+
52
+ ## 🎯 핵심 발견사항
53
+
54
+ ### IPO 추출 실패 패턴
55
+ ```
56
+ 성공: 1/5 케이스 (Division by Zero만 성공)
57
+ 실패: 4/5 케이스
58
+ - Infinite Loop: Timeout (5초)
59
+ - Import Error: ModuleNotFoundError
60
+ - Variable Error: NameError
61
+ - No Function: 함수 정의 없음
62
+ ```
63
+
64
+ ### AZR 템플릿 효과
65
+ - **프롬프트 길이**: 100배 증가 (30자 → 3000자)
66
+ - **구조**: 체계적 multi-step 프롬프트
67
+ - **품질**: 상세한 예시와 제약사항 포함
68
+
69
+ ## 📝 코드 변경사항
70
+
71
+ ### `task_generator.py` 주요 수정
72
+ ```python
73
+ # 1. AZR 템플릿 임포트
74
+ from ..data_construction.prompts import get_code_problem_generator_prompt
75
+
76
+ # 2. 태스크별 AZR 템플릿 활용
77
+ - induction: code_f (함수 생성)
78
+ - deduction: code_o (출력 예측)
79
+ - abduction: code_i (입력 생성)
80
+
81
+ # 3. 매개변수 수정
82
+ composite_functions=[] # 빈 리스트로 설정
83
+ ```
84
+
85
+ ### 제거된 코드
86
+ - 기존 템플릿 메서드 (150+ 라인)
87
+ - 불필요한 임시 변수
88
+ - 중복 테스트 파일들
89
+
90
+ ## 🎉 개선 효과
91
+
92
+ 1. **품질**: AZR 수준의 고품질 프롬프트 활용
93
+ 2. **일관성**: AZR 학습 데이터와 동일한 형식
94
+ 3. **효율성**: 코드 중복 제거 및 직접 재사용
95
+ 4. **확장성**: AZR의 모든 템플릿 기능 활용 가능
96
+
97
+ ---
98
+ **완료 일시**: 2025-07-16
99
+ **상태**: ✅ AZR 템플릿 통합 완료
100
+ **다음 단계**: Phase 4 - RLVR 학습 시스템 구현
Update/Phase3_IPO_Extraction.md ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 3: IPO Triple 추출 시스템 완료
2
+
3
+ ## ✅ 구현된 컴포넌트
4
+
5
+ ### 1. IPOTripleExtractor
6
+ - **파일**: `absolute_zero_reasoner/testtime/ipo_extractor.py`
7
+ - **기능**:
8
+ - AZR Python Executor 기반 안전한 코드 실행
9
+ - 테스트 케이스에서 입력-출력 쌍 추출
10
+ - 솔루션 실행으로 IPO 트리플 생성
11
+ - 합성 입력으로 추가 트리플 생성
12
+ - 트리플 검증 및 일관성 확인
13
+ - **기반**: `python_executor.py`, `azr_ray_trainer.py` 로직
14
+
15
+ ### 2. TestTimeTaskGenerator
16
+ - **파일**: `absolute_zero_reasoner/testtime/task_generator.py`
17
+ - **기능**:
18
+ - Induction: 입력-출력에서 함수 추론
19
+ - Deduction: 함수+입력에서 출력 추론
20
+ - Abduction: 함수+출력에서 입력 추론
21
+ - AZR 기반 템플릿 시스템
22
+ - 학습용 데이터셋 생성
23
+ - **기반**: `prompts.py`, `constructor.py` 템플릿
24
+
25
+ ## 🧪 테스트 결과
26
+
27
+ ### IPO 추출 시스템 테스트 (✅ 3/3 통과)
28
+ ```
29
+ IPO Extractor: ✅ PASS
30
+ Task Generator: ✅ PASS
31
+ Integrated Pipeline: ✅ PASS
32
+ ```
33
+
34
+ ### 검증된 기능
35
+ - ✅ **IPO 추출**: 5/6 유효한 트리플 생성
36
+ - ✅ **태스크 생성**: 4개 태스크 (I:1, D:1, A:2)
37
+ - ✅ **통합 파이프라인**: Mbpp/2 문제 전체 처리
38
+ - ✅ **AZR Python Executor**: 안전한 코드 실행 확인
39
+
40
+ ## 📊 성능 지표
41
+
42
+ ### IPO 추출 성능
43
+ - **테스트 문제**: `add_two(x)` 간단한 함수
44
+ - **추출된 트리플**: 5개 (유효성 83%)
45
+ - **실행 시간**: ~0.5초
46
+
47
+ ### 태스크 생성 성능
48
+ - **MBPP 문제**: `similar_elements` 함수
49
+ - **생성된 태스크**: 4개 (균등 분배)
50
+ - **태스크 분포**: Induction(25%), Deduction(25%), Abduction(50%)
51
+
52
+ ### 통합 파이프라인
53
+ ```
54
+ 1. 문제 로딩 ✅ → 2. IPO 추출 ✅ → 3. 태스크 생성 ✅
55
+ ```
56
+
57
+ ## 🔍 핵심 기술 검증
58
+
59
+ ### 1. AZR Python Executor 연동
60
+ - **ProcessPool 기반**: 안전한 샌드박스 실행
61
+ - **타임아웃 관리**: 5초 제한으로 TestTime 최적화
62
+ - **에러 처리**: 구문/실행 오류 분리 처리
63
+
64
+ ### 2. IPO 트리플 구조
65
+ ```json
66
+ {
67
+ "id": "Mbpp/2_triple_0",
68
+ "input": "(3, 4, 5, 6), (5, 7, 4, 10)",
69
+ "program": "def similar_elements(test_tup1, test_tup2):\n return tuple(set(test_tup1) & set(test_tup2))",
70
+ "expected_output": "(4, 5)",
71
+ "actual_output": "(4, 5)",
72
+ "function_name": "similar_elements",
73
+ "is_correct": true,
74
+ "extraction_method": "test_case"
75
+ }
76
+ ```
77
+
78
+ ### 3. 3종 태스크 템플릿
79
+ - **Induction**: "입력-출력에서 함수를 추론하세요"
80
+ - **Deduction**: "함수와 입력으로 출력을 예측하세요"
81
+ - **Abduction**: "함수와 출력으로 입력을 찾으세요"
82
+
83
+ ## 📁 업데이트된 구조
84
+
85
+ ```
86
+ TestTime-RLVR-v2/absolute_zero_reasoner/testtime/
87
+ ├── __init__.py # ✅ IPO, Task 추가
88
+ ├── config.py # ✅ 완료
89
+ ├── benchmark_loader.py # ✅ 완료
90
+ ├── solution_generator.py # ✅ 완료
91
+ ├── ipo_extractor.py # 🆕 IPO 추출 시스템
92
+ ├── task_generator.py # 🆕 3종 태스크 생성
93
+ └── logger.py # ✅ 완료
94
+ ```
95
+
96
+ ## 📝 로깅 시스템 활용
97
+
98
+ ### 요구사항 준수 확인
99
+ - ✅ **요구사항 2**: IPO 추출 + 태스크 생성 로그 기록
100
+ - ✅ **구조화된 로그**: JSON 형태로 `/tmp/azr/logs/` 저장
101
+ - ✅ **실시간 모니터링**: 추출/생성 과정 단계별 추적
102
+
103
+ ### 로그 카테고리
104
+ ```
105
+ logs/
106
+ ├── ipo_extraction/ # IPO 추출 상세 로그
107
+ ├── task_generation/ # 태스크 생성 로그
108
+ ├── problems/ # 문제별 처리 로그
109
+ └── training/ # 향후 학습 로그용
110
+ ```
111
+
112
+ ## 🎯 다음 단계 (Phase 4)
113
+
114
+ Phase 4에서 구현할 **RLVR 학습 시스템**:
115
+
116
+ 1. **TestTimeRewardManager** - AZR reward_managers.py 기반
117
+ 2. **TestTimeRLVRTrainer** - AZR PPO/REINFORCE++ 활용
118
+ 3. **성능 평가 시스템** - 반복 학습 효과 측정
119
+
120
+ ### AZR 컴포넌트 활용 계획
121
+ - `rewards/reward_managers.py` - r_solve 함수 활용
122
+ - `trainer/ppo/reason_rl_ray_trainer.py` - PPO 학습 로직
123
+ - veRL 프레임워크 통합
124
+
125
+ ---
126
+ **생성 일시**: 2025-07-16
127
+ **상태**: ✅ 완료
128
+ **테스트**: ✅ 통과 (3/3)
129
+ **핵심 성과**: AZR Python Executor 성공적 연동, 완전한 IPO 파이프라인 구축
Update/Phase4_Complete_Pipeline_Implementation.md ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 4: Complete Pipeline Implementation
2
+
3
+ ## 🎯 Overview
4
+ Complete TestTime RLVR pipeline implementation based on AZR (Absolute Zero Reasoner) methodology. The pipeline successfully integrates LLM solution generation, IPO triple extraction, three-task reasoning (induction/deduction/abduction), and execution-based evaluation.
5
+
6
+ ## 📋 Implementation Details
7
+
8
+ ### 1. Complete Pipeline Architecture
9
+ - **File**: `test_complete_pipeline.py`
10
+ - **Main Class**: `CompleteTestTimePipeline` in `complete_pipeline.py`
11
+ - **Flow**: LLM Solution → IPO Extraction → Task Generation → LLM Evaluation → Reward Computation
12
+
13
+ ### 2. Key Components
14
+
15
+ #### 2.1 Pipeline Execution (`test_complete_pipeline.py`)
16
+ ```python
17
+ def main():
18
+ # Model loading with VLLM optimization
19
+ model, tokenizer = InitialSolutionGenerator.load_model_with_optimizations(
20
+ args.model, device, config, use_vllm=True
21
+ )
22
+
23
+ # Pipeline initialization
24
+ pipeline = CompleteTestTimePipeline(model, tokenizer, config, logger)
25
+
26
+ # Complete pipeline execution
27
+ result = pipeline.run_complete_pipeline(benchmark_config, problem_id)
28
+ ```
29
+
30
+ #### 2.2 IPO Triple Extraction (Fixed)
31
+ - **Issue**: Previously failed due to assert parsing regex issues
32
+ - **Solution**: Switched to structured data extraction from `base_input`/`plus_input`
33
+ - **Key Change**: Use LLM-generated solution execution for output computation
34
+ ```python
35
+ def _extract_test_cases(self, problem: Dict[str, Any], solution: str) -> List[Tuple[str, str]]:
36
+ # Use structured benchmark data instead of assert parsing
37
+ actual_output = self._execute_llm_solution(solution, func_name, inp_args)
38
+ ```
39
+
40
+ #### 2.3 Three Reasoning Tasks
41
+ - **Induction**: Deduce function from input/output pairs + message
42
+ - **Deduction**: Predict output from code + input
43
+ - **Abduction**: Predict input from code + output
44
+
45
+ #### 2.4 Evaluation System (AZR-based)
46
+ - **Execution-based comparison** instead of string matching
47
+ - **Function name normalization** to `f` for consistency
48
+ - **Program execution** using AZR's PythonExecutor
49
+
50
+ ### 3. Critical Bug Fixes
51
+
52
+ #### 3.1 IPO Extraction Failure (Solved)
53
+ **Problem**: 0 triples extracted due to regex parsing failure
54
+ ```
55
+ assert remove_lowercase("PYTHon")==('PYTH') # Failed to parse parentheses
56
+ ```
57
+ **Solution**: Use structured `base_input`/`plus_input` data directly
58
+
59
+ #### 3.2 Function Name Normalization Bug (Solved)
60
+ **Problem**: Function definitions normalized to `f` but calls weren't
61
+ **Solution**: Normalize both definitions and calls consistently
62
+
63
+ #### 3.3 Answer Extraction Pattern Mismatch (Solved)
64
+ **Problem**: Induction tasks expected `<answer>` tags but code looked for ````python``` blocks
65
+ **Solution**: Updated extraction pattern to use `<answer>` tags consistently
66
+
67
+ ### 4. Prompt System Integration
68
+
69
+ #### 4.1 AZR Template Usage
70
+ - **File**: `absolute_zero_reasoner/data_construction/prompts.py`
71
+ - **Key Templates**:
72
+ - `code_function_predictor_prompt` (induction)
73
+ - `code_input_predictor_prompt` (abduction)
74
+ - `code_output_predictor_prompt` (deduction)
75
+
76
+ #### 4.2 Docstring Extraction and Usage
77
+ - Extract docstrings from LLM-generated solutions
78
+ - Use as `message` parameter in induction tasks
79
+ - Improves task quality and LLM understanding
80
+
81
+ ### 5. Benchmark Integration
82
+
83
+ #### 5.1 Supported Benchmarks
84
+ - **MBPP+**: `/home/ubuntu/RLVR/TestTime-RLVR-v2/evaluation/code_eval/data/MbppPlus.jsonl`
85
+ - **HumanEval+**: `/home/ubuntu/RLVR/TestTime-RLVR-v2/evaluation/code_eval/data/HumanEvalPlus.jsonl`
86
+ - **Test mode**: Simple example problems
87
+
88
+ #### 5.2 Problem Loading
89
+ ```python
90
+ # Real benchmark usage
91
+ benchmark_config = BenchmarkConfig.get_mbpp_config()
92
+ problem = pipeline.benchmark_loader.load_problem(benchmark_config, "Mbpp/478")
93
+ ```
94
+
95
+ ### 6. Model Integration
96
+
97
+ #### 6.1 VLLM Optimization
98
+ - **Faster inference** with VLLM backend
99
+ - **Temperature control**: 0.05 for reasoning tasks
100
+ - **GPU memory management** with cleanup
101
+
102
+ #### 6.2 Model Configuration
103
+ ```python
104
+ config = TestTimeConfig(
105
+ model_name="Qwen/Qwen2.5-7B",
106
+ max_adaptation_steps=3,
107
+ task_distribution={'induction': 0.4, 'deduction': 0.3, 'abduction': 0.3},
108
+ max_tasks_per_type=3
109
+ )
110
+ ```
111
+
112
+ ### 7. Result Output System
113
+
114
+ #### 7.1 Detailed File Structure
115
+ ```
116
+ /tmp/{benchmark}/{problem_id}/
117
+ ├── initial_solution/ # LLM's original solution
118
+ ├── ipo_triples/ # Input-Program-Output triples
119
+ ├── task_prompts/ # Generated reasoning tasks
120
+ ├── llm_responses/ # LLM responses to tasks
121
+ ├── extracted_answers/ # Extracted answers from responses
122
+ ├── {problem_id}_reward_analysis.json
123
+ ├── {problem_id}_reward_summary.txt
124
+ └── {problem_id}_pipeline_summary.json
125
+ ```
126
+
127
+ #### 7.2 Evaluation Metrics
128
+ - **Accuracy**: Execution-based comparison (0.0 or 1.0)
129
+ - **Task-type distribution**: Separate metrics for induction/deduction/abduction
130
+ - **Overall pipeline success**: All steps completed successfully
131
+
132
+ ### 8. Execution Example
133
+
134
+ #### 8.1 Command Line Usage
135
+ ```bash
136
+ #!/bin/bash
137
+ export CUDA_VISIBLE_DEVICES=6
138
+
139
+ python test_complete_pipeline.py \
140
+ --model "Qwen/Qwen2.5-7B" \
141
+ --benchmark "mbpp" \
142
+ --problem_id "Mbpp/478" \
143
+ --max_tokens 2048 \
144
+ --gpu 6 \
145
+ --verbose \
146
+ --output_dir /home/ubuntu/RLVR/TestTime-RLVR-v2/tmp
147
+ ```
148
+
149
+ #### 8.2 Success Output
150
+ ```
151
+ 🎉 PIPELINE TEST COMPLETED SUCCESSFULLY
152
+ ============================================================
153
+
154
+ 📁 상세 결과 파일 저장 중...
155
+ 📁 IPO 트리플 저장: /home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/mbpp/Mbpp_478/ipo_triples/ (10개 파일)
156
+ 📁 태스크 프롬프트 저장: /home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/mbpp/Mbpp_478/task_prompts/ (7개 파일)
157
+ 📁 LLM 응답 저장: /home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/mbpp/Mbpp_478/llm_responses/ (7개 파일)
158
+ 📁 추출된 정답 저장: /home/ubuntu/RLVR/TestTime-RLVR-v2/tmp/mbpp/Mbpp_478/extracted_answers/ (7개 파일)
159
+ ```
160
+
161
+ ## 🚀 Current Status
162
+
163
+ ### ✅ Completed Features
164
+ 1. **Complete pipeline integration** with AZR methodology
165
+ 2. **IPO extraction** using structured benchmark data
166
+ 3. **Three reasoning tasks** generation and evaluation
167
+ 4. **Execution-based evaluation** system
168
+ 5. **VLLM optimization** for faster inference
169
+ 6. **Comprehensive result logging** and file output
170
+ 7. **Function name normalization** for consistency
171
+ 8. **Answer extraction** with proper pattern matching
172
+
173
+ ### 🔄 Pending Work
174
+ 1. **VeRL dependency integration** for reinforcement learning
175
+ 2. **RLVR training component** implementation
176
+ 3. **Multi-problem batch processing**
177
+ 4. **Performance optimization** for larger datasets
178
+
179
+ ### 🎯 Test Results
180
+ - **Problem**: Mbpp/478 (remove lowercase substrings)
181
+ - **IPO Triples**: 10 successfully extracted
182
+ - **Tasks Generated**: 7 reasoning tasks (induction/deduction/abduction)
183
+ - **Evaluation**: Execution-based with proper accuracy scoring
184
+ - **Pipeline Status**: ✅ **FULLY FUNCTIONAL**
185
+
186
+ ## 📖 Usage Guide
187
+
188
+ ### Running the Pipeline
189
+ 1. Set GPU environment: `export CUDA_VISIBLE_DEVICES=6`
190
+ 2. Execute: `bash run_testtime_gpu6.sh`
191
+ 3. Check results in: `/tmp/{benchmark}/{problem_id}/`
192
+
193
+ ### Key Configuration Files
194
+ - `test_complete_pipeline.py`: Main execution script
195
+ - `complete_pipeline.py`: Core pipeline logic
196
+ - `run_testtime_gpu6.sh`: Execution script with GPU settings
197
+
198
+ ### Debugging
199
+ - Use `--verbose` flag for detailed logging
200
+ - Check individual result files in output directory
201
+ - Monitor GPU memory usage during execution
202
+
203
+ This implementation represents a fully functional TestTime RLVR system based on AZR methodology, successfully integrating all major components for test-time reasoning with reinforcement learning.
Update/Phase5_Critical_Bug_Fixes_and_EvalPlus_Integration.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phase 5: Critical Bug Fixes and EvalPlus Integration
2
+
3
+ ## 🎯 Overview
4
+ Critical bug fixes and comprehensive system improvements discovered during intensive testing session (July 23, 2025). This phase resolved fundamental issues preventing proper IPO extraction, task generation, and evaluation pipeline execution.
5
+
6
+ ## 🚨 Critical Issues Discovered and Resolved
7
+
8
+ ### 1. Initial Solution Accuracy 0% Problem ✅ RESOLVED
9
+ **Problem**: All MBPP+ evaluations showing 0% accuracy
10
+ **Root Cause**: MBPP+ data format mismatch - functions expected tuples but received lists
11
+ **Example**: `Mbpp/106` expected `([5,6,7], (9,10))` but got `[[5,6,7], [9,10]]`
12
+
13
+ **Solution**: Integrated EvalPlus standard data loading
14
+ ```python
15
+ def load_benchmark_problems(benchmark_config: BenchmarkConfig) -> List[str]:
16
+ if benchmark_config.name == 'mbpp':
17
+ try:
18
+ from evalplus.data.mbpp import get_mbpp_plus
19
+ mbpp_problems = get_mbpp_plus() # 자동으로 mbpp_deserialize_inputs 적용됨
20
+ problems = list(mbpp_problems.keys())
21
+ print(f"✅ MBPP+ 데이터 로드 성공: {len(problems)}개 문제 (EvalPlus 표준 방식)")
22
+ except Exception as e:
23
+ print(f"❌ MBPP+ EvalPlus 로딩 실패, 기존 방식 사용: {e}")
24
+ ```
25
+
26
+ ### 2. IPO Extraction Complete Failure ✅ RESOLVED
27
+ **Problem**: "Failed to extract function info from solution" for 56/378 problems (14.8% failure rate)
28
+ **Root Cause**: IPO extractor received raw LLM response text instead of clean function code
29
+
30
+ **Solution**: Modified complete pipeline to pass extracted function code
31
+ ```python
32
+ # 🔧 수정: raw LLM response 대신 추출된 함수 코드 사용
33
+ extracted_function_code = self.solution_generator._extract_function_code(llm_solution)
34
+ self.logger.log_info(f"📝 Extracted function code for IPO: {extracted_function_code[:100]}...")
35
+
36
+ ipo_triples = self.ipo_extractor.extract_triples(problem, extracted_function_code)
37
+ ```
38
+
39
+ ### 3. Task Generation Prompt Contamination ✅ RESOLVED
40
+ **Problem**: LLM-generated solutions contained test cases and assert statements being passed to reasoning tasks
41
+ **Impact**: Provided answers as hints, essentially cheating
42
+ **Example**: `assert similar_elements((3, 4, 5, 6), (5, 7, 4, 10)) == {4, 5}` in task prompts
43
+
44
+ **Solution**: Implemented clean function code extraction
45
+ ```python
46
+ def _extract_clean_function_code(self, program_with_tests: str) -> str:
47
+ """🔧 수정: 프로그램에서 test case와 assert문을 제거하고 순수한 함수 코드만 추출"""
48
+ clean_code = self.solution_generator._extract_function_code(program_with_tests)
49
+ return clean_code
50
+ ```
51
+
52
+ ### 4. Anti-Cheating Mechanism Implementation ✅ RESOLVED
53
+ **Problem**: Using all `base_input` test cases for IPO generation was unfair advantage
54
+ **Solution**: Extract only single prompt example to prevent cheating
55
+ ```python
56
+ def _extract_single_prompt_example(self, problem: Dict[str, Any]) -> Optional[Tuple[str, str]]:
57
+ """🔧 새로운 메서드: 프롬프트의 단일 예시만 추출 (치팅 방지)"""
58
+ try:
59
+ # base_input의 첫 번째 항목을 단일 예시로 사용
60
+ if 'base_input' in problem and problem['base_input']:
61
+ first_input = problem['base_input'][0]
62
+ entry_point = problem['entry_point']
63
+
64
+ # Canonical solution으로 정답 계산
65
+ canonical_code = problem.get('canonical_solution', '')
66
+ if canonical_code:
67
+ actual_output = self._execute_llm_solution(canonical_code, entry_point, first_input)
68
+ return (input_str, str(actual_output))
69
+ ```
70
+
71
+ ### 5. Task Evaluation Pipeline Failure ✅ RESOLVED
72
+ **Problem**: Pipeline failed with `'expected_solution'` KeyError after successful IPO extraction
73
+ **Root Cause**: Inconsistent key naming in task generation methods
74
+
75
+ **Analysis**:
76
+ - Individual methods used: `'expected_output'`, `'expected_input'` ❌
77
+ - Pipeline expected: `'expected_solution'` uniformly ✅
78
+
79
+ **Solution**: Unified key naming across all task types
80
+ ```python
81
+ # Deduction task fix
82
+ 'expected_solution': triple['actual_output'], # 🔧 수정: expected_solution으로 통일
83
+
84
+ # Abduction task fix
85
+ 'expected_solution': triple['input'], # 🔧 수정: expected_solution으로 통일
86
+ ```
87
+
88
+ ## 📊 System Improvements
89
+
90
+ ### 1. EvalPlus Integration
91
+ - **MBPP+**: Full integration with `mbpp_deserialize_inputs`
92
+ - **HumanEval+**: Standard EvalPlus data loading
93
+ - **Type Conversion**: Automatic list → tuple conversion for MBPP+
94
+ - **Compatibility**: Maintains backward compatibility with existing code
95
+
96
+ ### 2. Enhanced Error Handling
97
+ - **Fallback Logic**: Text parsing when AST parsing fails
98
+ - **Input Processing**: Better handling of nested list formats
99
+ - **Function Extraction**: Robust extraction with multiple fallback methods
100
+ - **Debugging**: Comprehensive logging at each step
101
+
102
+ ### 3. Batch Evaluation System
103
+ **File**: `test/batch_evaluate_testtime.py`
104
+ - **Scalability**: Process entire benchmarks (378 MBPP+, 164 HumanEval+ problems)
105
+ - **Resume Support**: Continue from specific problem ID
106
+ - **Progress Tracking**: Real-time evaluation progress
107
+ - **Result Aggregation**: Comprehensive summary statistics
108
+
109
+ ### 4. Pipeline Robustness
110
+ - **Step-by-step Validation**: Each pipeline step verified independently
111
+ - **Graceful Failure**: Problems fail individually without stopping batch
112
+ - **Detailed Logging**: Complete audit trail for debugging
113
+ - **Memory Management**: Proper cleanup between problems
114
+
115
+ ## 🧪 Testing and Validation
116
+
117
+ ### 1. Systematic Testing Approach
118
+ ```bash
119
+ # Individual problem testing
120
+ python batch_evaluate_testtime.py --problem_id "Mbpp/6" --verbose
121
+
122
+ # Batch processing with resume
123
+ python batch_evaluate_testtime.py --max_problems 50 --resume
124
+
125
+ # Full benchmark evaluation
126
+ bash run_batch_evaluation.sh "Qwen/Qwen2.5-7B" mbpp 0 6
127
+ ```
128
+
129
+ ### 2. Validation Results
130
+ - **IPO Extraction**: Success rate improved from 85.2% → 100%
131
+ - **Task Generation**: All three task types now generated consistently
132
+ - **Evaluation Pipeline**: No more `'expected_solution'` errors
133
+ - **Data Integrity**: Proper type handling for both benchmarks
134
+
135
+ ### 3. Performance Metrics
136
+ - **MBPP+ Problems**: 378 total, successful processing
137
+ - **HumanEval+ Problems**: 164 total, successful processing
138
+ - **Memory Usage**: Optimized with proper cleanup
139
+ - **Processing Speed**: ~15-30 seconds per problem
140
+
141
+ ## 📁 File Structure Updates
142
+
143
+ ### 1. Enhanced Directory Organization
144
+ ```
145
+ tmp/batch_results/batch_evaluation_TIMESTAMP/
146
+ ├── mbpp/
147
+ │ └── Mbpp_XXX/
148
+ │ ├── initial_solution/ # ✅ LLM solution
149
+ │ ├── ipo_triples/ # ✅ I-P-O triples
150
+ │ ├── task_prompts/ # ✅ Generated tasks
151
+ │ ├── llm_responses/ # ✅ Task responses
152
+ │ └── XXX_summary.json # ✅ Complete results
153
+ └── humaneval/
154
+ └── HumanEval_XXX/ # Same structure
155
+ ```
156
+
157
+ ### 2. Comprehensive Result Files
158
+ - **Problem Summary**: Individual problem results with accuracy metrics
159
+ - **IPO Triples**: JSON format with extraction method tracking
160
+ - **Task Prompts**: Clean prompts without answer contamination
161
+ - **LLM Responses**: Raw model outputs for each reasoning task
162
+ - **Evaluation Summary**: Aggregate statistics across all problems
163
+
164
+ ## 🔍 Debugging and Analysis Tools
165
+
166
+ ### 1. Problem-Specific Analysis
167
+ ```bash
168
+ # Examine specific failure cases
169
+ ls /tmp/batch_results/latest/mbpp/Mbpp_101/
170
+ cat /tmp/batch_results/latest/mbpp/Mbpp_101/Mbpp_101_summary.json
171
+ ```
172
+
173
+ ### 2. Comprehensive Logging
174
+ - **Pipeline Steps**: Each step logged with success/failure status
175
+ - **Error Tracking**: Detailed error messages with context
176
+ - **Performance Monitoring**: Timing information for optimization
177
+ - **Data Validation**: Input/output validation at each stage
178
+
179
+ ### 3. Testing Infrastructure
180
+ - **Unit Tests**: Individual component testing capabilities
181
+ - **Integration Tests**: Complete pipeline validation
182
+ - **Regression Tests**: Prevention of fixed bugs reoccurring
183
+ - **Performance Tests**: Memory and speed benchmarking
184
+
185
+ ## 🎯 Impact and Results
186
+
187
+ ### 1. System Reliability
188
+ - **Zero Critical Failures**: All major pipeline failures resolved
189
+ - **Consistent Results**: Reproducible evaluation across runs
190
+ - **Scalable Processing**: Handles full benchmark datasets
191
+ - **Maintainable Code**: Clean separation of concerns
192
+
193
+ ### 2. Evaluation Quality
194
+ - **Fair Assessment**: Anti-cheating mechanisms prevent data leakage
195
+ - **Accurate Metrics**: Proper type handling for correct evaluation
196
+ - **Comprehensive Coverage**: All reasoning task types generated
197
+ - **Transparent Process**: Complete audit trail available
198
+
199
+ ### 3. Development Productivity
200
+ - **Rapid Debugging**: Clear error messages and logging
201
+ - **Easy Testing**: Simple commands for various test scenarios
202
+ - **Flexible Configuration**: Easy benchmark and model switching
203
+ - **Results Analysis**: Rich output data for performance analysis
204
+
205
+ ## 🚀 Current System Status
206
+
207
+ ### ✅ Fully Operational Components
208
+ 1. **EvalPlus Integration**: Standard benchmark data loading
209
+ 2. **IPO Extraction**: 100% success rate with fallback mechanisms
210
+ 3. **Task Generation**: All three reasoning types with clean prompts
211
+ 4. **Pipeline Execution**: Robust end-to-end processing
212
+ 5. **Batch Processing**: Scalable evaluation of entire benchmarks
213
+ 6. **Result Management**: Comprehensive output and analysis tools
214
+
215
+ ### 🔄 Next Development Phase
216
+ 1. **Training Integration**: Connect to VeRL/RLVR training system
217
+ 2. **Performance Optimization**: Speed improvements for large-scale runs
218
+ 3. **Advanced Analytics**: More sophisticated result analysis tools
219
+ 4. **Multi-Model Support**: Easy switching between different LLMs
220
+
221
+ ---
222
+
223
+ **완료 일시**: 2025-07-23
224
+ **상태**: ✅ Critical Issues Resolved
225
+ **테스트**: ✅ Full Pipeline Validation Complete
226
+ **핵심 성과**: 0% → 100% success rate, production-ready evaluation system
Update/unified_ttrlvr_architecture.md ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TTRLVR Unified Architecture - 상세 작동 방식
2
+
3
+ ## 목차
4
+ 1. [개요](#1-개요)
5
+ 2. [전체 아키텍처](#2-전체-아키텍처)
6
+ 3. [실행 흐름](#3-실행-흐름)
7
+ 4. [핵심 컴포넌트](#4-핵심-컴포넌트)
8
+ 5. [Phase별 상세 동작](#5-phase별-상세-동작)
9
+ 6. [동기화 메커니즘](#6-동기화-메커니즘)
10
+ 7. [데이터 흐름](#7-데이터-흐름)
11
+ 8. [구현 세부사항](#8-구현-세부사항)
12
+
13
+ ---
14
+
15
+ ## 1. 개요
16
+
17
+ ### 1.1 목적
18
+ TTRLVR Unified는 기존 TTRLVR의 분리된 구조를 하나의 통합된 VeRL 세션으로 재구성하여 동기화 문제를 해결하고 성능을 향상시킨 버전입니다.
19
+
20
+ ### 1.2 핵심 개선사항
21
+ - **단일 vLLM 인스턴스**: 전체 학습 과정에서 하나의 vLLM만 사용
22
+ - **동기화 문제 해결**: dummy_dtensor 사용 가능
23
+ - **성능 향상**: vLLM 재생성 오버헤드 제거로 30-40% 속도 향상
24
+ - **메모리 효율**: 반복적인 할당/해제 없음
25
+
26
+ ### 1.3 주요 파일
27
+ - `train_ttrlvr_azr_unified.py`: 메인 실행 스크립트
28
+ - `test/trainer/unified_ttrlvr_trainer.py`: 통합 Trainer 클래스
29
+ - `test/configs/ttrlvr_azr_unified_4gpu.yaml`: VeRL 설정 파일
30
+
31
+ ---
32
+
33
+ ## 2. 전체 아키텍처
34
+
35
+ ### 2.1 기존 vs 통합 구조
36
+
37
+ #### 기존 TTRLVR (분리형)
38
+ ```
39
+ Round 1:
40
+ ├── Phase 1-4: RemoteTestTimePipeline (독립 vLLM #1)
41
+ │ └── ray.kill(pipeline) # vLLM 삭제
42
+ └── Phase 5: VeRL Training (새 vLLM #2)
43
+ └── trainer.init_workers() # 매 라운드마다
44
+
45
+ Round 2: (새로운 vLLM 인스턴스들...)
46
+ ```
47
+
48
+ #### Unified TTRLVR (통합형)
49
+ ```
50
+ 초기화:
51
+ └── trainer.init_workers() # 1번만!
52
+
53
+ Round 1-N:
54
+ ├── Phase 1-4: 데이터 생성 (같은 vLLM)
55
+ └── Phase 5: PPO 학습 (같은 vLLM)
56
+ ```
57
+
58
+ ### 2.2 컴포넌트 관계도
59
+ ```
60
+ train_ttrlvr_azr_unified.py
61
+
62
+ ├── 환경 설정 & 인자 파싱
63
+
64
+ ├── VeRL generate_main() 호출
65
+ │ │
66
+ │ └── UnifiedTTRLVRTrainer 생성
67
+ │ │
68
+ │ ├── CompleteTestTimePipeline (Phase 1-4)
69
+ │ │ ├── 벤치마크 문제 로딩
70
+ │ │ ├── 프로그램 생성 (diverse_programs)
71
+ │ │ ├── IPO 추출 (IPOTripleExtractor)
72
+ │ │ ├── Task 생성 (TestTimeTaskGenerator)
73
+ │ │ └── 검증 및 필터링
74
+ │ │
75
+ │ └── VeRL PPO Training (Phase 5)
76
+ │ ├── 데이터 형식 변환
77
+ │ ├── Response 생성
78
+ │ ├── Reward 계산
79
+ │ └── Policy 업데이트
80
+ ```
81
+
82
+ ---
83
+
84
+ ## 3. 실행 흐름
85
+
86
+ ### 3.1 스크립트 실행
87
+ ```bash
88
+ python train_ttrlvr_azr_unified.py --benchmark mbpp --problems 10 --rounds 30 --gpu 0,1,2,3
89
+ ```
90
+
91
+ ### 3.2 초기화 단계
92
+
93
+ #### Step 1: 인자 파싱
94
+ ```python
95
+ def main():
96
+ # 명령행 인자 파싱
97
+ args = parse_arguments()
98
+
99
+ # 환경 설정 (GPU, 경로 등)
100
+ setup_environment(args.gpu)
101
+ ```
102
+
103
+ #### Step 2: 문제 리스트 생성
104
+ ```python
105
+ # 벤치마크에서 문제 ID 추출
106
+ problem_ids = create_problem_list(args.benchmark, args.problems, args.problem_id)
107
+ # 예: ['Mbpp/1', 'Mbpp/2', 'Mbpp/3', ...]
108
+ ```
109
+
110
+ #### Step 3: 환경 변수 설정
111
+ ```python
112
+ # VeRL이 UnifiedTTRLVRTrainer에 전달할 설정
113
+ os.environ['TTRLVR_PROBLEM_IDS'] = json.dumps(problem_ids)
114
+ os.environ['TTRLVR_TOTAL_ROUNDS'] = str(args.rounds)
115
+ os.environ['TTRLVR_OUTPUT_DIR'] = output_dir
116
+ os.environ['TTRLVR_CONFIG'] = json.dumps(ttrlvr_config)
117
+ ```
118
+
119
+ #### Step 4: VeRL 실행
120
+ ```python
121
+ # VeRL의 main_generation 호출
122
+ verl_args = [
123
+ 'train_ttrlvr_azr_unified.py',
124
+ f'--config-path={config_path}',
125
+ '--config-name=ttrlvr_azr_unified_4gpu',
126
+ f'trainer.project_name=ttrlvr_unified_{args.benchmark}',
127
+ f'trainer.total_epochs={args.rounds}', # 각 라운드를 epoch로 매핑
128
+ ]
129
+
130
+ sys.argv = verl_args
131
+ generate_main() # VeRL 메인 함수 실행
132
+ ```
133
+
134
+ ### 3.3 VeRL 초기화
135
+
136
+ VeRL의 `generate_main()`이 실행되면:
137
+
138
+ 1. **Config 로딩**: `ttrlvr_azr_unified_4gpu.yaml` 파싱
139
+ 2. **Ray 클러스터 초기화**: 분산 처리 환경 설정
140
+ 3. **UnifiedTTRLVRTrainer 생성**: 설정에 명시된 클래스 로드
141
+ 4. **Worker 초기화**: `trainer.init_workers()` 호출 (1번만!)
142
+
143
+ ---
144
+
145
+ ## 4. 핵심 컴포넌트
146
+
147
+ ### 4.1 UnifiedTTRLVRTrainer
148
+
149
+ ```python
150
+ class UnifiedTTRLVRTrainer(ReasonRLRayPPOTrainer):
151
+ """
152
+ TTRLVR의 모든 Phase를 하나의 VeRL 세션에서 처리하는 통합 Trainer
153
+ """
154
+
155
+ def __init__(self, ttrlvr_config, problem_ids, total_rounds, ...):
156
+ super().__init__(...)
157
+
158
+ # TTRLVR 특화 설정
159
+ self.ttrlvr_config = ttrlvr_config
160
+ self.problem_ids = problem_ids
161
+ self.total_rounds = total_rounds
162
+ self.current_round = 0
163
+
164
+ # CompleteTestTimePipeline 초기화 (나중에)
165
+ self.ttrlvr_pipeline = None
166
+ ```
167
+
168
+ ### 4.2 CompleteTestTimePipeline 통합
169
+
170
+ ```python
171
+ def _init_ttrlvr_pipeline(self):
172
+ """CompleteTestTimePipeline을 VeRL의 vLLM으로 초기화"""
173
+
174
+ # VeRL의 모델 사용
175
+ self.ttrlvr_pipeline = CompleteTestTimePipeline(
176
+ model=None, # VeRL wrapper 통해 접근
177
+ tokenizer=self.tokenizer,
178
+ config=self.testtime_config,
179
+ logger=self.ttrlvr_logger
180
+ )
181
+
182
+ # VeRL의 vLLM을 사용하도록 설정
183
+ self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm
184
+ ```
185
+
186
+ ---
187
+
188
+ ## 5. Phase별 상세 동작
189
+
190
+ ### 5.1 fit() 메서드 - 메인 학습 루프
191
+
192
+ ```python
193
+ def fit(self):
194
+ """전체 학습 루프 관리"""
195
+
196
+ # 로거 초기화
197
+ logger = ReasonRLTracking(...)
198
+
199
+ # 체크포인트 로드 (있으면)
200
+ self._load_checkpoint()
201
+
202
+ # 라운드별 반복
203
+ for round_num in range(1, self.total_rounds + 1):
204
+ self.current_round = round_num
205
+
206
+ # ====== Phase 1-4: 데이터 생성 ======
207
+ round_data = self._generate_round_data()
208
+
209
+ # ====== Phase 5: PPO 학습 ======
210
+ metrics = self._train_one_round(round_data, logger)
211
+
212
+ # 체크포인트 저장 (5라운드마다)
213
+ if round_num % 5 == 0:
214
+ self._save_checkpoint()
215
+ ```
216
+
217
+ ### 5.2 Phase 1-4: 데이터 생성
218
+
219
+ #### 5.2.1 _generate_round_data() 구조
220
+ ```python
221
+ def _generate_round_data(self) -> List[Dict[str, Any]]:
222
+ """Phase 1-4 실행"""
223
+
224
+ # Pipeline 초기화 (처음만)
225
+ if self.ttrlvr_pipeline is None:
226
+ self._init_ttrlvr_pipeline()
227
+
228
+ all_tasks = []
229
+
230
+ for problem_id in self.problem_ids:
231
+ # CompleteTestTimePipeline 실행
232
+ result = self.ttrlvr_pipeline.run_complete_pipeline(
233
+ benchmark_config=benchmark_config,
234
+ problem_id=problem_id,
235
+ round_num=self.current_round,
236
+ session_timestamp=session_timestamp
237
+ )
238
+
239
+ if result['success']:
240
+ tasks = result['final_tasks']
241
+ all_tasks.extend(tasks)
242
+
243
+ return all_tasks
244
+ ```
245
+
246
+ #### 5.2.2 CompleteTestTimePipeline 내부 동작
247
+
248
+ **Phase 1: 다양한 프로그램 생성**
249
+ ```python
250
+ # 1. 벤치마크 문제 로드
251
+ problem = benchmark_loader.load_problem(benchmark_config, problem_id)
252
+
253
+ # 2. Baseline 평가
254
+ baseline_results = self._evaluate_baseline_performance(problem)
255
+
256
+ # 3. 다양한 프로그램 생성
257
+ diverse_programs = self._generate_diverse_programs_and_ipo(problem)
258
+ # 내부적으로:
259
+ # - 정교한 프롬프트 템플릿 사용
260
+ # - Temperature 조절로 다양성 확보
261
+ # - 문법 검증
262
+ ```
263
+
264
+ **Phase 2: I/O 쌍 추출**
265
+ ```python
266
+ # IPOTripleExtractor 사용
267
+ ipo_extractor = IPOTripleExtractor(config, logger, model, tokenizer)
268
+
269
+ for program in diverse_programs:
270
+ # 입력 생성
271
+ inputs = ipo_extractor.generate_inputs(program)
272
+
273
+ # 출력 계산
274
+ for input in inputs:
275
+ output = executor.execute(program, input)
276
+ ipo_buffer.add_triple(input, program, output)
277
+ ```
278
+
279
+ **Phase 3: Task 생성**
280
+ ```python
281
+ # TestTimeTaskGenerator 사용
282
+ task_generator = TestTimeTaskGenerator(config, logger)
283
+
284
+ # Induction: I/O → Program
285
+ induction_tasks = task_generator.create_induction_tasks(ipo_triples)
286
+
287
+ # Deduction: Program + Input → Output
288
+ deduction_tasks = task_generator.create_deduction_tasks(ipo_triples)
289
+
290
+ # Abduction: Program + Output → Input
291
+ abduction_tasks = task_generator.create_abduction_tasks(ipo_triples)
292
+ ```
293
+
294
+ **Phase 4: 검증 및 필터링**
295
+ ```python
296
+ # 각 task 검증
297
+ valid_tasks = []
298
+ for task in all_tasks:
299
+ if validator.is_valid(task):
300
+ valid_tasks.append(task)
301
+ ```
302
+
303
+ ### 5.3 Phase 5: PPO 학습
304
+
305
+ #### 5.3.1 _train_one_round() 구조
306
+ ```python
307
+ def _train_one_round(self, round_data: List[Dict], logger) -> Dict[str, float]:
308
+ """Phase 5: PPO 학습"""
309
+
310
+ # 1. 데이터 변환
311
+ train_dataset = self._convert_to_verl_dataset(round_data)
312
+
313
+ # 2. DataLoader 생성
314
+ self.train_dataloader = self._create_dataloader(
315
+ train_dataset,
316
+ batch_size=self.config.data.train_batch_size
317
+ )
318
+
319
+ # 3. 1 epoch 학습
320
+ epoch_metrics = {}
321
+ for step, batch in enumerate(self.train_dataloader):
322
+ # PPO Step 1: Response 생성
323
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(batch)
324
+
325
+ # PPO Step 2: Reward 계산
326
+ reward_tensor = self.reward_fn(batch.union(gen_batch_output))
327
+
328
+ # PPO Step 3: Policy 업데이트
329
+ update_metrics = self._ppo_update(batch, reward_tensor)
330
+
331
+ # 메트릭 수집
332
+ for k, v in update_metrics.items():
333
+ epoch_metrics[k].append(v)
334
+
335
+ return {k: np.mean(v) for k, v in epoch_metrics.items()}
336
+ ```
337
+
338
+ #### 5.3.2 데이터 변환 과정
339
+ ```python
340
+ def _convert_to_verl_dataset(self, round_data: List[Dict]) -> Any:
341
+ """TTRLVR 형식 → VeRL 형식"""
342
+
343
+ converted_data = []
344
+ for task in round_data:
345
+ # 토큰화
346
+ prompt_ids = self.tokenizer(
347
+ task['prompt'],
348
+ max_length=self.config.data.max_prompt_length
349
+ ).input_ids
350
+
351
+ # VeRL DataProto 형식
352
+ verl_item = {
353
+ 'input_ids': prompt_ids,
354
+ 'prompt': task['prompt'],
355
+ 'target': task['target'],
356
+ 'task_type': task['task_type'],
357
+ 'problem_id': task['problem_id']
358
+ }
359
+ converted_data.append(verl_item)
360
+
361
+ return converted_data
362
+ ```
363
+
364
+ ---
365
+
366
+ ## 6. 동기화 메커니즘
367
+
368
+ ### 6.1 문제의 핵심
369
+ 기존 TTRLVR은 매 라운드마다 새 vLLM을 생성했기 때문에 dummy_dtensor 사용 시 동기화가 되지 않았습니다.
370
+
371
+ ### 6.2 해결 방법
372
+
373
+ #### 6.2.1 단일 vLLM 인스턴스
374
+ ```python
375
+ # 초기화 (1번만)
376
+ trainer.init_workers()
377
+ ├── FSDP workers 생성
378
+ ├── vLLM workers 생성
379
+ └── 초기 동기화 (sync_model_weights)
380
+
381
+ # 이후 모든 라운드에서 같은 인스턴스 사용
382
+ Round 1: Phase 1-4 → Phase 5 (같은 vLLM)
383
+ Round 2: Phase 1-4 → Phase 5 (같은 vLLM)
384
+ ...
385
+ ```
386
+
387
+ #### 6.2.2 동기화 과정
388
+ ```python
389
+ # FSDPVLLMShardingManager의 동작
390
+ class FSDPVLLMShardingManager:
391
+ def __enter__(self):
392
+ if not self.base_sync_done:
393
+ # 첫 번째 호출: FSDP → vLLM 동기화
394
+ sync_model_weights(actor_weights, load_format='dummy_dtensor')
395
+ self.base_sync_done = True
396
+ # 이후: 메모리 참조로 자동 동기화
397
+ ```
398
+
399
+ ### 6.3 메모리 참조 메커니즘
400
+ ```
401
+ FSDP 모델 (GPU 0-3) vLLM 모델 (GPU 0-1)
402
+ ┌─────────────┐ ┌─────────────┐
403
+ │ Parameter A │ ─────────→ │ Parameter A │ (같은 메모리 참조)
404
+ │ Parameter B │ ─────────→ │ Parameter B │
405
+ │ Parameter C │ ─────────→ │ Parameter C │
406
+ └─────────────┘ └─────────────┘
407
+
408
+ PPO 업데이트 → FSDP 파라미터 변경 → vLLM도 자동으로 새 값 사용
409
+ ```
410
+
411
+ ---
412
+
413
+ ## 7. 데이터 흐름
414
+
415
+ ### 7.1 Round 1 상세 흐름
416
+
417
+ ```
418
+ 1. Problem: Mbpp/2 (예: "두 수의 합을 구하는 함수 작성")
419
+
420
+ ├── Phase 1: 프로그램 생성
421
+ │ ├── Prompt: "Generate 4 different solutions..."
422
+ │ ├── vLLM 생성 (동기화 발생)
423
+ │ └── Output: [prog1, prog2, prog3, prog4]
424
+
425
+ ├── Phase 2: I/O 추출
426
+ │ ├── 각 프로그램에 대해 입력 생성
427
+ │ ├── vLLM 사용 (동기화 건너뜀)
428
+ │ └── Output: [(input1, output1), (input2, output2), ...]
429
+
430
+ ├── Phase 3: Task 생성
431
+ │ ├── Induction: (1, 3) → "def add(a,b): return a+b"
432
+ │ ├── Deduction: (prog, 5) → 8
433
+ │ └── Abduction: (prog, 10) → (4, 6)
434
+
435
+ ├── Phase 4: 검증
436
+ │ └── 유효한 task만 필터링
437
+
438
+ └── Phase 5: PPO 학습
439
+ ├── 배치 생성
440
+ ├── Response 생성 (같은 vLLM)
441
+ ├── Reward 계산
442
+ └── FSDP 모델 업데이트
443
+ ```
444
+
445
+ ### 7.2 데이터 형식 변환
446
+
447
+ ```python
448
+ # TTRLVR Task 형식
449
+ {
450
+ 'problem_id': 'Mbpp/2',
451
+ 'task_type': 'induction',
452
+ 'input': 5,
453
+ 'output': 10,
454
+ 'target': 'def multiply_by_two(x): return x * 2',
455
+ 'prompt': 'Given input 5 produces output 10, write the function:'
456
+ }
457
+
458
+ # ↓ 변환
459
+
460
+ # VeRL DataProto 형식
461
+ {
462
+ 'input_ids': tensor([1, 234, 567, ...]), # 토큰화된 prompt
463
+ 'attention_mask': tensor([1, 1, 1, ...]),
464
+ 'prompt': 'Given input 5 produces output 10...',
465
+ 'target': 'def multiply_by_two(x): return x * 2',
466
+ 'meta_info': {
467
+ 'task_type': 'induction',
468
+ 'problem_id': 'Mbpp/2'
469
+ }
470
+ }
471
+ ```
472
+
473
+ ---
474
+
475
+ ## 8. 구현 세부사항
476
+
477
+ ### 8.1 VeRL과의 통합
478
+
479
+ #### 8.1.1 _generate_with_vllm 메서드
480
+ ```python
481
+ def _generate_with_vllm(self, prompt: str, temperature: float = 0.7):
482
+ """VeRL의 vLLM을 사용한 텍스트 생성"""
483
+
484
+ # 1. 토큰화
485
+ input_ids = self.tokenizer(prompt, ...).input_ids
486
+
487
+ # 2. DataProto 생성
488
+ prompts_proto = DataProto.from_dict({
489
+ "input_ids": input_ids.cuda(),
490
+ "attention_mask": torch.ones_like(input_ids).cuda(),
491
+ })
492
+
493
+ # 3. 메타 정보 설정
494
+ prompts_proto.meta_info = {
495
+ "eos_token_id": self.tokenizer.eos_token_id,
496
+ "temperature": temperature,
497
+ "do_sample": True,
498
+ "response_length": 256
499
+ }
500
+
501
+ # 4. VeRL의 vLLM으로 생성
502
+ outputs = self.actor_rollout_wg.generate_sequences(prompts_proto)
503
+
504
+ # 5. 디코딩 및 반환
505
+ return self.tokenizer.decode(outputs.batch["input_ids"][0])
506
+ ```
507
+
508
+ #### 8.1.2 CompleteTestTimePipeline 수정
509
+ ```python
510
+ # CompleteTestTimePipeline이 VeRL의 vLLM을 사용하도록
511
+ self.ttrlvr_pipeline.generate_with_verl = self._generate_with_vllm
512
+
513
+ # 이제 Pipeline 내부에서:
514
+ # response = self.generate_with_verl(prompt) # VeRL의 vLLM 사용
515
+ ```
516
+
517
+ ### 8.2 메모리 관리
518
+
519
+ #### 8.2.1 라운드 간 메모리 정리
520
+ ```python
521
+ def _manage_memory_between_rounds(self):
522
+ """라운드 간 메모리 정리 (인스턴스는 유지)"""
523
+
524
+ # GPU 캐시만 정리
525
+ torch.cuda.empty_cache()
526
+
527
+ # vLLM KV 캐시 정리 (선택적)
528
+ if hasattr(self.actor_rollout_wg, 'clear_kv_cache'):
529
+ self.actor_rollout_wg.clear_kv_cache()
530
+
531
+ # Garbage collection
532
+ import gc
533
+ gc.collect()
534
+ ```
535
+
536
+ #### 8.2.2 메모리 모니터링
537
+ ```python
538
+ def _monitor_memory(self):
539
+ """메모리 사용량 모니터링"""
540
+ for i in range(torch.cuda.device_count()):
541
+ allocated = torch.cuda.memory_allocated(i) / 1024**3
542
+ reserved = torch.cuda.memory_reserved(i) / 1024**3
543
+ print(f"GPU {i}: Allocated={allocated:.2f}GB, Reserved={reserved:.2f}GB")
544
+ ```
545
+
546
+ ### 8.3 에러 처리 및 복구
547
+
548
+ ```python
549
+ def _safe_generate(self, prompt: str, max_retries: int = 3):
550
+ """안전한 생성 with 재시도"""
551
+ for attempt in range(max_retries):
552
+ try:
553
+ return self._generate_with_vllm(prompt)
554
+ except Exception as e:
555
+ if attempt == max_retries - 1:
556
+ raise
557
+ torch.cuda.empty_cache()
558
+ time.sleep(1)
559
+ ```
560
+
561
+ ### 8.4 체크포인트 관리
562
+
563
+ ```python
564
+ def _save_checkpoint(self):
565
+ """체크포인트 저장"""
566
+ checkpoint = {
567
+ 'round': self.current_round,
568
+ 'model_state_dict': self.actor_rollout_wg.state_dict(),
569
+ 'optimizer_state_dict': self.optimizer.state_dict(),
570
+ 'metrics': self.accumulated_metrics,
571
+ 'timestamp': datetime.now().isoformat()
572
+ }
573
+
574
+ path = f"{self.checkpoint_dir}/round_{self.current_round}.pt"
575
+ torch.save(checkpoint, path)
576
+ ```
577
+
578
+ ---
579
+
580
+ ## 9. 성능 최적화
581
+
582
+ ### 9.1 배치 처리
583
+ - Phase 1-4에서 가능한 한 배치로 처리
584
+ - vLLM의 continuous batching 활용
585
+
586
+ ### 9.2 GPU 활용
587
+ - vLLM: GPU 0-1 (tensor parallel)
588
+ - FSDP: GPU 0-3 (data parallel)
589
+ - 효율적인 GPU 메모리 활용
590
+
591
+ ### 9.3 I/O 최적화
592
+ - Parquet 형식으로 중간 데이터 저장
593
+ - 비동기 I/O 처리
594
+
595
+ ---
596
+
597
+ ## 10. 디버깅 및 모니터링
598
+
599
+ ### 10.1 로깅 구조
600
+ ```
601
+ /home/ubuntu/RLVR/TestTime-RLVR-v2/logs/
602
+ ├── ttrlvr_unified_20241107_120000.log # 메인 로그
603
+ ├── round_1/
604
+ │ ├── phase_1_4.log # 데이터 생성 로그
605
+ │ └── phase_5.log # 학습 로그
606
+ └── metrics/
607
+ └── tensorboard/ # 학습 메트릭
608
+ ```
609
+
610
+ ### 10.2 주요 모니터링 지표
611
+ - 라운드별 소요 시간
612
+ - 생성된 task 수
613
+ - 평균 reward
614
+ - GPU 메모리 사용량
615
+ - 동기화 발생 횟수
616
+
617
+ ---
618
+
619
+ ## 11. 문제 해결 가이드
620
+
621
+ ### 11.1 OOM (Out of Memory)
622
+ - `gpu_memory_utilization` 조정 (기본: 0.35)
623
+ - `max_num_seqs` 감소
624
+ - 배치 크기 감소
625
+
626
+ ### 11.2 동기화 문제
627
+ - `load_format`이 `dummy_dtensor`인지 확인
628
+ - vLLM 인스턴스가 재생성되지 않는지 확인
629
+
630
+ ### 11.3 느린 성능
631
+ - GPU 활용률 확인
632
+ - 배치 크기 증가
633
+ - `enforce_eager=False` 확인 (CUDA graph 사용)
634
+
635
+ ---
636
+
637
+ ## 12. 결론
638
+
639
+ TTRLVR Unified는 기존 TTRLVR의 모든 기능을 유지하면서 다음을 달성했습니다:
640
+
641
+ 1. **구조적 개선**: 분리된 Phase들을 하나의 세션으로 통합
642
+ 2. **성능 향상**: vLLM 재생성 오버헤드 제거로 30-40% 속도 향상
643
+ 3. **안정성 향상**: 동기화 문제 완전 해결
644
+ 4. **확장성**: 더 큰 모델과 더 많은 라운드 지원 가능
645
+
646
+ 이 아키텍처는 TTRLVR의 정교한 데이터 생성 능력과 VeRL의 효율적인 PPO 학습을 완벽하게 결합했습니다.
absolute_zero_reasoner/__init__.py ADDED
File without changes
absolute_zero_reasoner/configs/azr_ppo_trainer.yaml ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ tokenizer: null
3
+ train_files: data/math/train_${reward_fn.extraction_type}.parquet
4
+ val_files: data/math/test_${reward_fn.extraction_type}.parquet
5
+
6
+ # Whether to use shared memory for data loading.
7
+ use_shm: False
8
+
9
+ prompt_key: prompt
10
+ max_prompt_length: 8096
11
+ max_response_length: 8096
12
+ train_batch_size: 1024
13
+ val_batch_size: 1312
14
+ return_raw_input_ids: False # This should be set to true when the tokenizer between policy and rm differs
15
+ return_raw_chat: False
16
+ shuffle: True
17
+ filter_overlong_prompts: False # for large-scale dataset, filtering overlong prompts could be timeconsuming. You cat set the filter_overlong_prompts_workers to use multiprocessing to speed up.
18
+ filter_overlong_prompts_workers: 1
19
+ truncation: error
20
+ image_key: images
21
+ video_key: videos
22
+ custom_cls:
23
+ path: null
24
+ name: null
25
+
26
+ actor_rollout_ref:
27
+ hybrid_engine: True
28
+ model:
29
+ path: ~/models/deepseek-llm-7b-chat
30
+ pretrained_tokenizer: True
31
+ use_shm: false
32
+ external_lib: null
33
+ override_config: { }
34
+ enable_gradient_checkpointing: True
35
+ use_remove_padding: False
36
+ use_liger: False
37
+ use_fused_kernels: False
38
+ trust_remote_code: True
39
+ actor:
40
+ strategy: fsdp2 # This is for backward-compatibility
41
+ ppo_mini_batch_size: 256
42
+ ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
43
+ ppo_micro_batch_size_per_gpu: null
44
+ use_dynamic_bsz: False
45
+ ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
46
+ grad_clip: 1.0
47
+ clip_ratio: 0.2
48
+ clip_ratio_low: 0.2
49
+ clip_ratio_high: 0.28
50
+ clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
51
+ entropy_coeff: 0.0
52
+ use_kl_loss: False # True for GRPO
53
+ kl_loss_coef: 0.0 # for grpo
54
+ use_torch_compile: True
55
+ kl_loss_type: low_var_kl # for grpo
56
+ ppo_epochs: 1
57
+ shuffle: False
58
+ ulysses_sequence_parallel_size: 1 # sp size
59
+ loss_agg_mode: "token-mean"
60
+ entropy_from_logits_with_chunking: False
61
+ entropy_checkpointing: False
62
+
63
+ # policy loss config
64
+ policy_loss:
65
+
66
+ # Loss function mode: vanilla / clip-cov / kl-cov from https://arxiv.org/abs/2505.22617
67
+ loss_mode: "vanilla"
68
+
69
+ # Ratio of tokens to be clipped for clip-cov loss
70
+ clip_cov_ratio: 0.0002
71
+
72
+ # Lower bound for clip-cov loss
73
+ clip_cov_lb: 1.0
74
+
75
+ # Upper bound for clip-cov loss
76
+ clip_cov_ub: 5.0
77
+
78
+ # Ratio of tokens to be applied kl penalty for kl-cov loss
79
+ kl_cov_ratio: 0.0002
80
+
81
+ # KL divergence penalty coefficient
82
+ ppo_kl_coef: 0.1
83
+ checkpoint:
84
+
85
+ # What to include in saved checkpoints
86
+ # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space
87
+ save_contents: ['model', 'optimizer', 'extra']
88
+
89
+ # For more flexibility, you can specify the contents to load from the checkpoint.
90
+ load_contents: ${actor_rollout_ref.actor.checkpoint.save_contents}
91
+ optim:
92
+ lr: 1e-6
93
+ lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
94
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
95
+ min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0
96
+ num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5
97
+ warmup_style: constant # select from constant/cosine
98
+ total_training_steps: -1 # must be override by program
99
+ weight_decay: 0.0
100
+ fsdp_config:
101
+ wrap_policy:
102
+ # transformer_layer_cls_to_wrap: None
103
+ min_num_params: 0
104
+ param_offload: False
105
+ optimizer_offload: False
106
+ offload_policy: False # only for fsdp2, offload param\grad\optimizer during train
107
+ reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size]
108
+ fsdp_size: -1
109
+
110
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
111
+ # before the current forward computation.
112
+ forward_prefetch: False
113
+
114
+ # profiler configs
115
+ profiler:
116
+
117
+ # True for each task has its own database, False for all tasks in one training step share one database.
118
+ discrete: False
119
+
120
+ # Whether to profile all ranks.
121
+ all_ranks: False
122
+
123
+ # The ranks that will be profiled. null or [0,1,...]
124
+ ranks: null
125
+ ref:
126
+
127
+ # actor_rollout_ref.ref: FSDP config same as actor. For models larger than 7B, it’s recommended to turn on offload for ref by default
128
+ strategy: ${actor_rollout_ref.actor.strategy}
129
+ include_ref: False
130
+ fsdp_config:
131
+ param_offload: False
132
+ reshard_after_forward: True # only for fsdp2, [True, False, int between 1 and fsdp_size]
133
+
134
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
135
+ # before the current forward computation.
136
+ forward_prefetch: False
137
+ wrap_policy:
138
+ # transformer_layer_cls_to_wrap: None
139
+ min_num_params: 0
140
+ use_torch_compile: ${actor_rollout_ref.actor.use_torch_compile}
141
+ log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
142
+ log_prob_micro_batch_size_per_gpu: null
143
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
144
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
145
+ ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
146
+
147
+ # calculate entropy with chunking to reduce memory peak
148
+ entropy_from_logits_with_chunking: False
149
+
150
+ # recompute entropy
151
+ entropy_checkpointing: False
152
+
153
+ # profiler configs
154
+ profiler:
155
+
156
+ # True for each task has its own database, False for all tasks in one training step share one database.
157
+ discrete: False
158
+
159
+ # Whether to profile all ranks.
160
+ all_ranks: False
161
+
162
+ # The ranks that will be profiled. null or [0,1,...]
163
+ ranks: null
164
+ rollout:
165
+ name: vllm
166
+ mode: sync # sync: LLM, async: AsyncLLM
167
+ chat_scheduler: null
168
+ max_model_len: null
169
+ temperature: 1.0
170
+ top_k: -1 # 0 for hf rollout, -1 for vllm rollout
171
+ top_p: 1
172
+ use_fire_sampling: False
173
+ prompt_length: ${data.max_prompt_length} # not use for opensource
174
+ response_length: ${data.max_response_length}
175
+ # for vllm rollout
176
+ dtype: bfloat16 # should align with FSDP
177
+ gpu_memory_utilization: 0.5
178
+ ignore_eos: False
179
+ enforce_eager: True
180
+ free_cache_engine: True
181
+ load_format: dummy_dtensor
182
+
183
+ # for huge model, layered summon can save memory (prevent OOM) but make it slower
184
+ layered_summon: False
185
+ tensor_model_parallel_size: 2
186
+ max_num_batched_tokens: 8192
187
+ max_num_seqs: 1024
188
+ log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
189
+ log_prob_micro_batch_size_per_gpu: null
190
+ log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
191
+ log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
192
+ disable_log_stats: True
193
+ enable_chunked_prefill: True # could get higher throughput
194
+ # for hf rollout
195
+ do_sample: True
196
+ n: 1 # > 1 for grpo
197
+
198
+ multi_stage_wake_up: false
199
+
200
+ # Extra inference engine arguments (vllm, sglang).
201
+ engine_kwargs:
202
+
203
+ # for vllm
204
+ vllm:
205
+
206
+ # Swap space (in GB) used by inference engine. null uses default (e.g., 4 GB).
207
+ swap_space: null
208
+
209
+ # Whether to disable the preprocessor cache for multimodel models.
210
+ disable_mm_preprocessor_cache: False
211
+
212
+ # for sglang
213
+ sglang:
214
+
215
+ # The attention backend for sglang engine. Options: flashinfer, triton, flashmla, null for default.
216
+ attention_backend: null
217
+
218
+ val_kwargs:
219
+ # sampling parameters for validation
220
+ top_k: -1 # 0 for hf rollout, -1 for vllm rollout
221
+ top_p: 1.0
222
+ temperature: 0
223
+ n: 1
224
+ do_sample: False # default eager for validation
225
+ # number of responses (i.e. num sample times)
226
+ multi_turn:
227
+ enable: False # should set rollout.name to sglang_async if True
228
+ max_turns: null # null for no limit (default max_length // 3)
229
+ tool_config_path: null # null for no tool
230
+ format: chatml # chatml, more formats will be supported in the future
231
+
232
+ # support logging rollout prob for debugging purpose
233
+ calculate_log_probs: False
234
+
235
+ # profiler configs
236
+ profiler:
237
+
238
+ # True for each task has its own database, False for all tasks in one training step share one database.
239
+ discrete: False
240
+
241
+ # Whether to profile all ranks.
242
+ all_ranks: False
243
+
244
+ # The ranks that will be profiled. null or [0,1,...]
245
+ ranks: null
246
+
247
+ # [Experimental] agent loop based rollout configs
248
+ agent:
249
+
250
+ # Number of agent loop workers
251
+ num_workers: 8
252
+
253
+ critic:
254
+
255
+ # Number of rollouts per update (mirrors actor rollout_n)
256
+ rollout_n: ${actor_rollout_ref.rollout.n}
257
+
258
+ # fsdp or fsdp2 strategy used for critic model training
259
+ strategy: ${actor_rollout_ref.actor.strategy}
260
+ optim:
261
+ lr: 1e-5
262
+ lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
263
+ min_lr_ratio: null # only useful for warmup with cosine
264
+ warmup_style: constant # select from constant/cosine
265
+ total_training_steps: -1 # must be override by program
266
+ weight_decay: 0.01
267
+ model:
268
+ path: ~/models/deepseek-llm-7b-chat
269
+
270
+ use_shm: False
271
+ tokenizer_path: ${actor_rollout_ref.model.path}
272
+ override_config: { }
273
+ external_lib: ${actor_rollout_ref.model.external_lib}
274
+ enable_gradient_checkpointing: True
275
+ use_remove_padding: False
276
+ fsdp_config:
277
+ param_offload: False
278
+ grad_offload: False
279
+ optimizer_offload: False
280
+ wrap_policy:
281
+ # transformer_layer_cls_to_wrap: None
282
+ min_num_params: 0
283
+
284
+ # Only for FSDP2: offload param/grad/optimizer during train
285
+ offload_policy: False
286
+
287
+ # Only for FSDP2: Reshard after forward pass to reduce memory footprint
288
+ reshard_after_forward: True
289
+
290
+ # Number of GPUs in each FSDP shard group; -1 means auto
291
+ fsdp_size: -1
292
+
293
+ # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather
294
+ # before the current forward computation.
295
+ forward_prefetch: False
296
+ ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
297
+ ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
298
+ ppo_micro_batch_size_per_gpu: null
299
+ forward_micro_batch_size: ${critic.ppo_micro_batch_size}
300
+ forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
301
+ use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
302
+ ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
303
+ forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
304
+ ulysses_sequence_parallel_size: 1 # sp size
305
+ ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
306
+ shuffle: ${actor_rollout_ref.actor.shuffle}
307
+ grad_clip: 1.0
308
+ cliprange_value: 0.5
309
+
310
+ reward_model:
311
+ enable: False
312
+ strategy: fsdp
313
+ model:
314
+ input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
315
+ path: ~/models/FsfairX-LLaMA3-RM-v0.1
316
+ external_lib: ${actor_rollout_ref.model.external_lib}
317
+ use_remove_padding: False
318
+ fsdp_config:
319
+ min_num_params: 0
320
+ param_offload: False
321
+ fsdp_size: -1
322
+ micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
323
+ micro_batch_size_per_gpu: null # set a number
324
+ max_length: null
325
+ ulysses_sequence_parallel_size: 1 # sp size
326
+ use_dynamic_bsz: ${critic.use_dynamic_bsz}
327
+ forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
328
+
329
+
330
+ # Cloud/local sandbox fusion configuration for custom reward logic
331
+ sandbox_fusion:
332
+
333
+ # Cloud/local function URL for sandbox execution
334
+ url: null
335
+
336
+ # Max concurrent requests allowed to sandbox
337
+ max_concurrent: 64
338
+
339
+ # Max memory limit for each sandbox process in MB
340
+ memory_limit_mb: 1024
341
+
342
+ # profiler configs
343
+ profiler:
344
+
345
+ # True for each task has its own database, False for all tasks in one training step share one database.
346
+ discrete: False
347
+
348
+ # Whether to profile all ranks.
349
+ all_ranks: False
350
+
351
+ # The ranks that will be profiled. null or [0,1,...]
352
+ ranks: null
353
+
354
+ algorithm:
355
+ gamma: 1.0
356
+ lam: 1.0
357
+ adv_estimator: gae
358
+ norm_adv_by_std_in_grpo: True
359
+ use_kl_in_reward: False
360
+ kl_penalty: kl # how to estimate kl divergence
361
+ kl_ctrl:
362
+ type: fixed
363
+ kl_coef: 0.0
364
+ horizon: 10000
365
+ target_kl: 0.0
366
+
367
+ # Whether to enable preference feedback PPO
368
+ use_pf_ppo: False
369
+
370
+ # Preference feedback PPO settings
371
+ pf_ppo:
372
+
373
+ # Method for reweighting samples: "pow", "max_min", or "max_random"
374
+ reweight_method: pow
375
+
376
+ # Power used for weight scaling in "pow" method
377
+ weight_pow: 2.0
378
+
379
+ ray_init:
380
+ num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
381
+
382
+ trainer:
383
+ balance_batch: True
384
+ debug: False
385
+ debug_port: 5678
386
+ wandb_run_id: null
387
+ total_epochs: 30
388
+
389
+ # The steps that will be profiled. null means no profiling. null or [1,2,5,...]
390
+ profile_steps: null
391
+ total_training_steps: null
392
+
393
+ # controller Nvidia Nsight Systems Options. Must set when profile_steps is not None.
394
+ ## reference https://docs.nvidia.com/nsight-systems/UserGuide/index.html
395
+ ## reference https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html
396
+ controller_nsight_options:
397
+
398
+ # Select the API(s) to be traced.
399
+ trace: "cuda,nvtx,cublas,ucx"
400
+
401
+ # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
402
+ cuda-memory-usage: "true"
403
+
404
+ # CUDA graphs will be traced as a whole
405
+ cuda-graph-trace: "graph"
406
+
407
+ # worker Nvidia Nsight Systems Options. Must set when profile_steps is not None.
408
+ worker_nsight_options:
409
+
410
+ # Select the API(s) to be traced.
411
+ trace: "cuda,nvtx,cublas,ucx"
412
+
413
+ # Track the GPU memory usage by CUDA kernels. Must be string type "true" or "false".
414
+ cuda-memory-usage: "true"
415
+
416
+ # CUDA graphs will be traced as a whole
417
+ cuda-graph-trace: "graph"
418
+
419
+ # Profiling only in a range of torch.cuda.profiler.start and stop. Do not change this config.
420
+ capture-range: "cudaProfilerApi"
421
+
422
+ # Specify the desired behavior when a capture range ends.
423
+ # In verl we need the orch.cuda.profiler.start/stop pair to repeats n times.
424
+ # valid values are "repeat-shutdown:n" or null.
425
+ # For normal whole step profiling, n = len(profile_steps);
426
+ # but for discrete profiling, n = len(profile_steps) * Number(subtasks).
427
+ # Or you can just leave it null and the program will use n = len(profile_steps) * 6;
428
+ capture-range-end: null
429
+
430
+ # Send signal to the target application's process group. We let the program to exit by itself.
431
+ kill: none
432
+
433
+ project_name: verl_examples
434
+ experiment_name: gsm8k
435
+ logger: [ 'console', 'wandb' ]
436
+ # Number of generations to log during validation
437
+ log_val_generations: 0
438
+
439
+ # Directory for logging rollout data; no dump if null
440
+ rollout_data_dir: null
441
+
442
+ # Directory for logging validation data; no dump if null
443
+ validation_data_dir: null
444
+
445
+ # Number of nodes used in the training
446
+ nnodes: 1
447
+ n_gpus_per_node: 8
448
+ save_freq: -1
449
+ # auto: find the last ckpt to resume. If can't find, start from scratch
450
+ resume_mode: auto # or auto or resume_path if
451
+ resume_from_path: False
452
+
453
+ # ESI redundant time (in seconds) for model checkpointsAdd commentMore actions
454
+ esi_redundant_time: 0
455
+ test_freq: -1
456
+ critic_warmup: 0
457
+ default_hdfs_dir: null
458
+ default_local_dir: checkpoints/code_io/${trainer.project_name}/${trainer.experiment_name}
459
+ remove_previous_ckpt_in_save: False
460
+ del_local_ckpt_after_load: False
461
+ wandb_tags: null
462
+
463
+ # Maximum number of actor checkpoints to keep
464
+ max_actor_ckpt_to_keep: null
465
+
466
+ # Maximum number of critic checkpoints to keep
467
+ max_critic_ckpt_to_keep: null
468
+
469
+ # Timeout (in seconds) for Ray worker to wait for registration
470
+ ray_wait_register_center_timeout: 300
471
+
472
+ # Device to run training on (e.g., "cuda", "cpu")
473
+ device: cuda
474
+
475
+ reward_fn:
476
+ extraction_type: answer_addition
477
+ math_metric: deepscaler #[math_verify|deepscaler|union]
478
+ splitter: "Assistant:"
479
+ boxed_retry: False
480
+
481
+ azr:
482
+ seed: 1
483
+ executor_max_workers: 1
484
+ executor_cleanup_frequency: 1
485
+ problem_types:
486
+ - code_i
487
+ - code_o
488
+ - code_f
489
+ pred_data_mix_strategy: "max_new" # [uniform_total, max_new, half_new, step]
490
+ gen_data_probabilities_strategy: "uniform" # [uniform, step]
491
+ past_epoch_window: ${azr.data_selection_strategy.update_iteration}
492
+ seed_dataset: null
493
+ error_seed_dataset: null
494
+ output_seed_path: null
495
+ output_error_seed_path: null
496
+ output_code_f_seed_path: null
497
+ code_f_seed_dataset: null
498
+ pretrain_pred_steps: -1
499
+ executor: qwq # [qwq, sandboxfusion]
500
+ ast_check: True
501
+ execute_max_timeout: 10 # seconds
502
+ random_print_max_programs: 3
503
+ train_propose: True
504
+ use_china_mirror: True # used for sandboxfusion executor for people in China
505
+
506
+ # Data saving options
507
+ save_generated_data: True # Enable/disable saving generated data
508
+ save_data_path: "./generated_programs" # Path to save generated data (if null, don't save)
509
+ save_valid_data: True # Save valid programs
510
+ save_invalid_data: True # Save invalid programs
511
+ save_frequency: 1 # Save every N steps (1 = every step)
512
+ save_final_datasets: False # Save complete datasets at training end
513
+ data_selection_strategy:
514
+ io_n: 6
515
+ update_iteration: 1
516
+ data_len: null # dummy set
517
+ seed_batch_factor: 4
518
+ content_max_length: 8096
519
+ valid_program_filter: all # [all (all valids), non_one (all valids except 100% accuracy), non_extremes (all valids except 0% and 100% accuracy)]
520
+ max_programs: null
521
+ batched_estimate: False
522
+ composite_function_n_min: -1
523
+ composite_function_n_max: -1
524
+ composite_chance: 0.5
525
+ composite_start_step: -1
526
+ max_programs_initial: ${azr.data_selection_strategy.composite_function_n_max}
527
+ composite_chance_initial: ${azr.data_selection_strategy.composite_chance}
528
+ composite_scheduler:
529
+ enabled: False
530
+ update_num_programs_start: 101
531
+ update_num_programs_interval: 50
532
+ num_programs_max: 3
533
+ update_probability_start: 101
534
+ update_probability_interval: 50
535
+ update_probability_max: 0.8
536
+ update_probability_increment: 0.01
537
+ num_inputs: 10 # for code_f, how many inputs to generate
538
+ banned_words:
539
+ - logging
540
+ - random
541
+ - multiprocessing
542
+ - pebble
543
+ - subprocess
544
+ - threading
545
+ - datetime
546
+ - time
547
+ - hashlib
548
+ - hmac
549
+ - bcrypt
550
+ - os.sys
551
+ - os.path
552
+ - sys.exit
553
+ - os.environ
554
+ - calendar
555
+ - datetime
556
+ banned_keywords_for_errors_and_exceptions:
557
+ # - raise
558
+ # - assert
559
+ # - try
560
+ # - except
561
+ reward:
562
+ n_samples: 8
563
+ extract_code_block: True
564
+ code_f_reward_type: binary # [accuracy, binary]
565
+ generation_reward_config:
566
+ format_reward: True
567
+ reject_multiple_functions: True
568
+ reject_test_input_in_code: False
569
+ f_replace_location: not_first # [not_first, any_last, any_first, not_last]
570
+ intrinsic_combine_method: sum # [sum, multiply, sum_multiply]
571
+ remove_after_return: False # remove global variables
572
+ remove_comments: False
573
+ remove_print: False
574
+ use_original_code_as_ref: False
575
+ generation_accuracy_convertion: one_minus
576
+ remove_input_from_snippet: False # prompting
577
+ include_references: True # ablation for unconditional generation
578
+ code_location: first # [first, last]
579
+ complexity_reward:
580
+ enabled: False
581
+ coef: 0.0
582
+ max: 0.5
583
+ mean_edit_distance_reward:
584
+ enabled: False
585
+ coef: 0.0
586
+ max: 0.5
587
+ halstead_reward:
588
+ enabled: False
589
+ coef: 0.0
590
+ max: 0.5
591
+ answer_diversity_reward:
592
+ enabled: False
593
+ coef: 0.0
594
+ max: 0.5
595
+ hierarchical: False
596
+ f_input_answer_diversity_reward:
597
+ enabled: False
598
+ coef: 0.0
599
+ max: 0.5
600
+ hierarchical: False
601
+ f_output_answer_diversity_reward:
602
+ enabled: False
603
+ coef: 0.0
604
+ max: 0.5
605
+ hierarchical: False
absolute_zero_reasoner/data_construction/__init__.py ADDED
File without changes
absolute_zero_reasoner/data_construction/constructor.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+
3
+ from numpy import random
4
+ import pandas as pd
5
+ from transformers import AutoTokenizer
6
+
7
+ from absolute_zero_reasoner.data_construction.prompts import get_code_problem_generator_prompt, get_code_problem_predictor_prompt
8
+ from absolute_zero_reasoner.data_construction.process_data import boxed_instruction, instruction_following
9
+ from absolute_zero_reasoner.utils.code_utils.parsers import replace_main_function_name
10
+
11
+
12
+ def get_gen_code_io_data(
13
+ io_data: List[Dict],
14
+ target_data_len: int,
15
+ problem_type: str,
16
+ instruction_type: str,
17
+ content_max_length: int,
18
+ io_n: int,
19
+ output_path: str,
20
+ split: str,
21
+ tokenizer: AutoTokenizer,
22
+ banned_keywords: List[str],
23
+ banned_assertion_keywords: List[str],
24
+ weights: List[float] = None,
25
+ enable_composite_function: bool = False,
26
+ composite_function_n_min: int = -1,
27
+ composite_function_n_max: int = -1,
28
+ composite_chance: float = 0.5,
29
+ remove_after_return: bool = False,
30
+ num_inputs: int = 10,
31
+ remove_input_from_snippet: bool = False,
32
+ include_references: bool = True,
33
+ ):
34
+ return_io_data = []
35
+ if instruction_type.startswith('boxed'):
36
+ instruction_template = boxed_instruction
37
+ elif instruction_type.startswith('answer'):
38
+ instruction_template = instruction_following
39
+ elif instruction_type.startswith('none'):
40
+ instruction_template = '{}'
41
+ else:
42
+ raise ValueError(f"Invalid instruction type: {instruction_type}")
43
+
44
+ if weights is None:
45
+ probabilities = [1.0 / len(io_data)] * len(io_data)
46
+ else:
47
+ # Normalize weights to form a probability distribution
48
+ probabilities = [float(w)/sum(weights) for w in weights]
49
+
50
+ idx = 0
51
+
52
+ while len(return_io_data) < target_data_len:
53
+ if not include_references and problem_type != 'code_f':
54
+ chosen_references = []
55
+ else:
56
+ chosen_references = random.choice(io_data, size=min(io_n, len(io_data)), replace=False, p=probabilities)
57
+ # composite functions is not used for code_f problem type
58
+ if problem_type != 'code_f' and composite_function_n_max > 0 and enable_composite_function and random.random() <= composite_chance and len(chosen_references) > composite_function_n_max:
59
+ # TODO: we only allow composite to sample from code snippets without composite functions
60
+ io_without_composite_function_indices = [i for i in range(len(io_data)) if not io_data[i]['composite_functions']]
61
+ io_without_composite_function_data = [io_data[i] for i in io_without_composite_function_indices]
62
+ io_without_composite_function_weights = [probabilities[i] for i in io_without_composite_function_indices]
63
+ # normalize the weights
64
+ io_without_composite_function_probabilities = [w / sum(io_without_composite_function_weights) for w in io_without_composite_function_weights]
65
+ # number of composite functions to sample is either fixed or random
66
+ composite_function_n = composite_function_n_min if composite_function_n_min == composite_function_n_max else random.randint(composite_function_n_min, composite_function_n_max)
67
+ composite_functions = random.choice(io_without_composite_function_data, size=composite_function_n, replace=False, p=io_without_composite_function_probabilities)
68
+ for i, composite_function in enumerate(composite_functions):
69
+ # TODO: need to also replace recursively called composite functions, ignore functions that have f as the last letter, only for function call f()
70
+ composite_functions[i]['snippet'] = replace_main_function_name(composite_function['snippet'], 'f', f'g_{i}')
71
+ imports = []
72
+ else:
73
+ composite_functions = []
74
+ if include_references:
75
+ imports = chosen_references[0]['imports']
76
+ else:
77
+ imports = []
78
+ io_prompt = instruction_template.format(
79
+ get_code_problem_generator_prompt(
80
+ problem_type=problem_type,
81
+ reference_snippets=chosen_references,
82
+ banned_keywords=banned_keywords,
83
+ banned_assertion_keywords=banned_assertion_keywords,
84
+ composite_functions=composite_functions,
85
+ remove_after_return=remove_after_return,
86
+ num_inputs=num_inputs,
87
+ remove_input_from_snippet=remove_input_from_snippet,
88
+ )
89
+ )
90
+ if len(tokenizer(io_prompt)['input_ids']) <= content_max_length:
91
+ io_item = {
92
+ "data_source": 'gen_' + problem_type,
93
+ "prompt": [{
94
+ "role": "user",
95
+ "content": io_prompt,
96
+ }],
97
+ "problem": '',
98
+ "ability": "code",
99
+ "reward_model": {
100
+ "style": "rule",
101
+ "ground_truth": '',
102
+ },
103
+ "extra_info": {
104
+ 'split': split,
105
+ 'index': idx,
106
+ 'metric': 'gen_' + problem_type,
107
+ 'chosen_references': chosen_references,
108
+ 'composite_functions': composite_functions,
109
+ 'imports': imports,
110
+ }
111
+ }
112
+ return_io_data.append(io_item)
113
+ idx += 1
114
+
115
+ if len(return_io_data) >= target_data_len:
116
+ break
117
+
118
+ # if io_data is not full, we sample upsample random data
119
+ while len(return_io_data) < target_data_len:
120
+ io_item = io_data[random.randint(0, len(io_data))]
121
+ return_io_data.append(io_item)
122
+
123
+ # output to parquet
124
+ df = pd.DataFrame(return_io_data)
125
+ df.to_parquet(output_path)
126
+
127
+
128
+ def get_pred_code_io_data(
129
+ io_data: List[Dict],
130
+ target_data_len: int,
131
+ problem_type: str,
132
+ instruction_type: str,
133
+ content_max_length: int,
134
+ output_path: str,
135
+ split: str,
136
+ tokenizer: AutoTokenizer,
137
+ ):
138
+ return_io_data = []
139
+ if instruction_type.startswith('boxed'):
140
+ instruction_template = boxed_instruction
141
+ elif instruction_type.startswith('answer'):
142
+ instruction_template = instruction_following
143
+ elif instruction_type.startswith('none'):
144
+ instruction_template = '{}'
145
+ else:
146
+ raise ValueError(f"Invalid instruction type: {instruction_type}")
147
+
148
+ for idx, io_item in enumerate(io_data):
149
+ if problem_type == 'code_i':
150
+ ground_truth = io_item['input']
151
+ elif problem_type == 'code_o':
152
+ ground_truth = io_item['output']
153
+ elif problem_type == 'code_e':
154
+ ground_truth = io_item['output']
155
+ elif problem_type == 'code_f':
156
+ ground_truth = io_item['snippet']
157
+ else:
158
+ raise ValueError(f"Invalid problem type: {problem_type}")
159
+ if problem_type == 'code_f':
160
+ num_given_inputs = len(io_item['inputs']) // 2
161
+ num_given_outputs = len(io_item['outputs']) // 2
162
+ given_inputs = list(io_item['inputs'][:num_given_inputs])
163
+ given_outputs = list(io_item['outputs'][:num_given_outputs])
164
+ hidden_inputs = list(io_item['inputs'][num_given_inputs:])
165
+ hidden_outputs = list(io_item['outputs'][num_given_outputs:])
166
+ io_prompt = instruction_template.format(
167
+ get_code_problem_predictor_prompt(
168
+ problem_type=problem_type,
169
+ snippet=io_item['snippet'],
170
+ message=io_item['message'],
171
+ input_output_pairs=zip(given_inputs, given_outputs),
172
+ )
173
+ )
174
+ else:
175
+ io_prompt = instruction_template.format(
176
+ get_code_problem_predictor_prompt(
177
+ problem_type=problem_type,
178
+ snippet=io_item['snippet'],
179
+ input_args=io_item['input'],
180
+ output=io_item['output'],
181
+ )
182
+ )
183
+ if len(tokenizer(io_prompt)['input_ids']) <= content_max_length:
184
+ output_io_item = {
185
+ "data_source": 'pred_' + problem_type,
186
+ "prompt": [{
187
+ "role": "user",
188
+ "content": io_prompt,
189
+ }],
190
+ "problem": io_item['snippet'],
191
+ "ability": "code",
192
+ "reward_model": {
193
+ "style": "rule",
194
+ "ground_truth": ground_truth,
195
+ },
196
+ "extra_info": {
197
+ 'split': split,
198
+ 'index': idx,
199
+ 'metric': 'pred_' + problem_type,
200
+ 'imports': io_item['imports'],
201
+ }
202
+ }
203
+ if problem_type == 'code_f': # for code_f, we need to split the inputs and outputs into given and hidden, only show part of the inputs and outputs to the model
204
+ output_io_item['extra_info']['given_inputs'] = given_inputs
205
+ output_io_item['extra_info']['given_outputs'] = given_outputs
206
+ output_io_item['extra_info']['hidden_inputs'] = hidden_inputs
207
+ output_io_item['extra_info']['hidden_outputs'] = hidden_outputs
208
+ output_io_item['extra_info']['message'] = io_item['message']
209
+ else:
210
+ output_io_item['extra_info']['input'] = io_item['input']
211
+ output_io_item['extra_info']['output'] = io_item['output']
212
+ return_io_data.append(output_io_item)
213
+
214
+ if len(return_io_data) >= target_data_len:
215
+ break
216
+
217
+ # if io_data is not full, we sample upsample random data
218
+ while len(return_io_data) < target_data_len:
219
+ io_item = return_io_data[random.randint(0, len(return_io_data))]
220
+ return_io_data.append(io_item)
221
+
222
+ # output to parquet
223
+ df = pd.DataFrame(return_io_data)
224
+ df.to_parquet(output_path)
225
+
absolute_zero_reasoner/data_construction/process_code_reasoning_data.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import argparse
3
+ import re
4
+
5
+ from datasets import load_dataset
6
+ from tqdm import tqdm
7
+ import pandas as pd
8
+
9
+ from absolute_zero_reasoner.rewards.code_reward import format_python_code
10
+ from absolute_zero_reasoner.data_construction.prompts import get_code_problem_predictor_prompt
11
+ from absolute_zero_reasoner.data_construction.process_data import instruction_following
12
+
13
+ def process_livecodebench_execution(row):
14
+ # Extract all function names from the code
15
+ program_name_matches = re.findall(r'def\s+(\w+)\s*\(', row['problem'])
16
+ if not program_name_matches:
17
+ raise ValueError("Could not find any function names in code")
18
+
19
+ # Extract the function name from the input
20
+ input_match = re.search(r'(\w+)\(', row['input'])
21
+ if not input_match:
22
+ raise ValueError("Could not find function name in input")
23
+
24
+ input_function_name = input_match.group(1)
25
+
26
+ # Check if the function name from input appears in any of the defined functions
27
+ if input_function_name not in program_name_matches:
28
+ raise ValueError(f"Function '{input_function_name}' from input not found in code. Available functions: {program_name_matches}")
29
+
30
+ # Use the function name from input for replacement
31
+ program_name = input_function_name
32
+
33
+ # Replace the program name with `f` in the code
34
+ row['problem'] = re.sub(r'def\s+' + re.escape(program_name) + r'\s*\(', 'def f(', row['problem'])
35
+
36
+ # Process the input: remove the function name and keep only the parameters
37
+ row['input'] = re.sub(r'^\w+\s*\(|\)$', '', row['input']).strip()
38
+
39
+ return row
40
+
41
+
42
+ def add_imports(problem):
43
+ # Add necessary imports based on the content of the problem
44
+ if 'collections' in problem:
45
+ problem = 'import collections\n' + problem
46
+ if 'Counter' in problem:
47
+ problem = 'from collections import Counter\n' + problem
48
+ if 'gcd' in problem:
49
+ problem = 'from math import gcd\n' + problem
50
+ if 'deque' in problem:
51
+ problem = 'from collections import deque\n' + problem
52
+ if '@cache' in problem:
53
+ problem = 'from functools import cache\n' + problem
54
+ if '= inf' in problem or '[inf]' in problem or 'inf)' in problem:
55
+ problem = 'from math import inf\n' + problem
56
+ if 'accumulate' in problem:
57
+ problem = 'from itertools import accumulate\n' + problem
58
+ if '@lru_cache' in problem:
59
+ problem = 'from functools import lru_cache\n' + problem
60
+ if 'defaultdict' in problem:
61
+ problem = 'from collections import defaultdict\n' + problem
62
+ if 'bisect' in problem:
63
+ problem = 'import bisect\n' + problem
64
+ if 'islice' in problem:
65
+ problem = 'from itertools import islice\n' + problem
66
+ if 'math.inf' in problem:
67
+ problem = 'import math\n' + problem
68
+ if 'prod(' in problem:
69
+ problem = 'from math import prod\n' + problem
70
+ if 'heapify(' in problem:
71
+ problem = 'from heapq import heapify, heappop, heappush\n' + problem
72
+ if 'reduce(' in problem:
73
+ problem = 'from functools import reduce\n' + problem
74
+ if 'comb(' in problem:
75
+ problem = 'from math import comb\n' + problem
76
+ problem = problem.replace('List', 'list').replace('Dict', 'dict').replace('Tuple', 'tuple').replace('Set', 'set')
77
+ problem = problem.replace('from typing import list', 'from typing import List')
78
+ return problem
79
+
80
+
81
+ if __name__ == '__main__':
82
+ parser = argparse.ArgumentParser()
83
+ parser.add_argument('--max_length', type=int, default=-1)
84
+ args = parser.parse_args()
85
+
86
+ # 283, 452, 510
87
+ ds = load_dataset('cruxeval-org/cruxeval')['test']
88
+ ds = ds.map(lambda x: {'problem': format_python_code(x['code'])})
89
+ output_data = []
90
+ for i, data in enumerate(tqdm(ds, desc="Processing CruxEval")):
91
+ prompt = get_code_problem_predictor_prompt('code_i', data['problem'], data['input'], data['output'])
92
+ formatted_question = instruction_following.format(prompt)
93
+ output_data.append({
94
+ "data_source": 'cruxeval_i',
95
+ "prompt": [{
96
+ "role": "user",
97
+ "content": formatted_question
98
+ }],
99
+ "problem": data['problem'],
100
+ "ability": "math",
101
+ "reward_model": {
102
+ "style": "rule",
103
+ "ground_truth": data['output']
104
+ },
105
+ "extra_info": {
106
+ 'split': 'test',
107
+ 'index': i,
108
+ 'metric': 'pred_code_i',
109
+ 'problem_type': 'code_i',
110
+ 'input': data['input'],
111
+ 'output': data['output'],
112
+ }
113
+ })
114
+ prompt = get_code_problem_predictor_prompt('code_o', data['problem'], data['input'], data['output'])
115
+ formatted_question = instruction_following.format(prompt)
116
+ output_data.append({
117
+ "data_source": 'cruxeval_o',
118
+ "prompt": [{
119
+ "role": "user",
120
+ "content": formatted_question
121
+ }],
122
+ "problem": data['problem'],
123
+ "ability": "math",
124
+ "reward_model": {
125
+ "style": "rule",
126
+ "ground_truth": data['output']
127
+ },
128
+ "extra_info": {
129
+ 'split': 'test',
130
+ 'index': i + len(data),
131
+ 'metric': 'pred_code_o',
132
+ 'problem_type': 'code_o',
133
+ 'input': data['input'],
134
+ 'output': data['output'],
135
+ }
136
+ })
137
+
138
+ # another ds:
139
+ ds = load_dataset('livecodebench/execution')['test']
140
+ ds = ds.map(lambda x: {'problem': format_python_code(x['code'])})
141
+ ds = ds.remove_columns(['code'])
142
+ ds = ds.map(process_livecodebench_execution)
143
+ # normalize the code
144
+ ds = ds.map(lambda x: {'problem': add_imports(x['problem'])})
145
+ for i, data in enumerate(tqdm(ds, desc="Processing LiveCodeBench")):
146
+ prompt = get_code_problem_predictor_prompt('code_i', data['problem'], data['input'], data['output'])
147
+ formatted_question = instruction_following.format(prompt)
148
+ output_data.append({
149
+ "data_source": 'livecodebench',
150
+ "prompt": [{
151
+ "role": "user",
152
+ "content": formatted_question
153
+ }],
154
+ "problem": data['problem'],
155
+ "ability": "math",
156
+ "reward_model": {
157
+ "style": "rule",
158
+ "ground_truth": data['output']
159
+ },
160
+ "extra_info": {
161
+ 'split': 'test',
162
+ 'index': i + len(data),
163
+ 'metric': 'pred_code_i',
164
+ 'problem_type': 'code_i',
165
+ 'input': data['input'],
166
+ 'output': data['output'],
167
+ }
168
+ })
169
+
170
+ df = pd.DataFrame(output_data)
171
+ if args.max_length > 0:
172
+ df = df.iloc[:args.max_length]
173
+ path = Path('data/code_reason')
174
+ path.mkdir(parents=True, exist_ok=True)
175
+ df.to_parquet(path / f'test_answer{"_" + str(args.max_length) if args.max_length > 0 else ""}.parquet')
absolute_zero_reasoner/data_construction/process_data.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Preprocess the GSM8k dataset to parquet format
16
+ """
17
+
18
+ import os
19
+ import datasets
20
+ from glob import glob
21
+ import argparse
22
+
23
+ from verl.utils.hdfs_io import copy, makedirs
24
+ from verl.utils.reward_score.math import remove_boxed, last_boxed_only_string
25
+
26
+
27
+ def extract_solution(solution_str):
28
+ return remove_boxed(last_boxed_only_string(solution_str))
29
+
30
+
31
+ METRIC_MAP = {
32
+ 'aime2024': 'math',
33
+ 'aime2025': 'math',
34
+ 'gpqa': 'mc',
35
+ 'amc2023': 'math',
36
+ 'math500': 'math',
37
+ 'minerva': 'math',
38
+ 'olympiadbench': 'math',
39
+ 'math': 'math',
40
+ 'orz': 'math',
41
+ 'simplerl': 'math',
42
+ 'hmmt_2025': 'math',
43
+ 'hmmt_2024': 'math',
44
+ 'live_math_bench': 'math',
45
+ 'big_math': 'math',
46
+ 'deepscaler': 'math',
47
+ "math3to5": 'math',
48
+ 'dapo': 'math',
49
+ }
50
+
51
+ instruction_following = "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think> <answer> answer here </answer>. User: {}\nAssistant: <think>"
52
+ boxed_instruction = "{}\nPlease reason step by step, and put your final answer within \\boxed{{}}."
53
+
54
+
55
+ # add a row to each data item that represents a unique id
56
+ def make_map_fn(split, question_key, answer_key, do_extract_solution, reward_fn_extraction_type, nothink = False):
57
+
58
+ def process_fn(example, idx):
59
+ question = example.pop(question_key)
60
+
61
+ if reward_fn_extraction_type == 'answer':
62
+ formatted_question = (instruction_following if not nothink else instruction_following.strip(' <think>')).format(question)
63
+ elif reward_fn_extraction_type == 'boxed':
64
+ formatted_question = boxed_instruction.format(question)
65
+ elif reward_fn_extraction_type == 'none':
66
+ formatted_question = question
67
+ # gpqa has this string in the question
68
+ if reward_fn_extraction_type != 'boxed':
69
+ remove_string = "\n\nPlease reason step-by-step and put your choice letter without any other text with \\boxed{} in the end."
70
+ replacement_string = '\n\nPlease reason step-by-step and put your choice letter without any other text with <answer> </answer> in the end.'
71
+ formatted_question = formatted_question.replace(remove_string, replacement_string)
72
+
73
+ answer = example.pop(answer_key)
74
+ if do_extract_solution:
75
+ solution = extract_solution(answer)
76
+ else:
77
+ solution = answer
78
+ data_source = example.pop('data_source')
79
+ data = {
80
+ "data_source": data_source,
81
+ "prompt": [{
82
+ "role": "user",
83
+ "content": formatted_question
84
+ }],
85
+ "problem": question,
86
+ "ability": "math",
87
+ "reward_model": {
88
+ "style": "rule",
89
+ "ground_truth": solution
90
+ },
91
+ "extra_info": {
92
+ 'split': split,
93
+ 'index': idx,
94
+ 'metric': METRIC_MAP[data_source],
95
+ }
96
+ }
97
+ return data
98
+
99
+ return process_fn
100
+
101
+
102
+ def process_data(args):
103
+ # 'lighteval/MATH' is no longer available on huggingface.
104
+ # Use mirror repo: DigitalLearningGmbH/MATH-lighteval
105
+ if args.train_set == 'math':
106
+ dataset = datasets.load_dataset('DigitalLearningGmbH/MATH-lighteval', trust_remote_code=True)
107
+ elif args.train_set == 'orz':
108
+ dataset = datasets.load_dataset('json', data_files='data/orz_math_57k_collected.json')
109
+ dataset = dataset.map(lambda x: {'problem': x['0']['value'], 'solution': x['1']['ground_truth']['value']})
110
+ elif args.train_set == 'simplerl':
111
+ dataset = datasets.load_dataset('json', data_files='data/math_level3to5_data_processed_with_qwen_prompt.json')
112
+ dataset = dataset.map(lambda x: {'problem': x['input'].replace('<|im_start|>system\nPlease reason step by step, and put your final answer within \\boxed{}.<|im_end|>\n<|im_start|>user\n', '').replace('<|im_end|>\n<|im_start|>assistant', ''), 'solution': x['gt_answer']})
113
+ elif args.train_set == 'big_math':
114
+ dataset = datasets.load_dataset('SynthLabsAI/Big-Math-RL-Verified')
115
+ dataset = dataset.rename_column('answer', 'solution')
116
+ elif args.train_set == 'deepscaler':
117
+ dataset = datasets.load_dataset('agentica-org/DeepScaleR-Preview-Dataset')
118
+ dataset = dataset.remove_columns(['solution'])
119
+ dataset = dataset.rename_column('answer', 'solution')
120
+ elif args.train_set == 'dapo':
121
+ remove_string = "Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n\n"
122
+ remove_string_2 = "\n\nRemember to put your answer on its own line after \"Answer:\"."
123
+ dataset = datasets.load_dataset('YouJiacheng/DAPO-Math-17k-dedup')
124
+ dataset = dataset.map(lambda x: {'problem': x['prompt'][0]['content'].replace(remove_string, '').replace(remove_string_2, '').strip(), 'solution': x['reward_model']['ground_truth']})
125
+ else:
126
+ raise ValueError(f"Invalid train_set: {args.train_set}")
127
+
128
+ if not args.test_only:
129
+ train_dataset = dataset['train']
130
+ train_dataset = train_dataset.add_column('data_source', [args.train_set] * len(train_dataset))
131
+ if args.filter_key is not None and args.filter_value is not None:
132
+ train_dataset = train_dataset.filter(lambda x: x[args.filter_key] == args.filter_value)
133
+ train_dataset = train_dataset.remove_columns([k for k in train_dataset.column_names if k not in ['problem', 'solution', 'data_source']])
134
+
135
+ test_datasources = glob('data/*.jsonl')
136
+ test_datasets = []
137
+ for test_datasource in test_datasources:
138
+ if 'seed_io' in test_datasource or 'MbppPlus' in test_datasource or 'HumanEvalPlus' in test_datasource:
139
+ continue
140
+ temp_ds = datasets.load_dataset('json', data_files=test_datasource, split='train')
141
+ if 'question' in temp_ds.column_names and 'problem' not in temp_ds.column_names:
142
+ temp_ds = temp_ds.rename_column('question', 'problem')
143
+ temp_ds = temp_ds.remove_columns([col for col in temp_ds.column_names if col not in ['problem', 'answer']])
144
+ temp_ds = temp_ds.add_column('data_source', [test_datasource.split('/')[-1].split('.')[0]] * len(temp_ds))
145
+ temp_ds = temp_ds.cast_column('problem', datasets.Value('string'))
146
+ temp_ds = temp_ds.cast_column('answer', datasets.Value('string'))
147
+ temp_ds = temp_ds.cast_column('data_source', datasets.Value('string'))
148
+ test_datasets.append(temp_ds)
149
+ live_math_bench_datasets = ['v202412_AMC_en', 'v202412_CCEE_en', 'v202412_CNMO_en', 'v202412_WLPMC_en', 'v202412_hard_en']
150
+ for dataset_name in live_math_bench_datasets:
151
+ live_math_bench_ds = datasets.load_dataset('opencompass/LiveMathBench', dataset_name)['test']
152
+ live_math_bench_ds = live_math_bench_ds.rename_column('question', 'problem')
153
+ live_math_bench_ds = live_math_bench_ds.remove_columns([col for col in live_math_bench_ds.column_names if col not in ['problem', 'answer']])
154
+ live_math_bench_ds = live_math_bench_ds.add_column('data_source', ['live_math_bench'] * len(live_math_bench_ds))
155
+ test_datasets.append(live_math_bench_ds)
156
+ test_dataset = datasets.concatenate_datasets(test_datasets)
157
+
158
+ if not args.test_only:
159
+ train_dataset = train_dataset.map(
160
+ function=make_map_fn(args.train_split_key, 'problem', 'solution', args.train_set == 'math', args.reward_fn_extraction_type, args.nothink),
161
+ with_indices=True, num_proc=16,
162
+ )
163
+ test_dataset = test_dataset.map(
164
+ function=make_map_fn(args.eval_split_key, 'problem', 'answer', False, args.reward_fn_extraction_type, args.nothink),
165
+ with_indices=True, num_proc=16,
166
+ )
167
+
168
+ if args.length_limit != -1 and not args.test_only:
169
+ train_dataset = train_dataset.select(range(args.length_limit))
170
+ test_dataset = test_dataset.select(range(args.length_limit))
171
+
172
+ local_dir = args.local_dir + f'/{args.train_set}{"_nothink" if args.nothink else ""}'
173
+ hdfs_dir = args.hdfs_dir
174
+
175
+ if args.filter_key is not None:
176
+ filter_key = f"_{args.filter_key}_{args.filter_value}"
177
+ else:
178
+ filter_key = ""
179
+
180
+ if not args.test_only:
181
+ train_dataset.to_parquet(os.path.join(local_dir, f'train_{args.reward_fn_extraction_type}{"" if args.length_limit == -1 else f"_{args.length_limit}"}{filter_key}.parquet'))
182
+ test_dataset.to_parquet(os.path.join(local_dir, f'test_{args.reward_fn_extraction_type}{"_ood" if args.ood_testsets else ""}{"" if args.length_limit == -1 else f"_{args.length_limit}"}{filter_key}.parquet'))
183
+
184
+ if hdfs_dir is not None:
185
+ makedirs(hdfs_dir)
186
+
187
+ copy(src=local_dir, dst=hdfs_dir)
188
+
189
+ if __name__ == '__main__':
190
+ parser = argparse.ArgumentParser()
191
+ parser.add_argument('--local_dir', default='data')
192
+ parser.add_argument(
193
+ '--reward_fn_extraction_type',
194
+ default='answer',
195
+ choices=['answer', 'boxed', 'none']
196
+ )
197
+ parser.add_argument('--length_limit', default=-1, type=int)
198
+ parser.add_argument('--hdfs_dir', default=None)
199
+ parser.add_argument('--train_set', default='math', choices=['math', 'orz', 'simplerl', 'big_math', 'deepscaler', 'dapo'])
200
+ parser.add_argument('--test_only', default=False, action='store_true')
201
+ parser.add_argument('--train_split_key', default='train', type=str)
202
+ parser.add_argument('--eval_split_key', default='test', type=str)
203
+ parser.add_argument('--filter_key', default=None, type=str)
204
+ parser.add_argument('--filter_value', default=None, type=str)
205
+ parser.add_argument('--nothink', default=False, action='store_true')
206
+
207
+ args = parser.parse_args()
208
+ print(args)
209
+
210
+ process_data(args)
absolute_zero_reasoner/data_construction/prompts.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Tuple
2
+
3
+ code_input_prompt = """
4
+ ## Task: Create a Python Code Snippet (where custom classes are allowed, which should be defined at the top of the code snippet) with one Matching Input
5
+
6
+ Using the reference code snippets provided below as examples, design a new and unique Python code snippet that demands deep algorithmic reasoning to deduce one possible input from a given output. Your submission should include both a code snippet and test input pair, where the input will be plugged into the code snippet to produce the output, which that function output be given to a test subject to come up with any input that will produce the same function output. This is meant to be an I.Q. test.
7
+
8
+ ### Code Requirements:
9
+ - Name the entry function `f` (e.g., `def f(...): ...`), you can have nested definitions inside `f`
10
+ - Ensure the function returns a value
11
+ - Include at least one input parameter
12
+ - Make the function deterministic
13
+ - Make the snippet require state tracking across multiple data transformations, ensuring the task requires long multi step reasoning
14
+ - AVOID THE FOLLOWING:
15
+ * Random functions or variables
16
+ * Date/time operations
17
+ * I/O operations (reading files, network requests)
18
+ * Printing or logging
19
+ * Any external state
20
+ - Ensure execution completes within 10 seconds on a modern CPU
21
+ - All imports and class definitions should be at the very top of the code snippet
22
+ - The snippet should end with a return statement from the main function `f`, anything after will be removed
23
+ {remove_input_from_snippet_prompt}{remove_after_return_prompt}
24
+ ### Input Requirements:
25
+ - Provide exactly one test input for your function
26
+ - Format multiple arguments with commas between them
27
+ - Remember to add quotes around string arguments
28
+
29
+ ### Formatting:
30
+ - Format your code with: ```python
31
+ def f(...):
32
+ # your code here
33
+ return ...
34
+ ```
35
+ - Format your input with: ```input
36
+ arg1, arg2, ...
37
+ ```
38
+
39
+ ### Example Format:
40
+ ```python
41
+ def f(name: str, info: dict):
42
+ # code logic here
43
+ return result
44
+ ```
45
+
46
+ ```input
47
+ 'John', {{'age': 20, 'city': 'New York'}}
48
+ ```
49
+
50
+ ### Evaluation Criteria:
51
+ - Executability, your code should be executable given your input
52
+ - Difficulty in predicting the output from your provided input and code snippet. Focus on either algorithmic reasoning or logic complexity. For example, you can define complex data structure classes and operate on them like trees, heaps, stacks, queues, graphs, etc, or use complex control flow, dynamic programming, recursions, divide and conquer, greedy, backtracking, etc
53
+ - Creativity, the code needs to be sufficiently different from the provided reference snippets
54
+ - Restricted usage of certain keywords and packages, you are not allowed to use the following words in any form, even in comments: <|BANNED_KEYWORDS|>
55
+
56
+ First, carefully devise a clear plan: e.g., identify how your snippet will be challenging, distinct from reference snippets, and creative. Then, write the final code snippet and its inputs.
57
+
58
+ ### Reference Code Snippets:
59
+ """
60
+
61
+ code_output_prompt = """
62
+ ## Task: Create a New Python Code Snippet (where custom classes are allowed, which should be defined at the top of the code snippet) with one Matching Input
63
+
64
+ Using the reference code snippets provided below as examples, design a new and unique Python code snippet that demands deep algorithmic reasoning to deduce the output from the input. Your submission should include a code snippet and a test input pair, where the input will be plugged into the code snippet to produce the output. The input will be given to a test subject to deduce the output, which is meant to be an I.Q. test.
65
+
66
+ ### Code Requirements:
67
+ - Name the entry function `f` (e.g., `def f(...): ...`), you can have nested definitions inside `f`
68
+ - Ensure the function returns a value
69
+ - Include at least one input parameter
70
+ - Make the function deterministic
71
+ - Make the snippet require state tracking across multiple data transformations, ensuring the task requires long multi step reasoning
72
+ - AVOID THE FOLLOWING:
73
+ * Random functions or variables
74
+ * Date/time operations
75
+ * I/O operations (reading files, network requests)
76
+ * Printing or logging
77
+ * Any external state
78
+ - Ensure execution completes within 10 seconds on a modern CPU
79
+ - All imports and class definitions should be at the very top of the code snippet
80
+ - The snippet should end with a return statement from the main function `f`, anything after will be removed
81
+ {remove_input_from_snippet_prompt}{remove_after_return_prompt}
82
+ ### Input Requirements:
83
+ - Provide exactly one test input for your function
84
+ - Format multiple arguments with commas between them
85
+ - Remember to add quotes around string arguments
86
+
87
+ ### Formatting:
88
+ - Format your code with:
89
+ ```python
90
+ def f(...):
91
+ # your code here
92
+ return ...
93
+ ```
94
+ - Format your input with:
95
+ ```input
96
+ arg1, arg2, ...
97
+ ```
98
+
99
+ ### Example Format:
100
+ ```python
101
+ def f(name: str, info: dict):
102
+ # code logic here
103
+ return result
104
+ ```
105
+
106
+ ```input
107
+ 'John', {{'age': 20, 'city': 'New York'}}
108
+ ```
109
+
110
+ ### Evaluation Criteria:
111
+ - Executability, your code should be executable given your input
112
+ - Difficulty in predicting your ```input``` from 1) your ```python``` code and 2) the deterministic ```output``` that will be obtained from your ```input```. Focus on either algorithmic reasoning or logic complexity. For example, you can define complex data structure classes and operate on them like trees, heaps, stacks, queues, graphs, etc, or use complex control flow, dynamic programming, recursions, divide and conquer, greedy, backtracking, etc
113
+ - Creativity, the code needs to be sufficiently different from the provided reference snippets
114
+ - Restricted usage of certain keywords and packages, you are not allowed to use the following words in any form, even in comments: <|BANNED_KEYWORDS|>
115
+
116
+ First, carefully devise a clear plan: e.g., identify how your snippet will be challenging, distinct from reference snippets, and creative. Then, write the final code snippet and its inputs.
117
+
118
+ ### Reference Code Snippets:
119
+ """
120
+
121
+ code_error_prompt = """
122
+ ## Task: Create a New Python Code Snippet (where custom classes are allowed, which should be defined at the top of the code snippet) with one Matching Input
123
+
124
+ Using the reference code snippets provided below as examples, design a new and unique Python code snippet that demands deep algorithmic reasoning to deduce what type of error will be raised when the code is executed. Your submission should include a code snippet and a test input pair, where the input will be plugged into the code snippet to produce the error. You can also choose to include a custom error type in your code snippet. However, the code can also be designed to raise no error. The input and the code will be given to a test subject to deduce the error type, which is meant to be an I.Q. test.
125
+
126
+ ### Code Requirements:
127
+ - Name the entry function `f` (e.g., `def f(...): ...`), you can have nested definitions inside `f`
128
+ - Ensure the function returns a value
129
+ - Include at least one input parameter
130
+ - Make the function deterministic
131
+ - Make the snippet require state tracking across multiple data transformations, ensuring the task requires long multi step reasoning
132
+ - AVOID THE FOLLOWING:
133
+ * Random functions or variables
134
+ * Date/time operations
135
+ * I/O operations (reading files, network requests)
136
+ * Printing or logging
137
+ * Any external state
138
+ - Ensure execution completes within 10 seconds on a modern CPU
139
+ - All imports and class definitions should be at the very top of the code snippet
140
+ - The snippet should end with a return statement from the main function `f`, anything after will be removed
141
+ {remove_after_return_prompt}
142
+ ### Input Requirements:
143
+ - Provide exactly one test input for your function
144
+ - Format multiple arguments with commas between them
145
+ - Remember to add quotes around string arguments
146
+
147
+ ### Formatting:
148
+ - Format your code with:
149
+ ```python
150
+ def f(...):
151
+ # your code here
152
+ return ...
153
+ ```
154
+ - Format your input with:
155
+ ```input
156
+ arg1, arg2, ...
157
+ ```
158
+
159
+ ### Example Format:
160
+ ```python
161
+ def f(name: str, info: dict):
162
+ # code logic here
163
+ return result
164
+ ```
165
+
166
+ ```input
167
+ 'John', {{'age': 20, 'city': 'New York'}}
168
+ ```
169
+
170
+ ### Evaluation Criteria:
171
+ - Executability, your code should be executable given your input
172
+ - Difficulty in deducing the error type (or no error) from 1) your ```python``` code and ```input```. Focus on either algorithmic reasoning or logic complexity. For example, you can define complex data structure classes and operate on them like trees, heaps, stacks, queues, graphs, etc, or use complex control flow, dynamic programming, recursions, divide and conquer, greedy, backtracking, etc
173
+ - Creativity, the code needs to be sufficiently different from the provided reference snippets
174
+ - Restricted usage of certain keywords and packages, you are not allowed to use the following words in any form, even in comments: <|BANNED_KEYWORDS|>
175
+ <|BANNED_ASSERTION_KEYWORDS|>
176
+ First, carefully devise a clear plan: e.g., identify how your snippet will be challenging, distinct from reference snippets, and creative. Then, write the final code snippet and its inputs. The code needs to compile and pass AST checks, but it is intended to raise an error or not.
177
+
178
+ ### Reference Code Snippets:
179
+ """
180
+
181
+ code_function_prompt = """
182
+ ## Task: Output {num_inputs} Inputs that can be plugged into the following Code Snippet to produce diverse Outputs, and give a message related to the given snippet.
183
+
184
+ Using the code snippet provided below, design {num_inputs} inputs that can be plugged into the code snippet to produce a diverse set of outputs. A subset of your given input and its deterministically produced outputs will be given to a test subject to deduce the function, which is meant to be an I.Q. test. You can also leave a message to the test subject to help them deduce the code snippet.
185
+
186
+ ### Input Requirements:
187
+ - Provide {num_inputs} valid inputs for the code snippet
188
+ - For each input, format multiple arguments with commas between them
189
+ - Remember to add quotes around string arguments
190
+ - Each input should be individually wrapped in ```input``` tags
191
+
192
+ ### Message Requirements:
193
+ - Leave a message to the test subject to help them deduce the code snippet
194
+ - The message should be wrapped in ```message``` tags
195
+ - The message can be in any form, can even be formed into a coding question, or a natural language instruction what the code snippet does
196
+ - You cannot provide the code snippet in the message
197
+
198
+ ### Formatting:
199
+ - Format your input with:
200
+ ```input
201
+ arg1, arg2, ...
202
+ ```
203
+
204
+ ### Example Format:
205
+ ```input
206
+ 'John', {{'age': 20, 'city': 'New York'}}
207
+ ```
208
+ ```input
209
+ 'Sammy', {{'age': 37, 'city': 'Los Angeles'}}
210
+ ```
211
+
212
+ ### Evaluation Criteria:
213
+ - Executability, your code should be executable given your inputs
214
+ - Coverage, the inputs and outputs should cover the whole input space of the code snippet, able to deduce the code snippet from the inputs and outputs
215
+ - Creativity, the inputs need to be sufficiently different from each other
216
+ - The overall selection of inputs and message combined should be challenging for the test subject, but not impossible for them to solve
217
+ First, carefully devise a clear plan: e.g., understand the code snippet, then identify how your proposed inputs have high coverage, and why the inputs will be challenging and creative. Then, write the inputs and message. Remember to wrap your inputs in ```input``` tags, and your message in ```message``` tags.
218
+
219
+ ### Code Snippet:
220
+ ```python
221
+ {snippet}
222
+ ```
223
+ """
224
+
225
+ # code_input_predictor_prompt = """
226
+ # # Task: Provide One Possible Input of a Python Code Snippet Given the Code and Output
227
+ # Given the following Code Snippet and the Output, think step by step then provide one possible input that produced the output. The input needs to be wrapped in ```input``` tags. Remember if an argument is a string, wrap it in quotes. If the function requires multiple arguments, separate them with commas.
228
+
229
+ # # Code Snippet:
230
+ # ```python
231
+ # {snippet}
232
+ # ```
233
+
234
+ # # Output:
235
+ # ```output
236
+ # {output}
237
+ # ```
238
+
239
+ # # Output Format:
240
+ # ```input
241
+ # arg1, arg2, ...
242
+ # ```
243
+ # # Example Output:
244
+ # ```input
245
+ # 'John', {{'age': 20, 'city': 'New York'}}
246
+ # ```
247
+ # """
248
+
249
+ # code_output_predictor_prompt = """
250
+ # # Task: Deduce the Output of a Python Code Snippet Given the Code and Input
251
+ # Given the following Code Snippet and the Input, think step by step then deduce the output that will be produced from plugging the Input into the Code Snippet. Put your output in ```output``` tags. Remember if the output is a string, wrap it in quotes. If the function returns multiple values, remember to use a tuple to wrap them.
252
+
253
+ # # Code Snippet:
254
+ # ```python
255
+ # {snippet}
256
+ # ```
257
+
258
+ # # Input:
259
+ # ```input
260
+ # {input_args}
261
+ # ```
262
+
263
+ # # Example Output:
264
+ # ```output
265
+ # {{'age': 20, 'city': 'New York'}}
266
+ # ```
267
+ # """
268
+
269
+ code_error_predictor_prompt = """
270
+ # Task: Deduce the Error Type of a Python Code Snippet Given the Code and Input
271
+ Given the following Code Snippet and the Input, think step by step to deduce the error type that will be raised when the code is executed. Put your final output in ```output``` tags. If there are no errors, put "NoError" in the ```output``` tags.
272
+
273
+ # Code Snippet:
274
+ ```python
275
+ {snippet}
276
+ ```
277
+
278
+ # Input:
279
+ ```input
280
+ {input_args}
281
+ ```
282
+
283
+ # Example Output:
284
+ ```output
285
+ ValueError
286
+ ```
287
+ """
288
+
289
+ # code_suffix = "\nf(<|YOUR INPUT WILL BE PLUGGED HERE|>)"
290
+
291
+ # code_function_predictor_prompt = """
292
+ # # Task: Deduce the Function that Produced the Outputs from the Inputs
293
+ # Given a set of input/output pairs and a message that describes the function, think through the problem step by step to deduce a general code snippet. This code should produce the hidden outputs from the hidden inputs, matching the original data-generating code that created the input/output pairs. Place your final answer inside python tags! It may be helpful to work through each input/output pair individually to test your function. If your function doesn’t work as expected, revise it until it does. The final code snippet will be used to evaluate your response, which is wrapped in ```python``` tags.
294
+
295
+ # # Code Requirements:
296
+ # - Name the entry function `f` (e.g., `def f(...): ...`), you can have nested definitions inside `f`
297
+ # - Ensure the function returns a value
298
+ # - Include at least one input parameter
299
+ # - Make the function deterministic
300
+ # - AVOID THE FOLLOWING:
301
+ # * Random functions or variables
302
+ # * Date/time operations
303
+ # * I/O operations (reading files, network requests)
304
+ # * Printing or logging
305
+ # * Any external state
306
+ # - Ensure execution completes within 10 seconds on a modern CPU
307
+ # - All imports and class definitions should be at the very top of the code snippet
308
+ # - The snippet should end with a return statement from the main function `f()`, anything after will be removed
309
+
310
+ # # Input and Output Pairs:
311
+ # {input_output_pairs}
312
+
313
+ # # Message:
314
+ # ```message
315
+ # {message}
316
+ # ```
317
+
318
+ # # Example Output:
319
+ # ```python
320
+ # def f(a):
321
+ # return a
322
+ # ```
323
+
324
+ # Name your entry function `f()`!!!
325
+ # """
326
+
327
+
328
+
329
+ #################################
330
+ # Changed Prompt #
331
+ #################################
332
+
333
+
334
+ code_input_predictor_prompt = """
335
+ A conversation between User and Assistant.
336
+ The User provides a Python code snippet and its observed output. The Assistant must:
337
+
338
+ 1. **Privately think step-by-step** about which input produces that output.
339
+ 2. **Output exactly one** `<think>...</think>` block containing your full reasoning.
340
+ 3. **Then output exactly one** `<answer>...</answer>` block containing **only** the input values—no labels, no comments, no extra text.
341
+ 4. **Do not** generate any text outside these two blocks.
342
+ 5. Adhere to the **input rules**.
343
+
344
+ # Input Rules:
345
+ - If an argument is a string, wrap it in quotes.
346
+ - For multiple arguments, separate by commas.
347
+ - Use Python literal notation for lists, dicts, tuples.
348
+ - Boolean values must be `True` or `False`.
349
+
350
+ User:
351
+ # Python Code Snippet:
352
+ {snippet}
353
+
354
+ # Observed Output:
355
+ {output}
356
+
357
+ # Assitant should follow this format:
358
+
359
+ # Example Response format:
360
+ <think>
361
+ # 1. Analyze the function signature.
362
+ # 2. Walk through the code to see how the observed output arises.
363
+ # 3. Identify specific input values that yield that output.
364
+ </think>
365
+ <answer>
366
+ <your input here>
367
+ </answer>
368
+
369
+ Assistant:
370
+ """
371
+
372
+ code_output_predictor_prompt = """
373
+ A conversation between User and Assistant.
374
+ The User provides a Python code snippet and specific input values. The Assistant must:
375
+
376
+ 1. **Privately think step-by-step** about how the code executes with the given inputs.
377
+ 2. **Output exactly one** `<think>...</think>` block containing your full reasoning.
378
+ 3. **Then output exactly one** `<answer>...</answer>` block containing **only** the output values—no labels, no comments, no extra text.
379
+ 4. **Do not** generate any text outside these two blocks.
380
+ 5. Adhere to the **output rules**.
381
+
382
+ # Output Rules:
383
+ - If the output is a string, wrap it in quotes.
384
+ - For dicts, lists, and other literals, use valid Python literal notation.
385
+
386
+ User:
387
+ # Python Code Snippet:
388
+ {snippet}
389
+
390
+ # Input:
391
+ {input_args}
392
+
393
+ # Assitant should follow this format:
394
+ <think>
395
+ # 1. Examine the code and input.
396
+ # 2. Walk through execution step by step.
397
+ # 3. Determine the exact output produced.
398
+ </think>
399
+ <answer>
400
+ <your output here>
401
+ </answer>
402
+
403
+ Assistant:
404
+ """
405
+
406
+ code_suffix = "\nf(<|YOUR INPUT WILL BE PLUGGED HERE|>)"
407
+
408
+ code_function_predictor_prompt = """
409
+ A conversation between User and Assistant.
410
+ The User provides a set of input/output pairs and a message describing the hidden function. The Assistant must:
411
+
412
+ 1. **Privately think step-by-step** about how to reconstruct the general function based on the provided examples.
413
+ 2. **Output exactly one** `<think>...</think>` block containing the full reasoning process.
414
+ 3. **Then output exactly one** `<answer>...</answer>` block containing **only** the Python code snippet defining the function `f`—no labels, no comments, no extra text.
415
+ 4. **Do not** generate any text outside these two blocks.
416
+ 5. Follow to the **code requirements** and **formatting rules**.
417
+
418
+ # Code Requirements:
419
+ - Name the entry function `f` (e.g., `def f(...): ...`), you may include nested definitions inside `f`.
420
+ - Ensure the function returns a value.
421
+ - Include at least one input parameter.
422
+ - Make the function deterministic.
423
+ - AVOID the FOLLOWING:
424
+ * Random functions or variables
425
+ * Date/time operations
426
+ * I/O operations (reading files, network requests)
427
+ * Printing or logging
428
+ * Any external state
429
+ - Ensure execution completes within 10 seconds on a modern CPU.
430
+ - All imports and custom class definitions must be at the very top of the code snippet.
431
+ - The snippet must end with a return statement from the main function `f`; anything after will be removed.
432
+
433
+ User:
434
+ # Input and Output Pairs:
435
+ {input_output_pairs}
436
+
437
+ # Message:
438
+ {message}
439
+
440
+ # Assistant should follow this format:
441
+ <think>
442
+ # 1. Review each input/output pair and the message to understand the pattern.
443
+ # 2. Infer the general algorithm or transformation being applied.
444
+ # 3. Outline the structure of function `f` that would reproduce all examples.
445
+ # 4. Ensure the function meets all requirements.
446
+ </think>
447
+
448
+ <answer>
449
+ def f(...):
450
+ # your code here
451
+ return ...
452
+ </answer>
453
+
454
+ Assistant:
455
+ """
456
+
457
+
458
+
459
+ # composite_requirements_prompt = "\n[IMPORTANT CRITERIA!!!] The main function `f` MUST make calls to ALL these functions {function_names} in its body, and you SHOULD NOT provide the definition of {function_names} in your output code snippet. You should first reason step by step about what these functions, {function_names}, do, then write the code snippet.\n" + '\n### The Functions that Must ALL be Called in your Code Snippet: \n```python\n{composite_functions}\n```\n'
460
+
461
+ composite_requirements_prompt = "\n[IMPORTANT CRITERIA!!!] The main function `f` MUST make calls to ALL these functions {function_names} in its body, and you SHOULD NOT provide the definition of {function_names} in your output code snippet. The function `f` should build on top of {function_names} with extra functionalities, not just a simple wrapper. You should first reason step by step about what these functions, {function_names}, do, then write the code snippet.\n" + '\n### The Functions that Must ALL be Called in your Code Snippet: \n```python\n{composite_functions}\n```\n'
462
+
463
+ remove_input_from_snippet_prompt = "- Do not have the test input anywhere in the code snippet, provide it in the input section."
464
+
465
+ remove_singleton_variables_prompt = "- All variable declarations must be inside the main function `f` or within functions `f` make calls to. Any variables declared outside of functions will be removed.\n"
466
+
467
+ def get_code_problem_generator_prompt(
468
+ problem_type: str,
469
+ reference_snippets: List[Dict[str, str]],
470
+ banned_keywords: List[str],
471
+ banned_assertion_keywords: List[str],
472
+ composite_functions: List[str] = None,
473
+ remove_after_return: bool = False,
474
+ num_inputs: int = 10,
475
+ remove_input_from_snippet: bool = False,
476
+ ) -> str:
477
+ # assert not (remove_after_return and not remove_input_from_snippet)
478
+ composite_functions = list(composite_functions)
479
+ snippet_string = ""
480
+ if problem_type != 'code_f':
481
+ output_key = 'output' if problem_type != 'code_e' else 'error'
482
+ for i, snippet in enumerate(reference_snippets):
483
+ snippet_string += f"<snippet_{i}>\n```python\n{snippet['snippet']}\n```\n```input\n{snippet['input']}\n```\n```{output_key}\n{snippet['output']}\n```\n</snippet_{i}>\n"
484
+ if problem_type == "code_i":
485
+ return code_input_prompt.format(
486
+ remove_after_return_prompt=(remove_singleton_variables_prompt if remove_after_return else '\n'),
487
+ remove_input_from_snippet_prompt=(remove_input_from_snippet_prompt if remove_input_from_snippet else '')
488
+ ).replace(
489
+ '<|BANNED_KEYWORDS|>', ', '.join(banned_keywords)
490
+ ) + snippet_string + (
491
+ composite_requirements_prompt.format(
492
+ function_names=', '.join([f'`g_{i}`' for i in range(len(composite_functions))]),
493
+ composite_functions="\n".join([d['snippet'] for d in composite_functions])
494
+ ) if composite_functions else '\n'
495
+ )
496
+ elif problem_type == "code_o":
497
+ return code_output_prompt.format(
498
+ remove_after_return_prompt=(remove_singleton_variables_prompt if remove_after_return else '\n'),
499
+ remove_input_from_snippet_prompt=(remove_input_from_snippet_prompt if remove_input_from_snippet else '')
500
+ ).replace(
501
+ '<|BANNED_KEYWORDS|>', ', '.join(banned_keywords)
502
+ ) + snippet_string + (
503
+ composite_requirements_prompt.format(
504
+ function_names=', '.join([f'`g_{i}`' for i in range(len(composite_functions))]),
505
+ composite_functions="\n".join([d['snippet'] for d in composite_functions])
506
+ ) if composite_functions else '\n'
507
+ )
508
+ elif problem_type == "code_f":
509
+ return code_function_prompt.format(
510
+ num_inputs=num_inputs,
511
+ snippet=reference_snippets[0]['snippet'] + code_suffix,
512
+ )
513
+ elif problem_type == "code_e":
514
+ if banned_assertion_keywords:
515
+ assertion_keywords_string = '- The following error handling keywords are not allowed to be used in the code snippet: ' + ', '.join(banned_assertion_keywords) + '\n'
516
+ else:
517
+ assertion_keywords_string = '\n'
518
+ return code_error_prompt.format(
519
+ remove_after_return_prompt=(remove_singleton_variables_prompt if remove_after_return else '\n'),
520
+ ).replace(
521
+ '<|BANNED_KEYWORDS|>', ', '.join(banned_keywords)
522
+ ).replace(
523
+ '<|BANNED_ASSERTION_KEYWORDS|>', assertion_keywords_string
524
+ ) + snippet_string + (
525
+ composite_requirements_prompt.format(
526
+ function_names=', '.join([f'`g_{i}`' for i in range(len(composite_functions))]),
527
+ composite_functions="\n".join([d['snippet'] for d in composite_functions])
528
+ ) if composite_functions else '\n'
529
+ )
530
+ else:
531
+ raise ValueError(f"Invalid problem type: {problem_type}")
532
+
533
+ def get_code_problem_predictor_prompt(problem_type: str, snippet: str, input_args: str = None, output: str = None, message: str = None, input_output_pairs: List[Tuple[str, str]] = None) -> str:
534
+ if problem_type.endswith("code_i"):
535
+ return code_input_predictor_prompt.format(snippet=snippet, output=output)
536
+ elif problem_type.endswith("code_o"):
537
+ return code_output_predictor_prompt.format(snippet=snippet, input_args=input_args)
538
+ elif problem_type.endswith("code_f"):
539
+ input_output_pairs_string = ""
540
+ for i, (input, output) in enumerate(input_output_pairs):
541
+ input_output_pairs_string += f"```input_{i}\n{input}\n```\n```output_{i}\n{output}\n```\n"
542
+ return code_function_predictor_prompt.format(input_output_pairs=input_output_pairs_string, message=message)
543
+ elif problem_type.endswith("code_e"):
544
+ return code_error_predictor_prompt.format(snippet=snippet, input_args=input_args)
545
+ else:
546
+ raise ValueError(f"Invalid problem type: {problem_type}")
absolute_zero_reasoner/main_azr_ppo.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
16
+ """
17
+ import ray
18
+ import hydra
19
+ from pathlib import Path
20
+ from pprint import pprint
21
+
22
+ from omegaconf import OmegaConf
23
+ from verl.utils.fs import copy_local_path_from_hdfs
24
+ from verl.utils import hf_tokenizer
25
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
26
+
27
+ from absolute_zero_reasoner.trainer.ppo.azr_ray_trainer import CodeIORayPPOTrainer
28
+ from absolute_zero_reasoner.rewards.reward_managers import CodeIORewardManager
29
+
30
+
31
+ @hydra.main(config_path='configs', config_name='azr_ppo_trainer', version_base=None)
32
+ def main(config):
33
+ run_ppo(config)
34
+
35
+
36
+ # Define a function to run the PPO-like training process
37
+ def run_ppo(config) -> None:
38
+ # Check if Ray is not initialized
39
+ if not ray.is_initialized():
40
+ # Initialize Ray with a local cluster configuration
41
+ # Set environment variables in the runtime environment to control tokenizer parallelism,
42
+ # NCCL debug level, VLLM logging level, and allow runtime LoRA updating
43
+ # `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
44
+ import os
45
+ cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
46
+ ray.init(
47
+ runtime_env={"env_vars": {
48
+ "TOKENIZERS_PARALLELISM": "true",
49
+ "NCCL_DEBUG": "WARN",
50
+ "VLLM_LOGGING_LEVEL": "WARN",
51
+ "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true",
52
+ "CUDA_VISIBLE_DEVICES": cuda_visible_devices
53
+ }},
54
+ num_cpus=config.ray_init.num_cpus,
55
+ )
56
+
57
+ # Create a remote instance of the TaskRunner class, and
58
+ # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
59
+ if OmegaConf.select(config.trainer, "profile_steps") is not None and len(OmegaConf.select(config.trainer, "profile_steps")) > 0:
60
+ nsight_options = OmegaConf.to_container(config.trainer.controller_nsight_options)
61
+ runner = TaskRunner.options(runtime_env={
62
+ "nsight": nsight_options,
63
+ "env_vars": {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}
64
+ }).remote()
65
+ else:
66
+ runner = TaskRunner.options(runtime_env={
67
+ "env_vars": {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}
68
+ }).remote()
69
+ ray.get(runner.run.remote(config))
70
+
71
+ # [Optional] get the path of the timeline trace file from the configuration, default to None
72
+ # This file is used for performance analysis
73
+ timeline_json_file = config.ray_init.get("timeline_json_file", None)
74
+ if timeline_json_file:
75
+ ray.timeline(filename=timeline_json_file)
76
+
77
+
78
+ @ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
79
+ class TaskRunner:
80
+ def run(self, config):
81
+ pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
82
+ OmegaConf.resolve(config)
83
+
84
+ if config.trainer.debug:
85
+ import debugpy
86
+ debugpy.listen(("0.0.0.0", config.trainer.debug_port))
87
+ print(f"Debugger listening on port {config.trainer.debug_port}")
88
+ debugpy.wait_for_client()
89
+ print("Debugger attached!")
90
+
91
+ # generator one batch, solver one batch
92
+ config.actor_rollout_ref.actor.ppo_mini_batch_size = config.data.train_batch_size * len(config.azr.problem_types) * (2 if config.azr.train_propose else 1)
93
+ pprint(f"auto setting ppo_mini_batch_size: {config.actor_rollout_ref.actor.ppo_mini_batch_size}")
94
+ config.azr.data_selection_strategy.data_len = config.data.train_batch_size * config.azr.data_selection_strategy.update_iteration
95
+ pprint(f"auto setting data_len: {config.azr.data_selection_strategy.data_len}")
96
+
97
+ config.trainer.default_local_dir = (Path(config.trainer.default_local_dir) / config.data.train_files.split('/')[-1].split('.')[0] / config.actor_rollout_ref.model.path.split('/')[-1] / config.reward_fn.extraction_type).as_posix()
98
+
99
+ assert not (not config.azr.reward.generation_reward_config.reject_multiple_functions and config.azr.data_selection_strategy.composite_function_n_min > 0), "If reject_multiple_functions is False, composite_function_n_min must be 0"
100
+
101
+ # download the checkpoint from hdfs
102
+ local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
103
+
104
+ # Instantiate the tokenizer and processor.
105
+ from verl.utils import hf_processor, hf_tokenizer
106
+
107
+ trust_remote_code = config.data.get("trust_remote_code", False)
108
+ tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
109
+
110
+ # base model chat template
111
+ if config.actor_rollout_ref.model.pretrained_tokenizer:
112
+ tokenizer.chat_template = "{%- for message in messages -%}{{- '\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}"
113
+
114
+ # Used for multimodal LLM, could be None
115
+ processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
116
+
117
+ # Version validation for vllm.
118
+ if config.actor_rollout_ref.rollout.name in ["vllm"]:
119
+ from verl.utils.vllm_utils import is_version_ge
120
+
121
+ if config.actor_rollout_ref.model.get("lora_rank", 0) > 0:
122
+ if not is_version_ge(pkg="vllm", minver="0.7.3"):
123
+ raise NotImplementedError("PPO LoRA is not supported before vllm 0.7.3")
124
+
125
+ # Define worker classes based on the actor strategy.
126
+ if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
127
+ assert config.critic.strategy in ["fsdp", "fsdp2"]
128
+ from verl.single_controller.ray import RayWorkerGroup
129
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
130
+
131
+ actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
132
+ ray_worker_group_cls = RayWorkerGroup
133
+
134
+ elif config.actor_rollout_ref.actor.strategy == "megatron":
135
+ assert config.actor_rol# lout_ref.actor.strategy == config.critic.strategy
136
+ from verl.single_controller.ray.megatron import NVMegatronRayWorkerGroup
137
+ from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
138
+
139
+ actor_rollout_cls = AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
140
+ ray_worker_group_cls = NVMegatronRayWorkerGroup
141
+
142
+ else:
143
+ raise NotImplementedError
144
+
145
+ from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
146
+
147
+ # Map roles to their corresponding remote worker classes.
148
+ role_worker_mapping = {
149
+ Role.ActorRollout: ray.remote(actor_rollout_cls),
150
+ Role.Critic: ray.remote(CriticWorker),
151
+ }
152
+
153
+ # Define the resource pool specification.
154
+ # Map roles to the resource pool.
155
+ global_pool_id = "global_pool"
156
+ resource_pool_spec = {
157
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
158
+ }
159
+ mapping = {
160
+ Role.ActorRollout: global_pool_id,
161
+ Role.Critic: global_pool_id,
162
+ }
163
+
164
+ # We should adopt a multi-source reward function here:
165
+ # - for rule-based rm, we directly call a reward score
166
+ # - for model-based rm, we call a model
167
+ # - for code related prompt, we send to a sandbox if there are test cases
168
+ # finally, we combine all the rewards together
169
+ # The reward type depends on the tag of the data
170
+ if config.reward_model.enable:
171
+ if config.reward_model.strategy in ["fsdp", "fsdp2"]:
172
+ from verl.workers.fsdp_workers import RewardModelWorker
173
+ elif config.reward_model.strategy == "megatron":
174
+ from verl.workers.megatron_workers import RewardModelWorker
175
+ else:
176
+ raise NotImplementedError
177
+ role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
178
+ mapping[Role.RewardModel] = global_pool_id
179
+
180
+ # Add a reference policy worker if KL loss or KL reward is used.
181
+ if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
182
+ role_worker_mapping[Role.RefPolicy] = ray.remote(ActorRolloutRefWorker)
183
+ mapping[Role.RefPolicy] = global_pool_id
184
+
185
+ reward_fn = CodeIORewardManager(
186
+ tokenizer=tokenizer,
187
+ num_examine=0,
188
+ reward_fn_extraction_type=config.reward_fn.extraction_type,
189
+ math_metric=config.reward_fn.math_metric,
190
+ split='train',
191
+ splitter=config.reward_fn.splitter,
192
+ output_path=config.trainer.default_local_dir,
193
+ max_prompt_length=config.data.max_prompt_length,
194
+ generation_reward_config=config.azr.reward.generation_reward_config,
195
+ valid_program_filter=config.azr.data_selection_strategy.valid_program_filter,
196
+ debug=config.trainer.debug,
197
+ extract_code_block=config.azr.reward.extract_code_block,
198
+ code_f_reward_type=config.azr.reward.code_f_reward_type,
199
+ boxed_retry=config.reward_fn.boxed_retry,
200
+ )
201
+
202
+ # Note that we always use function-based RM for validation
203
+ val_reward_fn = CodeIORewardManager(
204
+ tokenizer=tokenizer,
205
+ num_examine=1,
206
+ reward_fn_extraction_type=config.reward_fn.extraction_type,
207
+ math_metric=config.reward_fn.math_metric,
208
+ split='test',
209
+ splitter=config.reward_fn.splitter,
210
+ output_path=config.trainer.default_local_dir,
211
+ max_prompt_length=config.data.max_prompt_length,
212
+ generation_reward_config=config.azr.reward.generation_reward_config,
213
+ valid_program_filter=config.azr.data_selection_strategy.valid_program_filter,
214
+ debug=config.trainer.debug,
215
+ extract_code_block=config.azr.reward.extract_code_block,
216
+ code_f_reward_type=config.azr.reward.code_f_reward_type,
217
+ boxed_retry=config.reward_fn.boxed_retry,
218
+ )
219
+
220
+ resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
221
+
222
+ wandb_tags = [
223
+ 'codeio', config.azr.pred_data_mix_strategy, 'executor-' + config.azr.executor,
224
+ config.azr.data_selection_strategy.valid_program_filter, config.azr.gen_data_probabilities_strategy,
225
+ ]
226
+ wandb_tags.extend(config.azr.problem_types)
227
+ if config.trainer.wandb_tags is not None:
228
+ config.trainer.wandb_tags = wandb_tags + config.trainer.wandb_tags.split(',')
229
+ else:
230
+ config.trainer.wandb_tags = wandb_tags
231
+
232
+ trainer = CodeIORayPPOTrainer(
233
+ past_epoch_window=config.azr.past_epoch_window,
234
+ config=config,
235
+ tokenizer=tokenizer,
236
+ processor=processor,
237
+ role_worker_mapping=role_worker_mapping,
238
+ resource_pool_manager=resource_pool_manager,
239
+ ray_worker_group_cls=ray_worker_group_cls,
240
+ reward_fn=reward_fn,
241
+ val_reward_fn=val_reward_fn,
242
+ )
243
+
244
+ trainer.init_workers()
245
+ trainer.fit()
246
+
247
+
248
+ if __name__ == '__main__':
249
+ try:
250
+ main()
251
+ except KeyboardInterrupt:
252
+ import sys
253
+ import traceback
254
+ traceback.print_exc()
255
+ sys.exit(0)
256
+ except Exception as e:
257
+ import os
258
+ import traceback
259
+ traceback.print_exc()
260
+ os._exit(1)
absolute_zero_reasoner/rewards/__init__.py ADDED
File without changes
absolute_zero_reasoner/rewards/code_reward.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/huggingface/open-r1
3
+ """
4
+
5
+ import re
6
+ import json
7
+ from typing import Dict, Any, List, Tuple
8
+ import ast
9
+ import difflib
10
+ import json
11
+
12
+ from complexipy import code_complexity
13
+ import black
14
+ import autopep8
15
+
16
+ from absolute_zero_reasoner.utils.code_utils.parsers import (
17
+ parse_imports,
18
+ remove_comments_and_docstrings,
19
+ remove_any_not_definition_imports,
20
+ remove_print_statements,
21
+ )
22
+
23
+
24
+ def format_python_code(code: str) -> str:
25
+ """Formats Python code with proper indentation using autopep8."""
26
+ try:
27
+ # First try to use black for formatting
28
+ formatted = black.format_str(code, mode=black.Mode())
29
+ return formatted
30
+ except:
31
+ # Fallback to a simpler approach that handles the specific test case
32
+ # Parse the code line by line
33
+ formatted_lines = []
34
+ in_function = False
35
+ function_indent = 0
36
+ empty_line_after_return = False
37
+
38
+ for line in code.split('\n'):
39
+ stripped = line.strip()
40
+
41
+ # Skip empty lines but remember them for context
42
+ if not stripped:
43
+ if in_function and empty_line_after_return:
44
+ # Empty line after return statement likely means end of function
45
+ in_function = False
46
+ formatted_lines.append('')
47
+ continue
48
+
49
+ # Detect function definition
50
+ if stripped.startswith('def ') and stripped.endswith(':'):
51
+ in_function = True
52
+ function_indent = 0
53
+ formatted_lines.append(stripped)
54
+ continue
55
+
56
+ # Handle indentation inside functions
57
+ if in_function:
58
+ # Check for return statement
59
+ if stripped.startswith('return '):
60
+ formatted_lines.append(' ' + stripped)
61
+ empty_line_after_return = True
62
+ continue
63
+
64
+ # Check if this is likely a line outside the function
65
+ if empty_line_after_return and not stripped.startswith((' ', '\t')):
66
+ in_function = False
67
+ formatted_lines.append(stripped)
68
+ continue
69
+
70
+ # Regular function body line
71
+ formatted_lines.append(' ' + stripped)
72
+ else:
73
+ # Line outside any function
74
+ formatted_lines.append(stripped)
75
+
76
+ # Apply autopep8 for final cleanup
77
+ return autopep8.fix_code(
78
+ '\n'.join(formatted_lines),
79
+ options={'aggressive': 1, 'indent_size': 4}
80
+ )
81
+
82
+
83
+ def extract_code(completion: str) -> str:
84
+ pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
85
+ matches = pattern.findall(completion)
86
+ extracted_answer = matches[-1] if len(matches) >= 1 else ""
87
+ return extracted_answer
88
+
89
+
90
+ def parse_to_ast(code_snippet: str) -> ast.AST:
91
+ """
92
+ Parse a Python code snippet into an Abstract Syntax Tree (AST).
93
+
94
+ Args:
95
+ code_snippet: A string containing Python code
96
+
97
+ Returns:
98
+ An AST object representing the code
99
+
100
+ Raises:
101
+ SyntaxError: If the code snippet contains syntax errors
102
+ """
103
+ try:
104
+ return ast.parse(code_snippet)
105
+ except SyntaxError as e:
106
+ print(f"Syntax error in code: {e}")
107
+ raise
108
+
109
+
110
+ def ast_to_dict(node: ast.AST) -> Dict[str, Any]:
111
+ """
112
+ Convert an AST node to a dictionary representation for easier comparison.
113
+
114
+ Args:
115
+ node: An AST node
116
+
117
+ Returns:
118
+ A dictionary representing the node and its children
119
+ """
120
+ if isinstance(node, ast.AST):
121
+ # Extract node type and fields
122
+ result = {"node_type": node.__class__.__name__}
123
+
124
+ # Add children nodes
125
+ for field, value in ast.iter_fields(node):
126
+ if field == "ctx": # Skip context objects as they vary unnecessarily
127
+ continue
128
+
129
+ # Handle different types of field values
130
+ if isinstance(value, list):
131
+ result[field] = [ast_to_dict(item) for item in value if isinstance(item, ast.AST)]
132
+ elif isinstance(value, ast.AST):
133
+ result[field] = ast_to_dict(value)
134
+ elif value is not None:
135
+ # Keep primitive values unchanged
136
+ result[field] = value
137
+
138
+ return result
139
+ else:
140
+ return {"value": str(node)}
141
+
142
+
143
+ def ast_edit_distance(code1: str, code2: str) -> float:
144
+ """
145
+ Calculate the edit distance between two Abstract Syntax Trees.
146
+
147
+ Args:
148
+ ast1: First AST
149
+ ast2: Second AST
150
+
151
+ Returns:
152
+ A float value representing the normalized edit distance (0.0 = identical, 1.0 = completely different)
153
+ """
154
+ try:
155
+ ast1 = parse_to_ast(format_python_code(code1))
156
+ ast2 = parse_to_ast(format_python_code(code2))
157
+
158
+ # Convert ASTs to dictionary representation
159
+ dict1 = ast_to_dict(ast1)
160
+ dict2 = ast_to_dict(ast2)
161
+
162
+ # Convert to strings for difflib comparison
163
+ str1 = json.dumps(dict1, sort_keys=True, indent=2)
164
+ str2 = json.dumps(dict2, sort_keys=True, indent=2)
165
+
166
+ # Calculate similarity ratio using difflib
167
+ similarity = difflib.SequenceMatcher(None, str1, str2).ratio()
168
+
169
+ # Convert similarity to distance (1.0 - similarity)
170
+ distance = 1.0 - similarity
171
+
172
+ return distance
173
+ except Exception as e:
174
+ print(f"Error in ast_edit_distance: {e}")
175
+ return 0.0
176
+
177
+
178
+ def ast_edit_operations(ast1: ast.AST, ast2: ast.AST) -> List[Dict[str, Any]]:
179
+ """
180
+ Generate a list of edit operations needed to transform ast1 into ast2.
181
+
182
+ Args:
183
+ ast1: First AST
184
+ ast2: Second AST
185
+
186
+ Returns:
187
+ A list of edit operations (insert, delete, modify)
188
+ """
189
+ # Convert ASTs to dictionary representation
190
+ dict1 = ast_to_dict(ast1)
191
+ dict2 = ast_to_dict(ast2)
192
+
193
+ # Convert to strings for difflib comparison
194
+ str1 = json.dumps(dict1, sort_keys=True, indent=2).splitlines()
195
+ str2 = json.dumps(dict2, sort_keys=True, indent=2).splitlines()
196
+
197
+ # Calculate differences
198
+ diff = list(difflib.unified_diff(str1, str2, n=0))
199
+
200
+ # Parse diff into operations
201
+ operations = []
202
+ for line in diff[2:]: # Skip the header lines
203
+ if line.startswith('+'):
204
+ operations.append({
205
+ "operation": "insert",
206
+ "content": line[1:].strip()
207
+ })
208
+ elif line.startswith('-'):
209
+ operations.append({
210
+ "operation": "delete",
211
+ "content": line[1:].strip()
212
+ })
213
+ elif line.startswith(' '):
214
+ # Context line, no operation needed
215
+ pass
216
+
217
+ return operations
218
+
219
+
220
+ def get_code_complexity_reward(code_snippet: str) -> float:
221
+ """
222
+ Calculate the complexity of a Python code snippet using the `code_complexity` function from the `complexipy` library.
223
+
224
+ Args:
225
+ code_snippet: A string containing Python code
226
+
227
+ Returns:
228
+ A float value representing the complexity of the code snippet
229
+ """
230
+ try:
231
+ return code_complexity(format_python_code(code_snippet)).complexity / 15
232
+ except Exception as e:
233
+ return 0.0
234
+
235
+
236
+ def get_halstead_reward(code_snippet: str,
237
+ effort_max: float = 10000,
238
+ complexity_max: float = 10,
239
+ volume_max: float = 500) -> float:
240
+ """
241
+ Calculate the Halstead reward for a Python code snippet.
242
+
243
+ Args:
244
+ code_snippet: A string containing Python code
245
+
246
+ Returns:
247
+ A float value representing the Halstead reward of the code snippet
248
+ """
249
+ try:
250
+ from radon.metrics import h_visit
251
+ from radon.complexity import cc_visit
252
+
253
+ code = format_python_code(code_snippet)
254
+
255
+ h = h_visit(code).total
256
+ effort = h.effort
257
+ volume = h.volume
258
+ cc_blocks = cc_visit(code)
259
+ complexity = max((b.complexity for b in cc_blocks), default=1)
260
+ effort_norm = min(effort / effort_max, 1.0)
261
+ complexity_norm = min(complexity / complexity_max, 1.0)
262
+ volume_norm = min(volume / volume_max, 1.0)
263
+
264
+ w1, w2, w3 = 0.5, 0.3, 0.2
265
+
266
+ score = w1 * effort_norm + w2 * complexity_norm + w3 * volume_norm
267
+ return round(score, 3)
268
+ except Exception as e:
269
+ return 0.0
270
+
271
+
272
+ def has_test_input(snippet_code: str) -> bool:
273
+ test_patterns = [
274
+ r"(?i)#\s*(test|example)", # Match any test/example comment
275
+ r"\b(input|test_input|sample_input)\b\s*=", # Common test variable names
276
+ r"\b\w*input\w*\s*=\s*", # Match any variable containing "input"
277
+ r"\b(expected|output|result)\s*=\s*",
278
+ r"\bassert\b",
279
+ r"print\s*\(\s*f\(",
280
+ r"f\(\[.*\]\)",
281
+ r"f\([^)]*\)\s*(#|$)",
282
+ r"^\s*input\s*$", # Match lines containing only "input"
283
+ ]
284
+
285
+ return any(
286
+ re.search(pattern, snippet_code, re.MULTILINE)
287
+ for pattern in test_patterns
288
+ )
289
+
290
+
291
+ def parse_code_input_output(
292
+ input_str: str,
293
+ parse_input: bool = True,
294
+ parse_output: bool = True,
295
+ remove_after_return: bool = False,
296
+ remove_comments: bool = False,
297
+ remove_print: bool = False,
298
+ reject_multiple_functions: bool = True,
299
+ reject_test_input_in_code: bool = False,
300
+ f_replace_location: str = 'not_first',
301
+ code_location: str = 'first',
302
+ ) -> Tuple[bool, Dict[str, str]]:
303
+ """
304
+ Parse the input and output of a code snippet.
305
+
306
+ Args:
307
+ input_str: A string containing the code snippet
308
+ parse_input: Whether to parse the input
309
+ parse_output: Whether to parse the output
310
+ """
311
+ # Improved regex patterns with better whitespace handling and optional language specifiers
312
+ code_pattern = r"```(?:python\s*)?\n?(.*?)\n?```"
313
+ input_pattern = r"```input\s*\n?(.*?)\n?```"
314
+ output_pattern = r"```output\s*\n?(.*?)\n?```"
315
+
316
+ # Use flags for case-insensitive matching and dotall
317
+ flags = re.DOTALL | re.IGNORECASE
318
+
319
+ if code_location == 'last':
320
+ code_matches = list(re.finditer(code_pattern, input_str, flags))
321
+ if not code_matches:
322
+ code_match = None
323
+ else:
324
+ code_match = code_matches[-1]
325
+ elif code_location == 'first':
326
+ code_match = re.search(code_pattern, input_str, flags)
327
+ else:
328
+ raise ValueError(f"Invalid code_location: {code_location}. Must be 'first' or 'last'.")
329
+
330
+ # Check required blocks
331
+ if parse_input:
332
+ input_match = re.search(input_pattern, input_str, flags)
333
+ if not input_match:
334
+ # Try alternative pattern without explicit input block
335
+ input_match = re.search(r"# Input:\s*(.*?)(?=\n```|$)", input_str, flags)
336
+ if parse_output:
337
+ output_match = re.search(output_pattern, input_str, flags)
338
+ if not output_match:
339
+ # Try alternative pattern without explicit output block
340
+ output_match = re.search(r"# Output:\s*(.*?)(?=\n```|$)", input_str, flags)
341
+
342
+ # Validate required components
343
+ if not code_match or (parse_input and not input_match) or (parse_output and not output_match):
344
+ return False, {}
345
+
346
+ # Extract and clean components
347
+ code_snippet = code_match.group(1).strip()
348
+ input_snippet = input_match.group(1).strip() if parse_input else ""
349
+ output_snippet = output_match.group(1).strip() if parse_output else ""
350
+
351
+ # Enhanced function detection and validation
352
+ function_defs = re.findall(r"^\s*def\s+(\w+)\s*\(", code_snippet, re.MULTILINE)
353
+ if not function_defs:
354
+ return False, {}
355
+
356
+ if reject_multiple_functions and len(function_defs) > 1:
357
+ return False, {} # Reject multiple function definitions
358
+
359
+ if reject_test_input_in_code and has_test_input(code_snippet):
360
+ return False, {}
361
+
362
+ # Standardize function name to 'f'
363
+ if f_replace_location == 'not_first':
364
+ original_name = function_defs[0]
365
+ elif f_replace_location == 'any_last':
366
+ original_name = function_defs[-1] if 'f' not in function_defs else 'f'
367
+ elif f_replace_location == 'any_first':
368
+ original_name = function_defs[0] if 'f' not in function_defs else 'f'
369
+ elif f_replace_location == 'not_last':
370
+ original_name = function_defs[-1]
371
+ else:
372
+ raise ValueError(f'Invalid f_replace_location: {f_replace_location}')
373
+ if original_name != 'f':
374
+ code_snippet = re.sub(
375
+ rf"def\s+{re.escape(original_name)}\s*\(",
376
+ "def f(",
377
+ code_snippet,
378
+ count=0
379
+ )
380
+ # Replace all calls to the function as well (for recursive functions)
381
+ code_snippet = re.sub(
382
+ rf"\b{re.escape(original_name)}\s*\(",
383
+ "f(",
384
+ code_snippet
385
+ )
386
+
387
+ imports: List[str] = parse_imports(code_snippet)
388
+
389
+ # before_remove_comments = code_snippet
390
+ # remove comments and docstrings
391
+ if remove_comments:
392
+ code_snippet = remove_comments_and_docstrings(code_snippet)
393
+
394
+ # remove anything after return
395
+ if remove_after_return:
396
+ code_snippet = remove_any_not_definition_imports(code_snippet)
397
+
398
+ # remove print statements
399
+ if remove_print:
400
+ code_snippet = remove_print_statements(code_snippet)
401
+
402
+ # if before_remove_comments != code_snippet:
403
+ # with open("changed_content.jsonl", "a") as f:
404
+ # f.write(json.dumps({"before": before_remove_comments, "after": code_snippet}) + "\n")
405
+ return True, {"code": code_snippet, "input": input_snippet, "output": output_snippet, "imports": imports}
406
+
407
+
408
+ def parse_inputs_message(
409
+ input_str: str,
410
+ num_inputs: int,
411
+ ) -> Tuple[bool, Dict[str, Any]]:
412
+ """
413
+ Parse the last num_inputs inputs and message from a string.
414
+
415
+ Args:
416
+ input_str: A string containing the inputs and message
417
+ num_inputs: Number of most recent inputs to parse
418
+
419
+ Returns:
420
+ A tuple of (success, dict) where dict contains:
421
+ - inputs: List of last num_inputs input strings
422
+ - message: The message string
423
+ Returns (False, {}) if there aren't enough inputs or message is missing
424
+ """
425
+ # Improved regex patterns with better whitespace handling and optional language specifiers
426
+ input_pattern = r"```input\s*\n?(.*?)\n?```"
427
+ message_pattern = r"```message\s*\n?(.*?)\n?```"
428
+
429
+ # Use flags for case-insensitive matching and dotall
430
+ flags = re.DOTALL | re.IGNORECASE
431
+
432
+ # Check required blocks
433
+ input_matches = re.finditer(input_pattern, input_str, flags)
434
+ if not input_matches:
435
+ # Try alternative pattern without explicit input block
436
+ input_matches = re.finditer(r"# Input:\s*(.*?)(?=\n```|$)", input_str, flags)
437
+
438
+ # Get all inputs and take the last num_inputs
439
+ inputs = [match.group(1).strip() for match in input_matches]
440
+
441
+ # Return early if not enough inputs
442
+ if len(inputs) < num_inputs:
443
+ return False, {}
444
+
445
+ inputs = inputs[-num_inputs:] # Take last num_inputs
446
+
447
+ message_match = re.search(message_pattern, input_str, flags)
448
+
449
+ # Try parsing message between <message> </message> tags if previous methods failed
450
+ if not message_match:
451
+ message_match = re.search(r"<message>\s*(.*?)\s*</message>", input_str, flags)
452
+
453
+ if not message_match:
454
+ # Try alternative pattern without explicit message block
455
+ message_match = re.search(r"# Message:\s*(.*?)(?=\n```|$)", input_str, flags)
456
+
457
+ # Return early if message not found
458
+ if not message_match:
459
+ return False, {}
460
+
461
+ # Extract and clean message
462
+ message = message_match.group(1).strip()
463
+
464
+ return True, {"inputs": inputs, "message": message}
465
+
466
+
467
+ def parse_code_function(input_str: str) -> Tuple[bool, str]:
468
+ """
469
+ Parse the code function from a string.
470
+
471
+ Args:
472
+ input_str: A string containing the code function
473
+ """
474
+ # Improved regex patterns with better whitespace handling and optional language specifiers
475
+ code_pattern = r"```(?:python\s*)?\n?(.*?)\n?```"
476
+
477
+ flags = re.DOTALL | re.IGNORECASE
478
+
479
+ # find and output the last code block in the input string
480
+ code_matches = list(re.finditer(code_pattern, input_str, flags))
481
+ if not code_matches:
482
+ return False, ''
483
+ code_snippet = code_matches[-1].group(1).strip()
484
+
485
+ return True, code_snippet
486
+
487
+
488
+ def valid_code(solution_str: str, executor, banned_words: List[str]) -> Tuple[bool, str]:
489
+ success, result = parse_code_input_output(solution_str, parse_output=False)
490
+ if success:
491
+ try:
492
+ output, status = executor.apply(result['code'] + f'\nf({result["input"]})')
493
+ if 'error' in status.lower():
494
+ return False, None
495
+ for banned_word in banned_words:
496
+ if banned_word.lower() in result['code'].lower():
497
+ return False, None
498
+ return True, output
499
+ except Exception:
500
+ return False, None
501
+ return False, None
502
+
503
+
504
+ def get_type_counts_reward(answer: str, type_counters: Dict[str, Dict[str, int]], hierarchical: bool = False) -> float:
505
+ """
506
+ Calculate the type counts reward for a Python code snippet.
507
+
508
+ Args:
509
+ answer: A string containing the answer
510
+ type_counters: A dictionary of type counters
511
+ hierarchical: Whether to use hierarchical type counts
512
+ """
513
+ if hierarchical:
514
+ # we do not flatten we first have a distribution of the types, then we have a distribution of the elements within each type
515
+ # we want to maximize the suprise of the answer
516
+ # first, we get the distribution of the types
517
+ type_distribution = {}
518
+ for key, value in type_counters.items():
519
+ type_distribution[key] = sum(value.values())
520
+
521
+ # try to get the type, if failed default it as a string
522
+ try:
523
+ answer_type = type(eval(answer)).__name__
524
+ except:
525
+ answer_type = 'str'
526
+
527
+ # then, we get the "suprise" of the answer, sum of 1 - probability of answer_type and 1 - probability of the element within the type
528
+ suprise = 0
529
+ if answer_type in type_distribution:
530
+ suprise += 1 - (type_distribution[answer_type] / sum(type_distribution.values()))
531
+ else:
532
+ suprise += 1.0
533
+ if answer_type in type_counters:
534
+ if answer in type_counters[answer_type]:
535
+ suprise += 1 - (type_counters[answer_type][answer] / sum(type_counters[answer_type].values()))
536
+ else:
537
+ suprise += 1.0
538
+ else:
539
+ suprise += 1.0
540
+ return suprise / 2
541
+ else:
542
+ # first flatten the type_counters, use the counts of each element as a categorical distribution, then, we get the "suprise" of the answer
543
+ # we want to maximize the suprise
544
+ # first, flatten the type_counters
545
+ flattened_type_counters = {}
546
+ for _, value in type_counters.items():
547
+ for sub_key, sub_value in value.items():
548
+ flattened_type_counters[sub_key] = sub_value
549
+ # then, we get the "suprise" of the answer
550
+
551
+ if answer in flattened_type_counters:
552
+ suprise = 1 - (flattened_type_counters[answer] / sum(flattened_type_counters.values()))
553
+ return suprise
554
+ return 1.0
absolute_zero_reasoner/rewards/custom_evaluate.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py
15
+
16
+ import re
17
+ from collections import Counter
18
+ from typing import Tuple, List, Dict
19
+
20
+ from math_verify import parse, verify
21
+
22
+ from absolute_zero_reasoner.rewards.math_utils import grade_answer_mathd, grade_answer_sympy
23
+
24
+
25
+ def choice_answer_clean(pred: str):
26
+ """https://github.com/hkust-nlp/simpleRL-reason/blob/main/eval/grader.py"""
27
+ pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
28
+ # Clean the answer based on the dataset
29
+ tmp = re.findall(r"\b(A|B|C|D|E|F|G|H|I|J|K|L|M|N|O|P|Q|R|S|T|U|V|W|X|Y|Z)\b", pred.upper())
30
+ if tmp:
31
+ pred = tmp
32
+ else:
33
+ pred = [pred.strip().strip(".")]
34
+ pred = pred[-1]
35
+ # Remove the period at the end, again!
36
+ pred = pred.rstrip(".").rstrip("/")
37
+ return pred
38
+
39
+
40
+ def extract_code(completion: str, language: str = "python") -> str:
41
+ pattern = re.compile(rf"```{language}\n(.*?)```", re.DOTALL)
42
+ matches = pattern.findall(completion)
43
+ extracted_answer = matches[-1] if len(matches) >= 1 else ""
44
+ return extracted_answer
45
+
46
+
47
+ def get_gt_reward(solution_str: str, ground_truth: str, extraction_type: str, metric: str, math_metric: str = 'deepscaler', boxed_retry: bool = False) -> float:
48
+ answer = extract_answer(solution_str, extraction_type, boxed_retry=boxed_retry)
49
+ if metric == 'mc':
50
+ mc_answer = choice_answer_clean(answer)
51
+ if mc_answer == ground_truth:
52
+ return 1.0
53
+ if grade_answer_sympy(answer, ground_truth) or grade_answer_mathd(answer, ground_truth):
54
+ return 1.0
55
+ return 0.0
56
+ elif metric == 'math':
57
+ if math_metric == 'math_verify':
58
+ gold = parse('\\boxed{' + ground_truth + '}')
59
+ answer = parse('\\boxed{' + answer + '}')
60
+ return 1.0 if verify(gold, answer) else 0.0
61
+ elif math_metric == 'deepscaler':
62
+ if grade_answer_sympy(answer, ground_truth) or grade_answer_mathd(answer, ground_truth):
63
+ return 1.0
64
+ return 0.0
65
+ elif math_metric == 'union':
66
+ math_verify_gold = parse('\\boxed{' + ground_truth + '}')
67
+ math_verify_answer = parse('\\boxed{' + answer + '}')
68
+ if grade_answer_sympy(answer, ground_truth) or grade_answer_mathd(answer, ground_truth) or verify(math_verify_gold, math_verify_answer):
69
+ return 1.0
70
+ return 0.0
71
+ else:
72
+ raise ValueError(f"Invalid math metric: {math_metric}")
73
+ elif metric == 'code_eval':
74
+ try:
75
+ answer = eval(answer.strip())
76
+ except Exception:
77
+ return 0.0
78
+ ground_truth = eval(ground_truth.strip())
79
+ if answer == ground_truth:
80
+ return 1.0
81
+ return 0.0
82
+ else:
83
+ raise ValueError(f"Invalid metric: {metric}")
84
+
85
+
86
+ def extract_answer(solution_str: str, extraction_type: str, boxed_retry: bool = False) -> str:
87
+ if extraction_type.startswith('answer'):
88
+ if "<answer>" in solution_str:
89
+ answer = solution_str.split("<answer>")[-1].split("</answer>")[0]
90
+ else:
91
+ if boxed_retry:
92
+ boxed_answer = last_boxed_only_string(solution_str)
93
+ answer = boxed_answer if boxed_answer is not None else solution_str
94
+ else:
95
+ return ''
96
+ # Strip LaTeX math delimiters and whitespace
97
+ answer = answer.strip()
98
+ return answer
99
+ elif extraction_type.startswith('boxed'):
100
+ answer = last_boxed_only_string(solution_str)
101
+ return answer.strip() if answer is not None else ''
102
+ else:
103
+ raise ValueError(f"Invalid extraction type: {extraction_type}")
104
+
105
+
106
+ def extract_thought(solution_str: str) -> str:
107
+ if "<think>" in solution_str:
108
+ return solution_str.split("<think>")[-1].split("</think>")[0]
109
+ else:
110
+ return solution_str
111
+
112
+
113
+ def get_format_reward(
114
+ solution_str: str,
115
+ extraction_type: str,
116
+ ) -> float:
117
+ if extraction_type.startswith('answer'):
118
+ pattern = r"(?s)<think>.*?</think>\s*<answer>.*?</answer>"
119
+ matched = re.match(pattern, solution_str)
120
+ if matched:
121
+ return 1.
122
+ else:
123
+ return 0.
124
+ elif extraction_type.startswith('boxed'):
125
+ if last_boxed_only_string(solution_str) is not None:
126
+ return 1.
127
+ else:
128
+ return 0.
129
+ else:
130
+ raise ValueError(f"Invalid extraction type: {extraction_type}")
131
+
132
+
133
+ def extract_code_content(solution_str):
134
+ # Check if the string starts with an XML code block
135
+ xml_pattern = r'^```\s*xml\n(.*?)```'
136
+ xml_match = re.match(xml_pattern, solution_str, re.DOTALL | re.IGNORECASE)
137
+
138
+ if xml_match:
139
+ # XML code block found at start
140
+ return xml_match.group(1).strip()
141
+
142
+ # Check if the string starts with any code block
143
+ generic_pattern = r'^```\s*\w*\n(.*?)```'
144
+ generic_match = re.match(generic_pattern, solution_str, re.DOTALL)
145
+
146
+ if generic_match:
147
+ # Some other code block found at start
148
+ return generic_match.group(1).strip()
149
+
150
+ # No code block found at start, return the original string
151
+ return solution_str.strip()
152
+
153
+
154
+ def get_reward(
155
+ solution_str: str,
156
+ ground_truth: str,
157
+ extra_info: dict,
158
+ extraction_type: str,
159
+ splitter: str,
160
+ math_metric: str = 'deepscaler',
161
+ boxed_retry: bool = False,
162
+ ) -> Tuple[float, Dict[str, float]]:
163
+ solution_str = solution_str.split(splitter)[1].strip()
164
+ solution_str = solution_str.strip('\"\'')
165
+ gt_reward = get_gt_reward(solution_str, ground_truth, extraction_type, extra_info['metric'], math_metric, boxed_retry=boxed_retry)
166
+ format_reward = get_format_reward(solution_str, extraction_type)
167
+ if extra_info['split'] == 'train':
168
+ if extraction_type.startswith('answer') or extraction_type.startswith('boxed'):
169
+ if extraction_type.endswith('conditional'):
170
+ # R(answer) =
171
+ # 1 if correct formatting and correct answer
172
+ # -0.5 if correct formatting and incorrect answer
173
+ # -1 if incorrect formatting
174
+ if not format_reward:
175
+ return -1., {'gt': gt_reward, 'format': format_reward}
176
+ # correct formatting
177
+ else:
178
+ return 1. if gt_reward else -0.5, {'gt': gt_reward, 'format': format_reward}
179
+ elif extraction_type.endswith('addition'):
180
+ return (0.5 if format_reward else 0.) + gt_reward, {'gt': gt_reward, 'format': format_reward}
181
+ elif extraction_type.endswith('multiply'):
182
+ return format_reward * gt_reward, {'gt': gt_reward, 'format': format_reward}
183
+ else:
184
+ raise ValueError(f"Invalid extraction type: {extraction_type}")
185
+ elif extra_info['split'] == 'test':
186
+ return gt_reward, {'gt': gt_reward, 'format': format_reward}
187
+ else:
188
+ raise ValueError(f"Invalid split: {extra_info['split']}")
189
+
190
+
191
+ # string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
192
+ def is_equiv(str1: str, str2: str, verbose: bool = False) -> bool:
193
+ if str1 is None and str2 is None:
194
+ print("WARNING: Both None")
195
+ return True
196
+ if str1 is None or str2 is None:
197
+ return False
198
+
199
+ try:
200
+ ss1 = strip_string(str1)
201
+ ss2 = strip_string(str2)
202
+ if verbose:
203
+ print(ss1, ss2)
204
+ return ss1 == ss2
205
+ except Exception:
206
+ return str1 == str2
207
+
208
+
209
+ def remove_boxed(s: str) -> str:
210
+ if "\\boxed " in s:
211
+ left = "\\boxed "
212
+ assert s[:len(left)] == left
213
+ return s[len(left):]
214
+
215
+ left = "\\boxed{"
216
+
217
+ assert s[:len(left)] == left
218
+ assert s[-1] == "}"
219
+
220
+ return s[len(left):-1]
221
+
222
+
223
+ def last_boxed_only_string(string: str) -> str:
224
+ idx = string.rfind("\\boxed")
225
+ if "\\boxed " in string:
226
+ return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
227
+ if idx < 0:
228
+ idx = string.rfind("\\fbox")
229
+ if idx < 0:
230
+ return None
231
+
232
+ i = idx
233
+ right_brace_idx = None
234
+ num_left_braces_open = 0
235
+ while i < len(string):
236
+ if string[i] == "{":
237
+ num_left_braces_open += 1
238
+ if string[i] == "}":
239
+ num_left_braces_open -= 1
240
+ if num_left_braces_open == 0:
241
+ right_brace_idx = i
242
+ break
243
+ i += 1
244
+
245
+ if right_brace_idx is None:
246
+ retval = None
247
+ else:
248
+ retval = string[idx:right_brace_idx + 1]
249
+
250
+ return retval
251
+
252
+
253
+ def fix_fracs(string: str) -> str:
254
+ substrs = string.split("\\frac")
255
+ new_str = substrs[0]
256
+ if len(substrs) > 1:
257
+ substrs = substrs[1:]
258
+ for substr in substrs:
259
+ new_str += "\\frac"
260
+ if substr[0] == "{":
261
+ new_str += substr
262
+ else:
263
+ try:
264
+ assert len(substr) >= 2
265
+ except AssertionError:
266
+ return string
267
+ a = substr[0]
268
+ b = substr[1]
269
+ if b != "{":
270
+ if len(substr) > 2:
271
+ post_substr = substr[2:]
272
+ new_str += "{" + a + "}{" + b + "}" + post_substr
273
+ else:
274
+ new_str += "{" + a + "}{" + b + "}"
275
+ else:
276
+ if len(substr) > 2:
277
+ post_substr = substr[2:]
278
+ new_str += "{" + a + "}" + b + post_substr
279
+ else:
280
+ new_str += "{" + a + "}" + b
281
+ string = new_str
282
+ return string
283
+
284
+
285
+ def fix_a_slash_b(string: str) -> str:
286
+ if len(string.split("/")) != 2:
287
+ return string
288
+ a = string.split("/")[0]
289
+ b = string.split("/")[1]
290
+ try:
291
+ a = int(a)
292
+ b = int(b)
293
+ assert string == "{}/{}".format(a, b)
294
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
295
+ return new_string
296
+ except AssertionError:
297
+ return string
298
+
299
+
300
+ def remove_right_units(string: str) -> str:
301
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
302
+ if "\\text{ " in string:
303
+ splits = string.split("\\text{ ")
304
+ assert len(splits) == 2
305
+ return splits[0]
306
+ else:
307
+ return string
308
+
309
+
310
+ def fix_sqrt(string: str) -> str:
311
+ if "\\sqrt" not in string:
312
+ return string
313
+ splits = string.split("\\sqrt")
314
+ new_string = splits[0]
315
+ for split in splits[1:]:
316
+ if split[0] != "{":
317
+ a = split[0]
318
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
319
+ else:
320
+ new_substr = "\\sqrt" + split
321
+ new_string += new_substr
322
+ return new_string
323
+
324
+
325
+ def strip_string(string: str) -> str:
326
+ # linebreaks
327
+ string = string.replace("\n", "")
328
+
329
+ # remove inverse spaces
330
+ string = string.replace("\\!", "")
331
+
332
+ # replace \\ with \
333
+ string = string.replace("\\\\", "\\")
334
+
335
+ # replace tfrac and dfrac with frac
336
+ string = string.replace("tfrac", "frac")
337
+ string = string.replace("dfrac", "frac")
338
+
339
+ # remove \left and \right
340
+ string = string.replace("\\left", "")
341
+ string = string.replace("\\right", "")
342
+
343
+ # Remove circ (degrees)
344
+ string = string.replace("^{\\circ}", "")
345
+ string = string.replace("^\\circ", "")
346
+
347
+ # remove dollar signs
348
+ string = string.replace("\\$", "")
349
+
350
+ # remove units (on the right)
351
+ string = remove_right_units(string)
352
+
353
+ # remove percentage
354
+ string = string.replace("\\%", "")
355
+ string = string.replace("\%", "") # noqa: W605
356
+
357
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
358
+ string = string.replace(" .", " 0.")
359
+ string = string.replace("{.", "{0.")
360
+ # if empty, return empty string
361
+ if len(string) == 0:
362
+ return string
363
+ if string[0] == ".":
364
+ string = "0" + string
365
+
366
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
367
+ if len(string.split("=")) == 2:
368
+ if len(string.split("=")[0]) <= 2:
369
+ string = string.split("=")[1]
370
+
371
+ # fix sqrt3 --> sqrt{3}
372
+ string = fix_sqrt(string)
373
+
374
+ # remove spaces
375
+ string = string.replace(" ", "")
376
+
377
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
378
+ string = fix_fracs(string)
379
+
380
+ # manually change 0.5 --> \frac{1}{2}
381
+ if string == "0.5":
382
+ string = "\\frac{1}{2}"
383
+
384
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
385
+ string = fix_a_slash_b(string)
386
+
387
+ return string
absolute_zero_reasoner/rewards/math_utils.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ https://github.com/agentica-project/deepscaler/blob/main/deepscaler/rewards/math_utils/utils.py
3
+ """
4
+ import re
5
+ from pylatexenc import latex2text
6
+ import sympy
7
+ from sympy.parsing import sympy_parser
8
+ from typing import Optional
9
+
10
+
11
+ # Dan Hendrycks' code
12
+ def mathd_normalize_answer(answer: Optional[str]) -> Optional[str]:
13
+ if answer is None:
14
+ return None
15
+ answer = answer.strip()
16
+ try:
17
+ # Remove enclosing `\text{}`.
18
+ m = re.search("^\\\\text\{(?P<text>.+?)\}$", answer)
19
+ if m is not None:
20
+ answer = m.group("text").strip()
21
+ return _strip_string(answer)
22
+ except:
23
+ return answer
24
+
25
+ def _strip_string(string):
26
+ def _fix_fracs(string):
27
+ substrs = string.split("\\frac")
28
+ new_str = substrs[0]
29
+ if len(substrs) > 1:
30
+ substrs = substrs[1:]
31
+ for substr in substrs:
32
+ new_str += "\\frac"
33
+ if substr[0] == "{":
34
+ new_str += substr
35
+ else:
36
+ try:
37
+ assert len(substr) >= 2
38
+ except:
39
+ return string
40
+ a = substr[0]
41
+ b = substr[1]
42
+ if b != "{":
43
+ if len(substr) > 2:
44
+ post_substr = substr[2:]
45
+ new_str += "{" + a + "}{" + b + "}" + post_substr
46
+ else:
47
+ new_str += "{" + a + "}{" + b + "}"
48
+ else:
49
+ if len(substr) > 2:
50
+ post_substr = substr[2:]
51
+ new_str += "{" + a + "}" + b + post_substr
52
+ else:
53
+ new_str += "{" + a + "}" + b
54
+ string = new_str
55
+ return string
56
+
57
+
58
+ def _fix_a_slash_b(string):
59
+ if len(string.split("/")) != 2:
60
+ return string
61
+ a = string.split("/")[0]
62
+ b = string.split("/")[1]
63
+ try:
64
+ a = int(a)
65
+ b = int(b)
66
+ assert string == "{}/{}".format(a, b)
67
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
68
+ return new_string
69
+ except:
70
+ return string
71
+
72
+
73
+ def _remove_right_units(string):
74
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
75
+ if "\\text{ " in string:
76
+ splits = string.split("\\text{ ")
77
+ assert len(splits) == 2
78
+ return splits[0]
79
+ else:
80
+ return string
81
+
82
+
83
+ def _fix_sqrt(string):
84
+ if "\\sqrt" not in string:
85
+ return string
86
+ splits = string.split("\\sqrt")
87
+ new_string = splits[0]
88
+ for split in splits[1:]:
89
+ if split[0] != "{":
90
+ a = split[0]
91
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
92
+ else:
93
+ new_substr = "\\sqrt" + split
94
+ new_string += new_substr
95
+ return new_string
96
+ # linebreaks
97
+ string = string.replace("\n", "")
98
+ # print(string)
99
+
100
+ # remove inverse spaces
101
+ string = string.replace("\\!", "")
102
+ # print(string)
103
+
104
+ # replace \\ with \
105
+ string = string.replace("\\\\", "\\")
106
+ # print(string)
107
+
108
+ # replace tfrac and dfrac with frac
109
+ string = string.replace("tfrac", "frac")
110
+ string = string.replace("dfrac", "frac")
111
+ # print(string)
112
+
113
+ # remove \left and \right
114
+ string = string.replace("\\left", "")
115
+ string = string.replace("\\right", "")
116
+ # print(string)
117
+
118
+ # Remove circ (degrees)
119
+ string = string.replace("^{\\circ}", "")
120
+ string = string.replace("^\\circ", "")
121
+
122
+ # remove dollar signs
123
+ string = string.replace("\\$", "")
124
+
125
+ # remove units (on the right)
126
+ string = _remove_right_units(string)
127
+
128
+ # remove percentage
129
+ string = string.replace("\\%", "")
130
+ string = string.replace("\%", "")
131
+
132
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
133
+ string = string.replace(" .", " 0.")
134
+ string = string.replace("{.", "{0.")
135
+ # if empty, return empty string
136
+ if len(string) == 0:
137
+ return string
138
+ if string[0] == ".":
139
+ string = "0" + string
140
+
141
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
142
+ if len(string.split("=")) == 2:
143
+ if len(string.split("=")[0]) <= 2:
144
+ string = string.split("=")[1]
145
+
146
+ # fix sqrt3 --> sqrt{3}
147
+ string = _fix_sqrt(string)
148
+
149
+ # remove spaces
150
+ string = string.replace(" ", "")
151
+
152
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
153
+ string = _fix_fracs(string)
154
+
155
+ # manually change 0.5 --> \frac{1}{2}
156
+ if string == "0.5":
157
+ string = "\\frac{1}{2}"
158
+
159
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
160
+ string = _fix_a_slash_b(string)
161
+
162
+ return string
163
+
164
+
165
+ # sympy might hang -- we don't care about trying to be lenient in these cases
166
+ BAD_SUBSTRINGS = ["^{", "^("]
167
+ BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
168
+ TUPLE_CHARS = "()[]"
169
+
170
+
171
+ def _sympy_parse(expr: str):
172
+ """Parses an expression with sympy."""
173
+ py_expr = expr.replace("^", "**")
174
+ return sympy_parser.parse_expr(
175
+ py_expr,
176
+ transformations=(
177
+ sympy_parser.standard_transformations
178
+ + (sympy_parser.implicit_multiplication_application,)
179
+ ),
180
+ )
181
+
182
+
183
+ def _parse_latex(expr: str) -> str:
184
+ """Attempts to parse latex to an expression sympy can read."""
185
+ expr = expr.replace("\\tfrac", "\\frac")
186
+ expr = expr.replace("\\dfrac", "\\frac")
187
+ expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
188
+ expr = latex2text.LatexNodes2Text().latex_to_text(expr)
189
+
190
+ # Replace the specific characters that this parser uses.
191
+ expr = expr.replace("√", "sqrt")
192
+ expr = expr.replace("π", "pi")
193
+ expr = expr.replace("∞", "inf")
194
+ expr = expr.replace("∪", "U")
195
+ expr = expr.replace("·", "*")
196
+ expr = expr.replace("×", "*")
197
+
198
+ return expr.strip()
199
+
200
+
201
+ def _is_float(num: str) -> bool:
202
+ try:
203
+ float(num)
204
+ return True
205
+ except ValueError:
206
+ return False
207
+
208
+
209
+ def _is_int(x: float) -> bool:
210
+ try:
211
+ return abs(x - int(round(x))) <= 1e-7
212
+ except:
213
+ return False
214
+
215
+
216
+ def _is_frac(expr: str) -> bool:
217
+ return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
218
+
219
+
220
+ def _str_is_int(x: str) -> bool:
221
+ try:
222
+ x = _strip_properly_formatted_commas(x)
223
+ x = float(x)
224
+ return abs(x - int(round(x))) <= 1e-7
225
+ except:
226
+ return False
227
+
228
+
229
+ def _str_to_int(x: str) -> bool:
230
+ x = x.replace(",", "")
231
+ x = float(x)
232
+ return int(x)
233
+
234
+
235
+ def _inject_implicit_mixed_number(step: str):
236
+ """
237
+ Automatically make a mixed number evalable
238
+ e.g. 7 3/4 => 7+3/4
239
+ """
240
+ p1 = re.compile("([0-9]) +([0-9])")
241
+ step = p1.sub("\\1+\\2", step) ## implicit mults
242
+ return step
243
+
244
+
245
+ def _strip_properly_formatted_commas(expr: str):
246
+ # We want to be careful because we don't want to strip tuple commas
247
+ p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
248
+ while True:
249
+ next_expr = p1.sub("\\1\\3\\4", expr)
250
+ if next_expr == expr:
251
+ break
252
+ expr = next_expr
253
+ return next_expr
254
+
255
+
256
+ def _normalize(expr: str) -> str:
257
+ """Normalize answer expressions."""
258
+ if expr is None:
259
+ return None
260
+
261
+ # Remove enclosing `\text{}`.
262
+ m = re.search("^\\\\text\{(?P<text>.+?)\}$", expr)
263
+ if m is not None:
264
+ expr = m.group("text")
265
+
266
+ expr = expr.replace("\\%", "%")
267
+ expr = expr.replace("\\$", "$")
268
+ expr = expr.replace("$", "")
269
+ expr = expr.replace("%", "")
270
+ expr = expr.replace(" or ", " , ")
271
+ expr = expr.replace(" and ", " , ")
272
+
273
+ expr = expr.replace("million", "*10^6")
274
+ expr = expr.replace("billion", "*10^9")
275
+ expr = expr.replace("trillion", "*10^12")
276
+
277
+ for unit in [
278
+ "degree",
279
+ "cm",
280
+ "centimeter",
281
+ "meter",
282
+ "mile",
283
+ "second",
284
+ "minute",
285
+ "hour",
286
+ "day",
287
+ "week",
288
+ "month",
289
+ "year",
290
+ "foot",
291
+ "feet",
292
+ "inch",
293
+ "yard",
294
+ ]:
295
+ expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
296
+ expr = re.sub(f"\^ *\\\\circ", "", expr)
297
+
298
+ if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
299
+ expr = expr[1:-1]
300
+
301
+ expr = re.sub(",\\\\! *", "", expr)
302
+ if _is_float(expr) and _is_int(float(expr)):
303
+ expr = str(int(round(float(expr))))
304
+ if "\\" in expr:
305
+ try:
306
+ expr = _parse_latex(expr)
307
+ except:
308
+ pass
309
+
310
+ # edge case with mixed numbers and negative signs
311
+ expr = re.sub("- *", "-", expr)
312
+
313
+ expr = _inject_implicit_mixed_number(expr)
314
+ expr = expr.replace(" ", "")
315
+
316
+ # if we somehow still have latex braces here, just drop them
317
+ expr = expr.replace("{", "")
318
+ expr = expr.replace("}", "")
319
+
320
+ # don't be case sensitive for text answers
321
+ expr = expr.lower()
322
+
323
+ if _str_is_int(expr):
324
+ expr = str(_str_to_int(expr))
325
+
326
+ return expr
327
+
328
+
329
+ def count_unknown_letters_in_expr(expr: str):
330
+ expr = expr.replace("sqrt", "")
331
+ expr = expr.replace("frac", "")
332
+ letters_in_expr = set([x for x in expr if x.isalpha()])
333
+ return len(letters_in_expr)
334
+
335
+
336
+ def should_allow_eval(expr: str):
337
+ # we don't want to try parsing unknown text or functions of more than two variables
338
+ if count_unknown_letters_in_expr(expr) > 2:
339
+ return False
340
+
341
+ for bad_string in BAD_SUBSTRINGS:
342
+ if bad_string in expr:
343
+ return False
344
+
345
+ for bad_regex in BAD_REGEXES:
346
+ if re.search(bad_regex, expr) is not None:
347
+ return False
348
+
349
+ return True
350
+
351
+
352
+ def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
353
+ are_equal = False
354
+ try:
355
+ expr = f"({ground_truth_normalized})-({given_normalized})"
356
+ if should_allow_eval(expr):
357
+ sympy_diff = _sympy_parse(expr)
358
+ simplified = sympy.simplify(sympy_diff)
359
+ if simplified == 0:
360
+ are_equal = True
361
+ except:
362
+ pass
363
+ return are_equal
364
+
365
+
366
+ def split_tuple(expr: str):
367
+ """
368
+ Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
369
+ """
370
+ expr = _strip_properly_formatted_commas(expr)
371
+ if len(expr) == 0:
372
+ return []
373
+ if (
374
+ len(expr) > 2
375
+ and expr[0] in TUPLE_CHARS
376
+ and expr[-1] in TUPLE_CHARS
377
+ and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
378
+ ):
379
+ elems = [elem.strip() for elem in expr[1:-1].split(",")]
380
+ else:
381
+ elems = [expr]
382
+ return elems
383
+
384
+
385
+ def last_boxed_only_string(string):
386
+ idx = string.rfind("\\boxed")
387
+ if idx < 0:
388
+ idx = string.rfind("\\fbox")
389
+ if idx < 0:
390
+ return None
391
+
392
+ i = idx
393
+ right_brace_idx = None
394
+ num_left_braces_open = 0
395
+ while i < len(string):
396
+ if string[i] == "{":
397
+ num_left_braces_open += 1
398
+ if string[i] == "}":
399
+ num_left_braces_open -= 1
400
+ if num_left_braces_open == 0:
401
+ right_brace_idx = i
402
+ break
403
+ i += 1
404
+
405
+ if right_brace_idx == None:
406
+ retval = None
407
+ else:
408
+ retval = string[idx:right_brace_idx + 1]
409
+
410
+ return retval
411
+
412
+ def remove_boxed(s):
413
+ left = "\\boxed{"
414
+ try:
415
+ assert s[:len(left)] == left
416
+ assert s[-1] == "}"
417
+ return s[len(left):-1]
418
+ except:
419
+ return None
420
+
421
+
422
+ def extract_boxed_answer(solution: str) -> str:
423
+ """Extract the answer from inside a LaTeX \\boxed{} command"""
424
+ solution = last_boxed_only_string(solution)
425
+ solution = remove_boxed(solution)
426
+ return solution
427
+
428
+ def grade_answer_sympy(given_answer: str, ground_truth: str) -> bool:
429
+ ground_truth_normalized = _normalize(ground_truth)
430
+ given_normalized = _normalize(given_answer)
431
+
432
+ if ground_truth_normalized is None:
433
+ return False
434
+
435
+ if ground_truth_normalized == given_normalized:
436
+ return True
437
+
438
+ if len(given_normalized) == 0:
439
+ return False
440
+
441
+ ground_truth_elems = split_tuple(ground_truth_normalized)
442
+ given_elems = split_tuple(given_normalized)
443
+
444
+ if len(ground_truth_elems) > 1 and (
445
+ ground_truth_normalized[0] != given_normalized[0]
446
+ or ground_truth_normalized[-1] != given_normalized[-1]
447
+ ):
448
+ is_correct = False
449
+ elif len(ground_truth_elems) != len(given_elems):
450
+ is_correct = False
451
+ else:
452
+ for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
453
+ if _is_frac(ground_truth_elem) and _is_frac(given_elem):
454
+ # if fractions aren't reduced, then shouldn't be marked as correct
455
+ # so, we don't want to allow sympy.simplify in this case
456
+ is_correct = ground_truth_elem == given_elem
457
+ elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
458
+ # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
459
+ is_correct = False
460
+ else:
461
+ is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
462
+ if not is_correct:
463
+ break
464
+
465
+ return is_correct
466
+
467
+ def grade_answer_mathd(given_answer: str, ground_truth: str) -> bool:
468
+ ground_truth_normalized_mathd = mathd_normalize_answer(ground_truth)
469
+ given_answer_normalized_mathd = mathd_normalize_answer(given_answer)
470
+
471
+ # be at least as lenient as mathd
472
+ if ground_truth_normalized_mathd == given_answer_normalized_mathd:
473
+ return True
474
+ return False
475
+
476
+ def extract_answer(passage: str) -> str:
477
+ if "\\boxed" in passage:
478
+ return extract_boxed_answer(passage)
479
+ return None
480
+
481
+ def grade_answer_verl(solution_str, ground_truth):
482
+ if not ground_truth:
483
+ return False
484
+ if '\\boxed' in ground_truth:
485
+ ground_truth = extract_answer(ground_truth)
486
+ given_answer = extract_answer(solution_str)
487
+ if given_answer is None:
488
+ return False
489
+ return grade_answer_mathd(given_answer, ground_truth) \
490
+ or grade_answer_sympy(given_answer, ground_truth)
absolute_zero_reasoner/rewards/reward_managers.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ from typing import Dict, Any, List, Tuple
4
+ from collections import defaultdict
5
+ import re
6
+ import uuid
7
+ from functools import partial
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ from transformers import AutoTokenizer
13
+ from verl import DataProto
14
+ from verl.protocol import DataProtoItem
15
+ from verl.utils.dataset.rl_dataset import collate_fn
16
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
17
+
18
+ import absolute_zero_reasoner.rewards.custom_evaluate as custom_evaluate
19
+ from absolute_zero_reasoner.rewards.code_reward import (
20
+ parse_code_input_output,
21
+ parse_inputs_message,
22
+ parse_code_function,
23
+ ast_edit_distance,
24
+ get_code_complexity_reward,
25
+ get_halstead_reward,
26
+ get_type_counts_reward,
27
+ )
28
+ from absolute_zero_reasoner.rewards.custom_evaluate import get_format_reward, extract_answer, extract_thought
29
+ from absolute_zero_reasoner.data_construction.process_data import boxed_instruction, instruction_following
30
+ from absolute_zero_reasoner.data_construction.constructor import get_code_problem_predictor_prompt
31
+ from absolute_zero_reasoner.utils.dataset.rl_dataset import RLHFDataset
32
+ from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter
33
+ from absolute_zero_reasoner.utils.code_utils.checks import check_composite_function, check_no_definitions
34
+
35
+
36
+ class CodeIORewardManager():
37
+ """The reward manager."""
38
+ def __init__(
39
+ self,
40
+ tokenizer: AutoTokenizer,
41
+ num_examine: int,
42
+ split: str,
43
+ reward_fn_extraction_type: str,
44
+ math_metric: str,
45
+ splitter: str,
46
+ output_path: str,
47
+ generation_reward_config: Dict[str, Any],
48
+ debug: bool = False,
49
+ max_prompt_length: int = 8192,
50
+ valid_program_filter: str = 'all',
51
+ batched_estimate: bool = False,
52
+ extract_code_block: bool = True,
53
+ num_inputs: int = 10,
54
+ code_f_reward_type: str = 'accuracy',
55
+ boxed_retry: bool = False,
56
+ ):
57
+ self.tokenizer = tokenizer
58
+ self.num_examine = num_examine # the number of batches of decoded responses to print to the console
59
+ self.compute_score = partial(custom_evaluate.get_reward, math_metric=math_metric, boxed_retry=boxed_retry)
60
+ self.reward_fn_extraction_type = reward_fn_extraction_type
61
+ self.split = split
62
+ self.splitter = splitter
63
+ self.output_path = output_path
64
+ self.max_prompt_length = max_prompt_length
65
+ self.generation_reward_config = generation_reward_config
66
+ self.valid_program_filter = valid_program_filter
67
+ self.batched_estimate = batched_estimate
68
+ self.debug = debug
69
+ self.extract_code_block = extract_code_block
70
+ self.use_original_code_as_ref = generation_reward_config.use_original_code_as_ref
71
+ self.num_inputs = num_inputs
72
+ self.code_f_reward_type = code_f_reward_type
73
+ self.boxed_retry = boxed_retry
74
+
75
+ @staticmethod
76
+ def extract_input_output(extracted_content: str, return_input: bool = True, return_output: bool = False) -> Tuple[str, str]:
77
+ input_pattern = r"```input\s*\n?(.*?)\n?```"
78
+ output_pattern = r"```output\s*\n?(.*?)\n?```"
79
+ assert not (return_input and return_output), "Cannot return both input and output"
80
+ assert return_input or return_output, "Must return at least one of input or output"
81
+
82
+ # Use flags for case-insensitive matching and dotall
83
+ flags = re.DOTALL | re.IGNORECASE
84
+ if return_input:
85
+ input_matches = list(re.finditer(input_pattern, extracted_content, flags))
86
+ if not input_matches:
87
+ # Try alternative pattern without explicit input block
88
+ input_matches = list(re.finditer(r"# Input:\s*(.*?)(?=\n```|$)", extracted_content, flags))
89
+ if not input_matches:
90
+ # Match input() function call and preserve quotes
91
+ input_matches = list(re.finditer(r'input\s*\((.*?)\)', extracted_content, flags))
92
+ if not input_matches:
93
+ # Match <input> tag with optional closing tag, strip spaces
94
+ input_matches = list(re.finditer(r"<input>\s*(.*?)(?:</input>|\s*$)", extracted_content, flags))
95
+ if not input_matches:
96
+ # Match "The input is" pattern case-insensitively
97
+ input_matches = list(re.finditer(r"the input is\s*(.*?)\.?$", extracted_content, flags))
98
+ # if still no input matches, use the extracted answer as the input
99
+ # Don't strip() here to preserve quotes
100
+ input_snippet = input_matches[-1].group(1) if input_matches else extracted_content
101
+ return input_snippet
102
+
103
+ if return_output:
104
+ output_matches = list(re.finditer(output_pattern, extracted_content, flags))
105
+ if not output_matches:
106
+ # Try alternative pattern without explicit output block
107
+ output_matches = list(re.finditer(r"# Output:\s*(.*?)(?=\n```|$)", extracted_content, flags))
108
+ if not output_matches:
109
+ # Match output() function call and preserve quotes
110
+ output_matches = list(re.finditer(r'output\s*\((.*?)\)', extracted_content, flags))
111
+ if not output_matches:
112
+ # Match <output> tag with optional closing tag, strip spaces
113
+ output_matches = list(re.finditer(r"<output>\s*(.*?)(?:</output>|\s*$)", extracted_content, flags))
114
+ if not output_matches:
115
+ # Match "The output is" pattern case-insensitively, strip space after "is" and period at end
116
+ output_matches = list(re.finditer(r"the output is\s*(.*?)\.?$", extracted_content, flags))
117
+ # if still no output matches, use the extracted answer as the output
118
+ output_snippet = output_matches[-1].group(1) if output_matches else extracted_content
119
+ return output_snippet
120
+
121
+ def _get_data_dict(self, data_item: DataProtoItem, problem_type: str, executor, banned_words: List[str], uid: str, banned_assertion_keywords: List[str]) -> Dict:
122
+ prompt_ids = data_item.batch['prompts']
123
+
124
+ prompt_length = prompt_ids.shape[-1]
125
+
126
+ valid_prompt_length = data_item.batch['attention_mask'][:prompt_length].sum()
127
+ valid_prompt_ids = prompt_ids[-valid_prompt_length:]
128
+
129
+ response_ids = data_item.batch['responses']
130
+ valid_response_length = data_item.batch['attention_mask'][prompt_length:].sum()
131
+ valid_response_ids = response_ids[:valid_response_length]
132
+
133
+ # decode
134
+ sequences = torch.cat((valid_prompt_ids, valid_response_ids))
135
+ sequences_str = self.tokenizer.decode(sequences)
136
+
137
+ ground_truth = data_item.non_tensor_batch['reward_model']['ground_truth']
138
+ data_source = data_item.non_tensor_batch['data_source']
139
+ extra_info = data_item.non_tensor_batch['extra_info']
140
+ non_special_tokens_sequences_str = self.tokenizer.decode(self.tokenizer.encode(sequences_str), skip_special_tokens=True)
141
+
142
+ generation = non_special_tokens_sequences_str.split(self.splitter)[1].strip().strip('\"\'')
143
+ extracted_content = extract_answer(generation, self.reward_fn_extraction_type, boxed_retry=self.boxed_retry)
144
+ thought = extract_thought(generation)
145
+
146
+ data_dict = {
147
+ 'generation': generation,
148
+ 'data_source': data_source,
149
+ 'ground_truth': ground_truth,
150
+ 'extra_info': extra_info,
151
+ 'non_special_tokens_sequences_str': non_special_tokens_sequences_str,
152
+ 'valid_response_length': valid_response_length,
153
+ 'extracted_content': extracted_content,
154
+ 'thought': thought,
155
+ 'uid': uid,
156
+ }
157
+ if problem_type.startswith('gen'):
158
+ data_dict['references'] = [ref['snippet'] for ref in data_item.non_tensor_batch['extra_info']['chosen_references']]
159
+ if problem_type != 'gen_code_f':
160
+ data_dict['composite_functions'] = data_item.non_tensor_batch['extra_info']['composite_functions'].tolist()
161
+ else:
162
+ data_dict['imports'] = [ref['imports'] for ref in data_item.non_tensor_batch['extra_info']['chosen_references']]
163
+ if self.use_original_code_as_ref:
164
+ data_dict['original_references'] = [ref['original_snippet'] for ref in data_item.non_tensor_batch['extra_info']['chosen_references']]
165
+ elif problem_type.startswith('pred') and 'code_f' not in problem_type:
166
+ data_dict['program'] = data_item.non_tensor_batch['problem']
167
+ data_dict['input'] = data_item.non_tensor_batch['extra_info']['input']
168
+ data_dict['output'] = data_item.non_tensor_batch['extra_info']['output']
169
+ data_dict['imports'] = data_item.non_tensor_batch['extra_info'].get('imports', [])
170
+ elif problem_type.startswith('pred') and 'code_f' in problem_type:
171
+ data_dict['program'] = data_item.non_tensor_batch['problem']
172
+ data_dict['given_inputs'] = data_item.non_tensor_batch['extra_info']['given_inputs']
173
+ data_dict['given_outputs'] = data_item.non_tensor_batch['extra_info']['given_outputs']
174
+ data_dict['hidden_inputs'] = data_item.non_tensor_batch['extra_info']['hidden_inputs']
175
+ data_dict['hidden_outputs'] = data_item.non_tensor_batch['extra_info']['hidden_outputs']
176
+ data_dict['message'] = data_item.non_tensor_batch['extra_info']['message']
177
+ data_dict['imports'] = data_item.non_tensor_batch['extra_info'].get('imports', [])
178
+
179
+ # if QA task, we only need to check the format
180
+ if problem_type is None:
181
+ format_score = get_format_reward(solution_str=generation, extraction_type=self.reward_fn_extraction_type) if self.generation_reward_config.format_reward else 1.
182
+ data_dict['format_score'] = format_score
183
+ return data_dict
184
+ # first go through, we only checking the format
185
+ elif problem_type.startswith('gen') and 'code_f' not in problem_type:
186
+ success, result = parse_code_input_output(
187
+ extracted_content,
188
+ parse_output=False,
189
+ remove_after_return=self.generation_reward_config.remove_after_return and self.split == 'train',
190
+ remove_comments=self.generation_reward_config.remove_comments and self.split == 'train',
191
+ remove_print=self.generation_reward_config.remove_print and self.split == 'train',
192
+ reject_multiple_functions=self.generation_reward_config.reject_multiple_functions,
193
+ f_replace_location=self.generation_reward_config.f_replace_location,
194
+ reject_test_input_in_code=self.generation_reward_config.reject_test_input_in_code,
195
+ code_location=self.generation_reward_config.code_location,
196
+ )
197
+ if len(data_dict['composite_functions']) > 0 and success:
198
+ # first, check if the composite function names are redefined in the code, which we do not allow
199
+ success = check_no_definitions(result['code'], [f'g_{i}' for i in range(len(data_dict['composite_functions']))])
200
+ if not success: # if the composite function names are redefined, we do not allow the code
201
+ data_dict['code_validity'] = False
202
+ data_dict['format_score'] = 0.
203
+ return data_dict
204
+
205
+ composite_imports = '\n'.join(
206
+ '\n'.join(list(d['imports'])) if list(d['imports']) else '' for d in data_dict['composite_functions']
207
+ ).strip()
208
+
209
+ composite_snippets = '\n\n'.join(d['snippet'] for d in data_dict['composite_functions']).strip()
210
+
211
+ # cache the original code
212
+ result['original_code'] = result['code']
213
+
214
+ result['code'] = f"{composite_imports}\n\n{composite_snippets}\n\n{result['code']}".strip()
215
+ # TODO: composite function check
216
+ success = check_composite_function(
217
+ code = result['code'],
218
+ composite_functions = [d['snippet'] for d in data_dict['composite_functions']],
219
+ )
220
+ if success:
221
+ code_validity, output = executor.check_all(
222
+ code=result['code'],
223
+ inputs=result['input'],
224
+ banned_keywords=banned_words,
225
+ check_determinism=True,
226
+ imports=list(set(result['imports'])),
227
+ check_error=problem_type == 'gen_code_e',
228
+ banned_keywords_for_errors_and_exceptions=banned_assertion_keywords,
229
+ )
230
+ if not code_validity:
231
+ data_dict['code_validity'] = False
232
+ data_dict['format_score'] = 0.
233
+ return data_dict
234
+ # means the code is valid, we append any good programs, but we eval format separately
235
+ data_dict['answer'] = {
236
+ 'snippet': result['code'],
237
+ 'original_snippet': result['original_code'] if 'original_code' in result else result['code'],
238
+ 'input': result['input'],
239
+ 'output': output,
240
+ 'imports': result['imports'],
241
+ 'thought': thought,
242
+ 'composite_functions': data_dict['composite_functions']
243
+ }
244
+ format_score = get_format_reward(solution_str=generation, extraction_type=self.reward_fn_extraction_type) if self.generation_reward_config.format_reward else 1.
245
+ data_dict['format_score'] = format_score
246
+ data_dict['code_validity'] = True
247
+ return data_dict
248
+ else:
249
+ data_dict['code_validity'] = False
250
+ data_dict['format_score'] = 0.
251
+ return data_dict
252
+
253
+ elif problem_type == 'gen_code_f':
254
+ success, result = parse_inputs_message(
255
+ extracted_content,
256
+ num_inputs=self.num_inputs,
257
+ )
258
+ if success and len(result['inputs']) == self.num_inputs: # for code_f, we need to ensure the number of inputs is correct
259
+ outputs = []
260
+ for inpt in result['inputs']:
261
+ code_validity, output = executor.check_all(
262
+ code=data_dict['references'][0],
263
+ inputs=inpt,
264
+ banned_keywords=[],
265
+ check_determinism=True,
266
+ imports=data_dict['imports'][0],
267
+ check_error=False,
268
+ banned_keywords_for_errors_and_exceptions=[],
269
+ )
270
+ if not code_validity:
271
+ data_dict['code_validity'] = False
272
+ data_dict['format_score'] = 0.
273
+ return data_dict
274
+ outputs.append(output)
275
+ data_dict['answer'] = {
276
+ 'snippet': data_dict['references'][0],
277
+ 'inputs': result['inputs'],
278
+ 'outputs': outputs,
279
+ 'message': result['message'],
280
+ 'imports': data_dict['imports'][0],
281
+ 'thought': thought,
282
+ }
283
+ format_score = get_format_reward(solution_str=generation, extraction_type=self.reward_fn_extraction_type) if self.generation_reward_config.format_reward else 1.
284
+ data_dict['format_score'] = format_score
285
+ data_dict['code_validity'] = True
286
+ return data_dict
287
+ else:
288
+ data_dict['code_validity'] = False
289
+ data_dict['format_score'] = 0.
290
+ return data_dict
291
+
292
+ # if prediction is the task
293
+ elif problem_type.startswith('pred'):
294
+ # Check required blocks
295
+ if problem_type.endswith('code_i'): # parse input
296
+ input_snippet = self.extract_input_output(extracted_content, return_input=True, return_output=False) \
297
+ if self.extract_code_block else extracted_content
298
+ if input_snippet is None:
299
+ data_dict['format_score'] = 0.
300
+ return data_dict
301
+ format_score = get_format_reward(solution_str=generation, extraction_type=self.reward_fn_extraction_type) if self.generation_reward_config.format_reward else 1.
302
+ data_dict['format_score'] = format_score
303
+ data_dict['answer'] = input_snippet
304
+ return data_dict
305
+ elif problem_type.endswith('code_o') or problem_type.endswith('code_e'): # parse output, code_e format is same as code_o
306
+ output_snippet = self.extract_input_output(extracted_content, return_input=False, return_output=True) \
307
+ if self.extract_code_block else extracted_content
308
+ if output_snippet is None:
309
+ data_dict['format_score'] = 0.
310
+ return data_dict
311
+ format_score = get_format_reward(solution_str=generation, extraction_type=self.reward_fn_extraction_type) if self.generation_reward_config.format_reward else 1.
312
+ data_dict['format_score'] = format_score
313
+ data_dict['answer'] = output_snippet
314
+ return data_dict
315
+ elif problem_type.endswith('code_f'):
316
+ success, code_snippet = parse_code_function(extracted_content)
317
+ if not success:
318
+ data_dict['format_score'] = 0.
319
+ return data_dict
320
+ format_score = get_format_reward(solution_str=generation, extraction_type=self.reward_fn_extraction_type) if self.generation_reward_config.format_reward else 1.
321
+ data_dict['format_score'] = format_score
322
+ data_dict['answer'] = {
323
+ 'snippet': code_snippet,
324
+ 'given_inputs': data_dict['given_inputs'],
325
+ 'given_outputs': data_dict['given_outputs'],
326
+ 'hidden_inputs': data_dict['hidden_inputs'],
327
+ 'hidden_outputs': data_dict['hidden_outputs'],
328
+ 'message': data_dict['message'],
329
+ 'imports': data_dict['imports'],
330
+ 'thought': thought,
331
+ 'gold_program': data_dict['program'],
332
+ }
333
+ return data_dict
334
+ else:
335
+ raise ValueError(f"Invalid problem type: {problem_type}")
336
+ else:
337
+ raise ValueError(f"Invalid problem type: {problem_type}")
338
+
339
+ def __call__(
340
+ self,
341
+ data: DataProto,
342
+ problem_type: str = None,
343
+ executor = None,
344
+ rollout_actor_wg = None,
345
+ banned_words: List[str] = [],
346
+ banned_assertion_keywords: List[str] = [],
347
+ n_samples: int = 1,
348
+ input_type_counters: Dict[str, Dict[str, int]] = None,
349
+ output_type_counters: Dict[str, Dict[str, int]] = None,
350
+ error_type_counters: Dict[str, Dict[str, int]] = None,
351
+ ) -> Tuple[torch.Tensor, Dict, List[Dict], List[Dict]]:
352
+ """We will expand this function gradually based on the available datasets"""
353
+
354
+ # If there is rm score, we directly return rm score. Otherwise, we compute via rm_score_fn
355
+ if 'rm_scores' in data.batch.keys():
356
+ return data.batch['rm_scores']
357
+
358
+ reward_tensor = torch.zeros_like(data.batch['responses'], dtype=torch.float32)
359
+
360
+ all_scores = defaultdict(list)
361
+ data_dicts = []
362
+ valid_programs = [] # for gen tasks, we need to store the valid programs for later use, ignore this if prediction task
363
+ invalid_programs = [] # for gen tasks, we need to store the invalid programs for analysis
364
+ correct_predictions = []
365
+ uids = np.array([str(uuid.uuid4()) for _ in range(len(data))], dtype=object)
366
+ if problem_type is None:
367
+ problem_types = [d.non_tensor_batch['extra_info']['metric'] for d in data]
368
+ problem_type = 'pred' # dummy set
369
+ else:
370
+ problem_types = [problem_type] * len(data)
371
+ PrettyPrinter.section_header("Getting Data Dicts")
372
+ for i in range(len(data)): # get format score
373
+ data_dict = self._get_data_dict(data[i], problem_types[i], executor, banned_words, uids[i], banned_assertion_keywords)
374
+ data_dicts.append(data_dict)
375
+
376
+ if problem_type.startswith('gen') and rollout_actor_wg is not None: # get generation rewards
377
+ PrettyPrinter.section_header("Generating Rewards for Generation Tasks")
378
+ rewards, valid_programs, invalid_programs = self._get_problem_generator_rewards_and_valid_programs(
379
+ data_dicts=data_dicts,
380
+ problem_type=problem_type,
381
+ n_samples=n_samples,
382
+ rollout_actor_wg=rollout_actor_wg,
383
+ executor=executor,
384
+ input_type_counters=input_type_counters,
385
+ output_type_counters=output_type_counters,
386
+ error_type_counters=error_type_counters,
387
+ )
388
+ PrettyPrinter.section_header("Combining Rewards for Generation Tasks")
389
+ for i in range(len(data_dicts)):
390
+ uid = data_dicts[i]['uid']
391
+ valid_response_length = data_dicts[i]['valid_response_length']
392
+ acc_reward = rewards[uid]['accuracy']
393
+ format_reward = data_dicts[i]['format_score']
394
+ if format_reward > 0:
395
+ if acc_reward > 0:
396
+ # Helper function for safe reward combination
397
+ def _combine_rewards(acc, intrinsic_components, method):
398
+ components = [c for c in intrinsic_components if c is not None]
399
+
400
+ if method == 'sum':
401
+ return acc + sum(components) if components else acc
402
+ elif method == 'multiply':
403
+ return acc * np.prod([c for c in components]) if components else acc
404
+ elif method == 'sum_multiply':
405
+ return acc + np.prod([c for c in components]) if components else acc
406
+ elif method == 'multiply_sum':
407
+ return acc * sum(components) if components else acc
408
+ else:
409
+ raise ValueError(f"Unknown combination method: {method}")
410
+
411
+ intrinsic_reward_components = []
412
+ if problem_type.endswith('code_f'):
413
+ if self.generation_reward_config.f_input_answer_diversity_reward.enabled:
414
+ intrinsic_reward_components.append(min(self.generation_reward_config.f_input_answer_diversity_reward.coef * rewards[uid]['input_type_counts'],
415
+ self.generation_reward_config.f_input_answer_diversity_reward.max))
416
+ if self.generation_reward_config.f_output_answer_diversity_reward.enabled:
417
+ intrinsic_reward_components.append(min(self.generation_reward_config.f_output_answer_diversity_reward.coef * rewards[uid]['output_type_counts'],
418
+ self.generation_reward_config.f_output_answer_diversity_reward.max))
419
+ else:
420
+ if self.generation_reward_config.complexity_reward.enabled:
421
+ intrinsic_reward_components.append(min(self.generation_reward_config.complexity_reward.coef * rewards[uid]['complexity'],
422
+ self.generation_reward_config.complexity_reward.max))
423
+ if self.generation_reward_config.mean_edit_distance_reward.enabled:
424
+ intrinsic_reward_components.append(min(self.generation_reward_config.mean_edit_distance_reward.coef * rewards[uid]['mean_edit_distance'],
425
+ self.generation_reward_config.mean_edit_distance_reward.max))
426
+ if self.generation_reward_config.halstead_reward.enabled:
427
+ intrinsic_reward_components.append(min(self.generation_reward_config.halstead_reward.coef * rewards[uid]['halstead'],
428
+ self.generation_reward_config.halstead_reward.max))
429
+ if self.generation_reward_config.answer_diversity_reward.enabled:
430
+ intrinsic_reward_components.append(min(self.generation_reward_config.answer_diversity_reward.coef * rewards[uid]['type_counts'],
431
+ self.generation_reward_config.answer_diversity_reward.max))
432
+
433
+ final_reward = _combine_rewards(acc_reward, intrinsic_reward_components, self.generation_reward_config.intrinsic_combine_method)
434
+ reward_tensor[i, valid_response_length - 1] = final_reward
435
+ else:
436
+ reward_tensor[i, valid_response_length - 1] = -0.5
437
+ else:
438
+ reward_tensor[i, valid_response_length - 1] = -1.0
439
+ all_scores['accuracy'] = [rewards[uid]['accuracy'] for uid in rewards]
440
+ all_scores['format_score'] = [data_dicts[i]['format_score'] for i in range(len(data))]
441
+ if 'code_f' not in problem_type:
442
+ all_scores['answer_diversity'] = [rewards[uid]['type_counts'] for uid in rewards]
443
+ all_scores['complexity'] = [rewards[uid]['complexity'] for uid in rewards]
444
+ all_scores['mean_edit_distance'] = [rewards[uid]['mean_edit_distance'] for uid in rewards]
445
+ all_scores['halstead'] = [rewards[uid]['halstead'] for uid in rewards]
446
+ else:
447
+ all_scores['input_answer_diversity'] = [rewards[uid]['input_type_counts'] for uid in rewards]
448
+ all_scores['output_answer_diversity'] = [rewards[uid]['output_type_counts'] for uid in rewards]
449
+ elif problem_type.startswith('pred'): # get prediction rewards
450
+ PrettyPrinter.section_header("Getting Prediction Rewards")
451
+ all_scores['none_count'] = 0
452
+ acc_rewards = []
453
+ for i, data_dict in enumerate(data_dicts):
454
+ valid_response_length = data_dict['valid_response_length']
455
+ imports = data_dict['imports']
456
+ if not problem_type.endswith('code_f'):
457
+ answer = data_dict['answer']
458
+ gold_input = data_dict['input']
459
+ gold_output = data_dict['output']
460
+ program = data_dict['program']
461
+ else:
462
+ hidden_inputs = data_dict['hidden_inputs']
463
+ hidden_outputs = data_dict['hidden_outputs']
464
+ if not data_dicts[i]['format_score']: # early stop if the format is not correct
465
+ acc_reward = 0.
466
+ elif problem_types[i].endswith('code_i'):
467
+ acc_reward = executor.eval_input_prediction(code=program, gold_output=gold_output, agent_input=answer, imports=list(set(imports)))
468
+ # problematic, but we did not encounter too much of this
469
+ if acc_reward is None:
470
+ all_scores['none_count'] += 1
471
+ acc_reward = 0.
472
+ print(f"error in pred_code_i, not in [0, 1], acc_reward={acc_reward}\nprogram:\n{program}\n---\nanswer:\n{answer}\n---\nimports:\n{imports}\n---\n")
473
+ if acc_reward > 0.0:
474
+ correct_predictions.append(data_dict)
475
+ elif problem_types[i].endswith('code_o'):
476
+ acc_reward = executor.eval_output_prediction(code=program, gold_output=gold_output, agent_output=answer, imports=list(set(imports)))
477
+ # problematic, but we did not encounter too much of this
478
+ if acc_reward is None:
479
+ all_scores['none_count'] += 1
480
+ acc_reward = 0.
481
+ print(f"error in pred_code_o, not in [0, 1], acc_reward={acc_reward}\nprogram:\n{program}\n---\nanswer:\n{answer}\n---\nimports:\n{imports}\n---\n")
482
+ if acc_reward > 0.0:
483
+ correct_predictions.append(data_dict)
484
+ elif problem_types[i].endswith('code_e'): # string matching for errors
485
+ answer = answer.split(' ')[0].split(':')[0]
486
+ if answer.lower() == gold_output.lower():
487
+ acc_reward = 1.0
488
+ correct_predictions.append(data_dict)
489
+ else:
490
+ acc_reward = 0.0
491
+ elif problem_types[i].endswith('code_f'):
492
+ input_output_accs = []
493
+ program = data_dict['answer']['snippet']
494
+ for inpt, outpt in zip(hidden_inputs, hidden_outputs):
495
+ input_output_acc = executor.eval_input_prediction(
496
+ code=program,
497
+ gold_output=outpt,
498
+ agent_input=inpt,
499
+ imports=list(set(imports)),
500
+ )
501
+ if input_output_acc is not None:
502
+ input_output_accs.append(input_output_acc)
503
+ acc_reward = np.mean(input_output_accs) if input_output_accs else 0.0
504
+ if self.code_f_reward_type == 'binary':
505
+ acc_reward = 1.0 if acc_reward == 1.0 else 0.0
506
+ elif self.code_f_reward_type == 'if_one_correct':
507
+ acc_reward = 1.0 if acc_reward > 0 else 0.0
508
+ # note that if code_f_reward_type==accuracy, it is already handled in the above
509
+ if acc_reward > 0:
510
+ correct_predictions.append(data_dict)
511
+ else:
512
+ raise ValueError(f"Invalid problem type: {problem_types[i]}")
513
+
514
+ if self.split == 'train':
515
+ if data_dicts[i]['format_score'] > 0:
516
+ if acc_reward > 0:
517
+ reward_tensor[i, valid_response_length - 1] = acc_reward
518
+ else:
519
+ reward_tensor[i, valid_response_length - 1] = -0.5
520
+ else:
521
+ reward_tensor[i, valid_response_length - 1] = -1.0
522
+ elif self.split == 'test': # only acc reward for eval
523
+ if acc_reward > 0:
524
+ reward_tensor[i, valid_response_length - 1] = 1.0
525
+ else:
526
+ reward_tensor[i, valid_response_length - 1] = 0.0
527
+ acc_rewards.append(acc_reward)
528
+ all_scores['accuracy'] = acc_rewards
529
+ all_scores['format_score'] = [data_dicts[i]['format_score'] for i in range(len(data))]
530
+ all_scores['none_ratio'] = all_scores['none_count'] / len(data)
531
+ return reward_tensor, all_scores, valid_programs, correct_predictions, invalid_programs
532
+
533
+ def _get_problem_generator_rewards_and_valid_programs(
534
+ self,
535
+ data_dicts: List[Dict],
536
+ problem_type: str,
537
+ n_samples: int,
538
+ rollout_actor_wg,
539
+ executor,
540
+ input_type_counters: Dict[str, Dict[str, int]] = None,
541
+ output_type_counters: Dict[str, Dict[str, int]] = None,
542
+ error_type_counters: Dict[str, Dict[str, int]] = None,
543
+ ) -> Tuple[Dict[str, Dict[str, float]], List[Dict[str, str]]]:
544
+ """This function uses samples to estimate the accuracy reward for each program, also computes the code complexity and mean edit distance of generated programs.
545
+ Also returns the valid programs using filters.
546
+ Args:
547
+ data_dicts: List[Dict]: A list of data dictionaries.
548
+ problem_type: str: The type of problem.
549
+ n_samples: int: The number of samples to use.
550
+ rollout_actor_wg: RolloutActorWG: The rollout actor.
551
+ executor: PythonExecutor/CodeBoxExecutor: The executor.
552
+ type_counters: Dict[str, Dict[str, int]]: The type counters.
553
+ Returns:
554
+ rewards: Dict[str, Dict[str, float]]: A dictionary of rewards for each program.
555
+ valid_programs: List[Dict[str, str]]: A list of valid programs.
556
+ """
557
+ if problem_type.endswith('code_i'):
558
+ type_counters = input_type_counters
559
+ elif problem_type.endswith('code_o'):
560
+ type_counters = output_type_counters
561
+ elif problem_type.endswith('code_e'):
562
+ type_counters = error_type_counters
563
+ valid_data_dicts = [data_dict for data_dict in data_dicts if data_dict['code_validity']]
564
+ uid2valid_dict_idx = {data_dict['uid']: i for i, data_dict in enumerate(valid_data_dicts)}
565
+ valid_uids = [data_dict['uid'] for data_dict in data_dicts if data_dict['code_validity']]
566
+ invalid_uids = [data_dict['uid'] for data_dict in data_dicts if not data_dict['code_validity']]
567
+ assert len(valid_uids) + len(invalid_uids) == len(data_dicts)
568
+ accuracies = {uid: 1.0 for uid in invalid_uids} # for invalid uids, we give maximum accuracy to the model
569
+ rewards = defaultdict(dict)
570
+ valid_programs = []
571
+ invalid_programs = []
572
+ if len(valid_uids) > 0:
573
+ if self.reward_fn_extraction_type.startswith('boxed'):
574
+ instruction_template = boxed_instruction
575
+ elif self.reward_fn_extraction_type.startswith('answer'):
576
+ instruction_template = instruction_following
577
+ elif self.reward_fn_extraction_type.startswith('none'):
578
+ instruction_template = '{}'
579
+ else:
580
+ raise ValueError(f"Invalid instruction type: {self.reward_fn_extraction_type}")
581
+ prompts = []
582
+ if problem_type.endswith('code_i'):
583
+ pt = 'code_i'
584
+ elif problem_type.endswith('code_o'):
585
+ pt = 'code_o'
586
+ elif problem_type.endswith('code_e'):
587
+ pt = 'code_e'
588
+ elif problem_type.endswith('code_f'):
589
+ pt = 'code_f'
590
+ else:
591
+ raise ValueError(f"Invalid problem type: {problem_type}")
592
+ for data_dict in valid_data_dicts:
593
+ if pt == 'code_f':
594
+ num_given_inputs = len(data_dict['answer']['inputs']) // 2
595
+ num_given_outputs = len(data_dict['answer']['outputs']) // 2
596
+ data_dict['answer']['given_inputs'] = data_dict['answer']['inputs'][:num_given_inputs]
597
+ data_dict['answer']['given_outputs'] = data_dict['answer']['outputs'][:num_given_outputs]
598
+ data_dict['answer']['hidden_inputs'] = data_dict['answer']['inputs'][num_given_inputs:]
599
+ data_dict['answer']['hidden_outputs'] = data_dict['answer']['outputs'][num_given_outputs:]
600
+ io_prompt = instruction_template.format(
601
+ get_code_problem_predictor_prompt(
602
+ problem_type=problem_type,
603
+ snippet=data_dict['answer']['snippet'],
604
+ message=data_dict['answer']['message'],
605
+ input_output_pairs=zip(data_dict['answer']['given_inputs'], data_dict['answer']['given_outputs']),
606
+ )
607
+ )
608
+ else:
609
+ io_prompt = instruction_template.format(
610
+ get_code_problem_predictor_prompt(
611
+ problem_type=pt,
612
+ snippet=data_dict['answer']['snippet'],
613
+ input_args=data_dict['answer']['input'],
614
+ output=data_dict['answer']['output'],
615
+ )
616
+ )
617
+ prompts_dict = {
618
+ 'prompt': [{'role': 'user', 'content': io_prompt}],
619
+ 'uid': data_dict['uid'],
620
+ 'problem': data_dict['answer'],
621
+ 'data_source': data_dict['data_source'],
622
+ 'ground_truth': data_dict['answer']['output'] if pt != 'code_f' else data_dict['answer']['snippet'],
623
+ 'extra_info': data_dict['extra_info'],
624
+ 'program': data_dict['answer']['snippet'],
625
+ 'imports': data_dict['answer']['imports'],
626
+ 'references': data_dict['references'],
627
+ }
628
+ if pt == 'code_f':
629
+ prompts_dict.update({
630
+ 'given_inputs': data_dict['answer']['given_inputs'],
631
+ 'given_outputs': data_dict['answer']['given_outputs'],
632
+ 'hidden_inputs': data_dict['answer']['hidden_inputs'],
633
+ 'hidden_outputs': data_dict['answer']['hidden_outputs'],
634
+ 'message': data_dict['answer']['message'],
635
+ })
636
+ else:
637
+ prompts_dict.update({
638
+ 'input': data_dict['answer']['input'],
639
+ 'output': data_dict['answer']['output'],
640
+ 'original_program': data_dict['answer']['original_snippet'],
641
+ 'composite_functions': data_dict['answer']['composite_functions'],
642
+ })
643
+ prompts.append(prompts_dict)
644
+
645
+ # sampling to estimate the accuracy
646
+ PrettyPrinter.section_header("Sampling to Estimate Accuracy")
647
+ prompts = prompts * n_samples # repeat the prompts n_samples times
648
+ pd.DataFrame(prompts).to_parquet(f'{self.output_path}/temp.parquet') # RLHFDataset expects parquet
649
+ temp_data = RLHFDataset(
650
+ parquet_files=f'{self.output_path}/temp.parquet',
651
+ tokenizer=self.tokenizer,
652
+ prompt_key='prompt',
653
+ max_prompt_length=self.max_prompt_length,
654
+ filter_prompts=True,
655
+ return_raw_chat=False,
656
+ truncation='error'
657
+ )
658
+ os.remove(f'{self.output_path}/temp.parquet') # we do not need this file after we load in the dataset
659
+ sampler = torch.utils.data.SequentialSampler(data_source=temp_data)
660
+
661
+ dataloader = torch.utils.data.DataLoader(
662
+ dataset=temp_data,
663
+ batch_size=len(temp_data),
664
+ drop_last=False,
665
+ shuffle=False,
666
+ collate_fn=collate_fn,
667
+ sampler=sampler,
668
+ )
669
+ assert len(dataloader) == 1
670
+ data = next(iter(dataloader))
671
+ batch = DataProto.from_single_dict(data)
672
+ gen_batch = batch.pop(['input_ids', 'attention_mask', 'position_ids'])
673
+ gen_batch.meta_info = {
674
+ 'eos_token_id': self.tokenizer.eos_token_id,
675
+ 'pad_token_id': self.tokenizer.pad_token_id,
676
+ 'recompute_log_prob': False,
677
+ 'do_sample': True,
678
+ 'validate': True,
679
+ }
680
+ # pad to be divisible by dp_size
681
+ gen_batch_padded, pad_size = pad_dataproto_to_divisor(gen_batch, rollout_actor_wg.world_size)
682
+ output_gen_batch_padded = rollout_actor_wg.generate_sequences(gen_batch_padded)
683
+ # unpad
684
+ output_gen_batch = unpad_dataproto(output_gen_batch_padded, pad_size=pad_size)
685
+ print('validation generation end')
686
+
687
+ # Store generated outputs
688
+ batch = batch.union(output_gen_batch)
689
+ batched_responses = []
690
+ for b in batch:
691
+ batch_dict = {
692
+ 'extracted_answers': extract_answer(
693
+ self.tokenizer.decode(b.batch['responses'], skip_special_tokens=True),
694
+ self.reward_fn_extraction_type,
695
+ boxed_retry=self.boxed_retry,
696
+ ),
697
+ 'uid': b.non_tensor_batch['uid'],
698
+ 'problem': b.non_tensor_batch['problem'],
699
+ 'data_source': b.non_tensor_batch['data_source'],
700
+ 'extra_info': b.non_tensor_batch['extra_info'],
701
+ 'program': b.non_tensor_batch['program'],
702
+ 'references': b.non_tensor_batch['references'],
703
+ 'imports': b.non_tensor_batch['imports'],
704
+ }
705
+ if pt == 'code_f':
706
+ batch_dict.update({
707
+ 'given_inputs': b.non_tensor_batch['given_inputs'],
708
+ 'given_outputs': b.non_tensor_batch['given_outputs'],
709
+ 'hidden_inputs': b.non_tensor_batch['hidden_inputs'],
710
+ 'hidden_outputs': b.non_tensor_batch['hidden_outputs'],
711
+ 'message': b.non_tensor_batch['message'],
712
+ })
713
+ else:
714
+ batch_dict.update({
715
+ 'input': b.non_tensor_batch['input'],
716
+ 'output': b.non_tensor_batch['output'],
717
+ 'original_program': b.non_tensor_batch['original_program'],
718
+ 'composite_functions': b.non_tensor_batch['composite_functions'].tolist(),
719
+ })
720
+ batched_responses.append(batch_dict)
721
+ df = pd.DataFrame(batched_responses)
722
+
723
+ # estimating accuracy using python executor
724
+ PrettyPrinter.section_header("Estimating Accuracy Using Python Executor")
725
+ for valid_uid in valid_uids:
726
+ df_valid = df[df['uid'] == valid_uid]
727
+ if df_valid.empty: # the prompt got filtered out TODO: check
728
+ accuracies[valid_uid] = 0.0
729
+ continue
730
+ if pt != 'code_f':
731
+ answers = [self.extract_input_output(
732
+ answer,
733
+ return_input=problem_type.endswith('code_i'),
734
+ return_output=(problem_type.endswith('code_o') or problem_type.endswith('code_e')) # code_e output format is same as code_o
735
+ ) for answer in df_valid['extracted_answers'].tolist()]
736
+ else:
737
+ answers = [parse_code_function(answer) for answer in df_valid['extracted_answers'].tolist()]
738
+ answer_cache = {} # for the same uid, the answer is the same and the program is assumed to be deterministic, therefore we cache the answer -> accuracy mapping
739
+ if pt == 'code_f':
740
+ hidden_outputs = df_valid['hidden_outputs'].tolist()[0].tolist()
741
+ hidden_inputs = df_valid['hidden_inputs'].tolist()[0].tolist()
742
+ else:
743
+ gold_output = df_valid['output'].tolist()[0]
744
+ program = df_valid['program'].tolist()[0]
745
+ # gold_input = df_valid['input'].tolist()[0]
746
+ imports = df_valid['imports'].tolist()[0]
747
+ problem_accuracies = []
748
+ if problem_type.endswith('code_i'):
749
+ if self.batched_estimate:
750
+ problem_accuracies = executor.eval_k_input_prediction(code=program, gold_output=gold_output, k_agent_inputs=answers, imports=list(set(imports)))
751
+ else:
752
+ for answer in answers:
753
+ if answer in answer_cache:
754
+ problem_accuracies.append(answer_cache[answer])
755
+ continue
756
+ acc_reward = executor.eval_input_prediction(code=program, gold_output=gold_output, agent_input=answer, imports=list(set(imports)))
757
+ if acc_reward is not None:
758
+ problem_accuracies.append(acc_reward)
759
+ answer_cache[answer] = acc_reward
760
+ # if self.debug:
761
+ # batched_problem_accuracies = executor.eval_k_input_prediction(code=program, gold_output=gold_output, k_agent_inputs=answers, imports=list(set(imports)))
762
+ # assert np.mean(batched_problem_accuracies) == np.mean(problem_accuracies), f"Gen I batch accuracy: {np.mean(batched_problem_accuracies)}, Single accuracy: {np.mean(problem_accuracies)}"
763
+ elif problem_type.endswith('code_o'):
764
+ if self.batched_estimate:
765
+ problem_accuracies = executor.eval_k_output_prediction(code=program, gold_output=gold_output, k_agent_outputs=answers, imports=list(set(imports)))
766
+ else:
767
+ for answer in answers:
768
+ if answer in answer_cache:
769
+ problem_accuracies.append(answer_cache[answer])
770
+ continue
771
+ acc_reward = executor.eval_output_prediction(code=program, gold_output=gold_output, agent_output=answer, imports=list(set(imports)))
772
+ if acc_reward is not None:
773
+ problem_accuracies.append(acc_reward)
774
+ answer_cache[answer] = acc_reward
775
+ # if self.debug:
776
+ # batched_problem_accuracies = executor.eval_k_output_prediction(code=program, gold_output=gold_output, k_agent_outputs=answers, imports=list(set(imports)))
777
+ # assert np.mean(batched_problem_accuracies) == np.mean(problem_accuracies), f"Gen O batch accuracy: {np.mean(batched_problem_accuracies)}, Single accuracy: {np.mean(problem_accuracies)}"
778
+ elif problem_type.endswith('code_e'): # string matching for errors
779
+ for answer in answers:
780
+ answer = answer.split(' ')[0].split(':')[0]
781
+ if answer.lower() == gold_output.lower():
782
+ problem_accuracies.append(1.0)
783
+ else:
784
+ problem_accuracies.append(0.0)
785
+ elif problem_type.endswith('code_f'):
786
+ for parsed, answer in answers: # for each input/output set, we sampled n codes to estimate the accuracy
787
+ if not parsed: # the code answer is not parsed, we assume the code is not valid
788
+ problem_accuracies.append(0.0)
789
+ continue
790
+ code_accuracies = []
791
+ for inpt, outpt in zip(hidden_inputs, hidden_outputs):
792
+ code_accuracies.append(executor.eval_input_prediction(code=answer, gold_output=outpt, agent_input=inpt, imports=list(set(imports))))
793
+ answer_acc = np.mean([a for a in code_accuracies if a is not None]) if code_accuracies else 0.0
794
+ if self.code_f_reward_type == 'binary':
795
+ problem_accuracies.append(1.0 if answer_acc == 1.0 else 0.0)
796
+ elif self.code_f_reward_type == 'if_one_correct':
797
+ problem_accuracies.append(1.0 if answer_acc > 0 else 0.0)
798
+ elif self.code_f_reward_type == 'accuracy':
799
+ problem_accuracies.append(answer_acc)
800
+ else:
801
+ raise ValueError(f"Invalid code_f_reward_type: {self.code_f_reward_type}")
802
+ accuracies[valid_uid] = sum(problem_accuracies) / len(problem_accuracies) if problem_accuracies else 0.0
803
+
804
+ # filtering valid programs
805
+ if self.valid_program_filter == 'all':
806
+ valid_programs.append(valid_data_dicts[uid2valid_dict_idx[valid_uid]]['answer'])
807
+ elif self.valid_program_filter == 'non_one':
808
+ if accuracies[valid_uid] < 1.0:
809
+ valid_programs.append(valid_data_dicts[uid2valid_dict_idx[valid_uid]]['answer'])
810
+ elif self.valid_program_filter == 'non_extremes':
811
+ if accuracies[valid_uid] > 0.0 and accuracies[valid_uid] < 1.0:
812
+ valid_programs.append(valid_data_dicts[uid2valid_dict_idx[valid_uid]]['answer'])
813
+ else:
814
+ raise ValueError(f"Invalid valid program filter: {self.valid_program_filter}")
815
+
816
+ # collecting invalid programs for analysis
817
+ invalid_data_dicts = [data_dict for data_dict in data_dicts if not data_dict['code_validity']]
818
+ for i, invalid_data_dict in enumerate(invalid_data_dicts):
819
+ # Create a unique label for each invalid problem
820
+ problem_label = f"{problem_type}_invalid_{invalid_data_dict.get('uid', i)}"
821
+
822
+ # Store the full LLM prompt and response for analysis
823
+ invalid_program = {
824
+ 'problem_id': problem_label,
825
+ 'llm_prompt': invalid_data_dict.get('prompt', ''),
826
+ 'llm_response': invalid_data_dict.get('response', ''),
827
+ 'invalid_reason': invalid_data_dict.get('error_message', 'Parsing or validation failed'),
828
+ 'format_score': invalid_data_dict.get('format_score', 0.0),
829
+ 'problem_type': problem_type,
830
+ 'timestamp': invalid_data_dict.get('timestamp', ''),
831
+ 'raw_data': invalid_data_dict # Keep original data for debugging
832
+ }
833
+ invalid_programs.append(invalid_program)
834
+
835
+ # getting other rewards
836
+ PrettyPrinter.section_header("Getting Other Rewards")
837
+ # outputting rewards
838
+ for d in data_dicts:
839
+ uid = d['uid']
840
+ if self.generation_reward_config.generation_accuracy_convertion == 'one_minus':
841
+ rewards[uid]['accuracy'] = (1 - accuracies[uid]) if accuracies[uid] > 0 else 0.0
842
+ elif self.generation_reward_config.generation_accuracy_convertion == 'inverse':
843
+ rewards[uid]['accuracy'] = 1 - accuracies[uid]
844
+ else:
845
+ raise ValueError(f"Invalid generation accuracy convertion: {self.generation_reward_config.generation_accuracy_convertion}")
846
+
847
+ if not problem_type.endswith('code_f'):
848
+ code_key = 'original_snippet' if self.use_original_code_as_ref else 'snippet'
849
+ reference_key = 'original_references' if self.use_original_code_as_ref else 'references'
850
+ if problem_type.endswith('code_i'):
851
+ type_counter_key = 'input'
852
+ elif problem_type.endswith('code_o'):
853
+ type_counter_key = 'output'
854
+ elif problem_type.endswith('code_e'):
855
+ type_counter_key = 'error'
856
+ else:
857
+ raise ValueError(f"Invalid problem type: {problem_type}")
858
+ for data_dict in data_dicts:
859
+ rewards[data_dict['uid']]['complexity'] = get_code_complexity_reward(data_dict['answer'][code_key]) if 'answer' in data_dict else 0.0
860
+ for data_dict in data_dicts:
861
+ rewards[data_dict['uid']]['mean_edit_distance'] = np.mean([ast_edit_distance(data_dict['answer'][code_key], ref) for ref in data_dict[reference_key]]) if 'answer' in data_dict else 0.0
862
+ for data_dict in data_dicts:
863
+ rewards[data_dict['uid']]['halstead'] = get_halstead_reward(data_dict['answer'][code_key]) if 'answer' in data_dict else 0.0
864
+ for data_dict in data_dicts:
865
+ rewards[data_dict['uid']]['type_counts'] = get_type_counts_reward(
866
+ data_dict['answer'][type_counter_key],
867
+ type_counters,
868
+ hierarchical=self.generation_reward_config.answer_diversity_reward.hierarchical
869
+ ) if 'answer' in data_dict else 0.0
870
+ if self.debug:
871
+ for data_dict in data_dicts:
872
+ if 'answer' in data_dict:
873
+ continue
874
+ else:
875
+ for data_dict in data_dicts:
876
+ rewards[data_dict['uid']]['input_type_counts'] = []
877
+ rewards[data_dict['uid']]['output_type_counts'] = []
878
+ if 'answer' in data_dict:
879
+ for inpt, outpt in zip(data_dict['answer']['inputs'], data_dict['answer']['outputs']):
880
+ rewards[data_dict['uid']]['input_type_counts'].append(get_type_counts_reward(
881
+ inpt,
882
+ input_type_counters,
883
+ hierarchical=self.generation_reward_config.answer_diversity_reward.hierarchical
884
+ ))
885
+ rewards[data_dict['uid']]['output_type_counts'].append(get_type_counts_reward(
886
+ outpt,
887
+ output_type_counters,
888
+ hierarchical=self.generation_reward_config.answer_diversity_reward.hierarchical
889
+ ))
890
+ rewards[data_dict['uid']]['input_type_counts'] = np.mean(rewards[data_dict['uid']]['input_type_counts'])
891
+ rewards[data_dict['uid']]['output_type_counts'] = np.mean(rewards[data_dict['uid']]['output_type_counts'])
892
+ else:
893
+ rewards[data_dict['uid']]['input_type_counts'] = 0.0
894
+ rewards[data_dict['uid']]['output_type_counts'] = 0.0
895
+
896
+ # turn into normal dict
897
+ rewards = dict(rewards)
898
+ return rewards, valid_programs, invalid_programs
absolute_zero_reasoner/rewards/ttrlvr_reward_manager.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTRLVR Reward Manager for AZR Integration
3
+
4
+ TTRLVR의 complete_pipeline.py에 있는 _compute_rewards_with_azr 로직을
5
+ 완전히 동일하게 AZR에서 사용할 수 있도록 통합
6
+ """
7
+
8
+ import re
9
+ from typing import Dict, Any, List, Optional, Tuple
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+
13
+ from .reward_managers import CodeIORewardManager
14
+ from ..utils.code_utils.python_executor import PythonExecutor
15
+ from ..utils.code_utils.templates import EVAL_INPUT_PREDICTION_TEMPLATE
16
+
17
+
18
+ class TTRLVRRewardManager(CodeIORewardManager):
19
+ """TTRLVR 전용 Reward Manager - complete_pipeline.py의 로직 그대로 사용"""
20
+
21
+ def __init__(self, tokenizer: AutoTokenizer, **kwargs):
22
+ super().__init__(tokenizer=tokenizer, **kwargs)
23
+ self.executor = PythonExecutor()
24
+
25
+ def compute_rewards(self,
26
+ prompts: List[str],
27
+ responses: List[str],
28
+ metadata: List[Dict[str, Any]]) -> List[float]:
29
+ """
30
+ TTRLVR complete_pipeline.py의 _compute_rewards_with_azr과 동일한 로직
31
+
32
+ Args:
33
+ prompts: 프롬프트 리스트
34
+ responses: 모델이 생성한 응답 리스트
35
+ metadata: 각 task의 메타데이터 (task_type, evaluation_data 등)
36
+
37
+ Returns:
38
+ rewards: 각 응답에 대한 reward 리스트
39
+ """
40
+ rewards = []
41
+
42
+ # 간단한 디버깅 정보 출력
43
+ if metadata and len(metadata) > 0:
44
+ # Task type 분포 확인
45
+ task_types = [m.get('task_type', 'unknown') for m in metadata]
46
+ from collections import Counter
47
+ task_count = Counter(task_types)
48
+ print(f"\n[TTRLVR] Processing batch - Task distribution: {dict(task_count)}")
49
+
50
+ for prompt, response, meta in zip(prompts, responses, metadata):
51
+ task_type = meta.get('task_type', 'unknown')
52
+ evaluation_data = meta.get('evaluation_data', {})
53
+ expected = meta.get('expected_solution', '')
54
+
55
+ # complete_pipeline.py:458과 동일
56
+ extracted_answer = self._extract_answer_by_task_type(response, task_type)
57
+
58
+ # 실제 코드 실행 기반 평가 (complete_pipeline.py:461-584와 동일)
59
+ try:
60
+ if task_type == 'abduction':
61
+ # complete_pipeline.py:462-548 그대로
62
+ code = evaluation_data['function_code']
63
+ expected_output = evaluation_data['expected_output']
64
+ agent_input = extracted_answer
65
+
66
+ # 함수 정의만 추출
67
+ code = self._extract_function_definition(code)
68
+
69
+ # 함수명 추출 및 f로 변경
70
+ func_name_match = re.search(r'def\s+(\w+)\s*\(', code)
71
+ if func_name_match:
72
+ original_func_name = func_name_match.group(1)
73
+ code = re.sub(r'def\s+' + re.escape(original_func_name) + r'\s*\(', 'def f(', code)
74
+
75
+ # expected_output을 실제 값으로 변환
76
+ try:
77
+ expected_output_value = eval(expected_output)
78
+ except:
79
+ expected_output_value = expected_output
80
+
81
+ # EVAL_INPUT_PREDICTION_TEMPLATE 사용
82
+ try:
83
+ code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format(
84
+ code=code,
85
+ gold_output=expected_output_value,
86
+ agent_input=agent_input
87
+ )
88
+ result, status = self.executor.apply(code_snippet)
89
+
90
+ if 'error' in status.lower():
91
+ accuracy = 0.0
92
+ else:
93
+ # 실행 결과와 expected output 비교
94
+ try:
95
+ if isinstance(result, bool):
96
+ agent_output = result
97
+ else:
98
+ agent_output = eval(result)
99
+ accuracy = 1.0 if agent_output else 0.0
100
+ except:
101
+ accuracy = 0.0
102
+ except:
103
+ accuracy = 0.0
104
+
105
+ elif task_type == 'deduction':
106
+ # complete_pipeline.py:549-558 그대로
107
+ expected_output = expected
108
+ agent_output = extracted_answer
109
+
110
+ # 간단한 eval 비교
111
+ try:
112
+ accuracy = 1.0 if eval(expected_output) == eval(agent_output) else 0.0
113
+ except:
114
+ accuracy = 0.0
115
+
116
+ elif task_type == 'induction':
117
+ # complete_pipeline.py:560-575 그대로
118
+ input_output_pairs = evaluation_data.get('input_output_pairs', [])
119
+ agent_code = extracted_answer
120
+
121
+ # numpy array 처리: JSON/Parquet 저장으로 인한 데이터 형식 변환 해결
122
+ import numpy as np
123
+ if isinstance(input_output_pairs, np.ndarray):
124
+ # 이중 중첩된 numpy array 해제 (parquet 저장 시 발생)
125
+ # array([array(['input', 'output'])]) -> array(['input', 'output'])
126
+ if len(input_output_pairs) == 1 and isinstance(input_output_pairs[0], np.ndarray):
127
+ # 단일 테스트 케이스: array를 리스트로 변환 후 다시 리스트로 감싸기
128
+ inner_array = input_output_pairs[0]
129
+ input_output_pairs = [inner_array.tolist()] # [['input', 'output']]
130
+ else:
131
+ # 여러 테스트 케이스: 각각을 리스트로 변환
132
+ input_output_pairs = [item.tolist() if isinstance(item, np.ndarray) else item
133
+ for item in input_output_pairs]
134
+
135
+ # 리스트인데 직접 ['input', 'output'] 형태인 경우 처리
136
+ if isinstance(input_output_pairs, list) and len(input_output_pairs) == 2:
137
+ if isinstance(input_output_pairs[0], str) and isinstance(input_output_pairs[1], str):
138
+ # 단일 테스트 케이스로 감싸기
139
+ input_output_pairs = [input_output_pairs]
140
+
141
+ # 모든 input-output 쌍에 대해 테스트
142
+ accuracies = []
143
+
144
+ for i, pair in enumerate(input_output_pairs):
145
+ try:
146
+ # numpy array인 경우 리스트로 변환
147
+ if isinstance(pair, np.ndarray):
148
+ pair = pair.tolist()
149
+
150
+ # 리스트 또는 튜플에서 입력과 출력 추출
151
+ if isinstance(pair, (list, tuple)) and len(pair) >= 2:
152
+ test_input, expected_output = pair[0], pair[1]
153
+ else:
154
+ # 잘못된 형식의 경우 0점 처리
155
+ accuracies.append(0.0)
156
+ continue
157
+
158
+ # 실제 코드 실행 및 평가
159
+ accuracy = self.executor.eval_input_prediction(agent_code, expected_output, test_input)
160
+ accuracies.append(accuracy if accuracy is not None else 0.0)
161
+ except Exception as e:
162
+ # 예외 발생 시 0점 처리
163
+ accuracies.append(0.0)
164
+
165
+ # 평균 정확도 계산
166
+ accuracy = sum(accuracies) / len(accuracies) if accuracies else 0.0
167
+
168
+ else:
169
+ # complete_pipeline.py:578-579 그대로
170
+ accuracy = 1.0 if expected.strip() == extracted_answer.strip() else 0.0
171
+
172
+ except Exception as e:
173
+ accuracy = 0.0
174
+
175
+ rewards.append(accuracy)
176
+
177
+ # 계산된 rewards 요약 (간단히)
178
+ if rewards:
179
+ import numpy as np
180
+ mean_reward = np.mean(rewards)
181
+ print(f"[TTRLVR] Batch rewards - Mean: {mean_reward:.4f}, Min: {min(rewards):.4f}, Max: {max(rewards):.4f}")
182
+
183
+ return rewards
184
+
185
+ def _extract_answer_by_task_type(self, llm_response: str, task_type: str) -> str:
186
+ """complete_pipeline.py의 _extract_answer_by_task_type와 동일"""
187
+ # <answer>...</answer> 태그 추출
188
+ match = re.search(r'<answer>(.*?)</answer>', llm_response, re.DOTALL)
189
+ if match:
190
+ answer_content = match.group(1).strip()
191
+
192
+ # Task 타입별 후처리
193
+ if task_type == 'induction':
194
+ # 코드 블록에서 def f(...) 추출
195
+ if 'def f(' in answer_content:
196
+ return answer_content
197
+ # 코드 블록 마커 제거
198
+ answer_content = answer_content.replace('```python', '').replace('```', '').strip()
199
+ return answer_content
200
+
201
+ elif task_type == 'deduction':
202
+ # 출력값 정리
203
+ return answer_content.strip()
204
+
205
+ elif task_type == 'abduction':
206
+ # 입��값 정리
207
+ return answer_content.strip()
208
+
209
+ # 태그가 없으면 전체 응답 반환
210
+ return llm_response.strip()
211
+
212
+ def _extract_function_definition(self, code: str) -> str:
213
+ """complete_pipeline.py:470-502와 동일한 함수"""
214
+ lines = code.split('\n')
215
+ import_lines = []
216
+ func_lines = []
217
+ in_function = False
218
+ base_indent = None
219
+
220
+ for line in lines:
221
+ # import 문 수집
222
+ if line.strip().startswith('from ') or line.strip().startswith('import '):
223
+ import_lines.append(line)
224
+ # 함수 정의 시작
225
+ elif line.strip().startswith('def '):
226
+ in_function = True
227
+ base_indent = len(line) - len(line.lstrip())
228
+ func_lines.append(line)
229
+ elif in_function:
230
+ # 빈 줄이거나 함수 내부인 경우
231
+ if line.strip() == '':
232
+ func_lines.append(line)
233
+ elif line.startswith(' ' * (base_indent + 1)) or line.startswith('\t'):
234
+ # 함수 내부 (들여쓰기가 더 깊음)
235
+ func_lines.append(line)
236
+ else:
237
+ # 함수 외부 코드 (assert 문 등) - 중단
238
+ break
239
+
240
+ # import문과 함수를 합쳐서 반환
241
+ if import_lines:
242
+ return '\n'.join(import_lines) + '\n\n' + '\n'.join(func_lines)
243
+ else:
244
+ return '\n'.join(func_lines)
absolute_zero_reasoner/testtime/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TestTime RLVR Components
3
+
4
+ This module contains all TestTime-specific components adapted from AZR:
5
+ - BenchmarkProblemLoader: 벤치마크 문제 로딩
6
+ - IPOTripleExtractor: (Input, Program, Output) 트리플 추출
7
+ - TestTimeTaskGenerator: Induction/Deduction/Abduction 태스크 생성
8
+ - TestTimeRLVRTrainer: TestTime 특화 RLVR 학습
9
+ - TestTimeRewardManager: TestTime 보상 계산
10
+ - TestTimeLogger: 포괄적 로깅 시스템
11
+ """
12
+
13
+ from .benchmark_loader import BenchmarkProblemLoader
14
+ from .solution_generator import InitialSolutionGenerator
15
+ from .ipo_extractor import IPOTripleExtractor
16
+ from .task_generator import TestTimeTaskGenerator
17
+ from .logger import TestTimeLogger
18
+ from .config import TestTimeConfig, BenchmarkConfig
19
+
20
+ # 향후 구현 예정
21
+ # from .testtime_trainer import TestTimeRLVRTrainer
22
+ # from .reward_manager import TestTimeRewardManager
23
+
24
+ __all__ = [
25
+ 'BenchmarkProblemLoader',
26
+ 'InitialSolutionGenerator',
27
+ 'IPOTripleExtractor',
28
+ 'TestTimeTaskGenerator',
29
+ 'TestTimeLogger',
30
+ 'TestTimeConfig',
31
+ 'BenchmarkConfig'
32
+ # 'TestTimeRLVRTrainer',
33
+ # 'TestTimeRewardManager',
34
+ ]
absolute_zero_reasoner/testtime/benchmark_loader.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Benchmark Problem Loader
3
+
4
+ AZR 기반 TestTime RLVR을 위한 벤치마크 문제 로딩 시스템
5
+ 기존 Test-Time-RLVR의 load_humaneval_problem 함수를 확장
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from typing import Dict, List, Any, Tuple, Optional
11
+ from pathlib import Path
12
+
13
+ from .config import BenchmarkConfig, TestTimeConfig
14
+ from .logger import TestTimeLogger
15
+
16
+
17
+ class BenchmarkProblemLoader:
18
+ """벤치마크 문제 로딩 및 관리 (EvalPlus 표준 방식 사용)"""
19
+
20
+ def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None):
21
+ self.config = config
22
+ self.logger = logger or TestTimeLogger()
23
+ self.loaded_problems = {} # 캐시
24
+ self.evalplus_cache = {} # EvalPlus 데이터 캐시
25
+
26
+ def _load_evalplus_data(self, benchmark_name: str) -> Dict[str, Dict[str, Any]]:
27
+ """EvalPlus 데이터 로드 및 캐시"""
28
+ if benchmark_name in self.evalplus_cache:
29
+ return self.evalplus_cache[benchmark_name]
30
+
31
+ try:
32
+ if benchmark_name == 'mbpp':
33
+ from evalplus.data.mbpp import get_mbpp_plus
34
+ problems = get_mbpp_plus() # 자동으로 mbpp_deserialize_inputs 적용됨
35
+ self.logger.log_info(f"✅ MBPP+ EvalPlus 데이터 로드 성공: {len(problems)}개 문제")
36
+ elif benchmark_name == 'humaneval':
37
+ from evalplus.data.humaneval import get_human_eval_plus
38
+ problems = get_human_eval_plus() # EvalPlus 표준 방식
39
+ self.logger.log_info(f"✅ HumanEval+ EvalPlus 데이터 로드 성공: {len(problems)}개 문제")
40
+ else:
41
+ raise ValueError(f"Unsupported benchmark for EvalPlus: {benchmark_name}")
42
+
43
+ self.evalplus_cache[benchmark_name] = problems
44
+ return problems
45
+
46
+ except Exception as e:
47
+ self.logger.log_error(f"❌ {benchmark_name.upper()}+ EvalPlus 로딩 실패: {e}")
48
+ return {}
49
+
50
+ def load_problem(self, benchmark_config: BenchmarkConfig, problem_id: str) -> Dict[str, Any]:
51
+ """특정 벤치마크 문제 로드 (EvalPlus 표준 방식 우선 사용)"""
52
+
53
+ cache_key = f"{benchmark_config.name}_{problem_id}"
54
+ if cache_key in self.loaded_problems:
55
+ return self.loaded_problems[cache_key]
56
+
57
+ # EvalPlus 방식 시도
58
+ if benchmark_config.name in ['mbpp', 'humaneval']:
59
+ evalplus_problems = self._load_evalplus_data(benchmark_config.name)
60
+ if problem_id in evalplus_problems:
61
+ problem = evalplus_problems[problem_id].copy()
62
+ # 추가 메타데이터 설정
63
+ problem['benchmark_name'] = benchmark_config.name
64
+ problem['benchmark_config'] = benchmark_config
65
+
66
+ # 캐시에 저장
67
+ self.loaded_problems[cache_key] = problem
68
+ self.logger.log_info(f"✅ Problem loaded: {problem_id} from {benchmark_config.name} (EvalPlus)")
69
+ return problem
70
+
71
+ # Fallback: 기존 방식
72
+ self.logger.log_info(f"⚠️ {problem_id} EvalPlus 로딩 실패, 기존 방식 사용")
73
+ problem_file = benchmark_config.data_path
74
+
75
+ # 파일 존재 확인
76
+ if not os.path.exists(problem_file):
77
+ raise FileNotFoundError(f"Benchmark file not found: {problem_file}")
78
+
79
+ # JSONL 파일 로드 (기존 방식과 동일)
80
+ with open(problem_file, 'r', encoding='utf-8') as f:
81
+ problems = [json.loads(line) for line in f]
82
+
83
+ # 문제 ID로 검색
84
+ for problem in problems:
85
+ if problem['task_id'] == problem_id:
86
+ # 추가 메타데이터 설정
87
+ problem['benchmark_name'] = benchmark_config.name
88
+ problem['benchmark_config'] = benchmark_config
89
+
90
+ # 캐시에 저장
91
+ self.loaded_problems[cache_key] = problem
92
+
93
+ self.logger.log_info(f"✅ Problem loaded: {problem_id} from {benchmark_config.name} (Original)")
94
+ return problem
95
+
96
+ raise ValueError(f"Problem {problem_id} not found in {problem_file}")
97
+
98
+ def load_problem_batch(self, benchmark_config: BenchmarkConfig,
99
+ problem_ids: List[str]) -> List[Dict[str, Any]]:
100
+ """여러 문제 배치 로딩"""
101
+ problems = []
102
+ for problem_id in problem_ids:
103
+ try:
104
+ problem = self.load_problem(benchmark_config, problem_id)
105
+ problems.append(problem)
106
+ except Exception as e:
107
+ self.logger.log_error(f"Failed to load {problem_id}: {e}")
108
+
109
+ return problems
110
+
111
+ def get_test_cases(self, problem: Dict[str, Any]) -> List[Tuple[str, str]]:
112
+ """문제에서 테스트 케이스 추출"""
113
+ test_cases = []
114
+
115
+ # 기본 테스트 케이스 (test 필드)
116
+ if 'test' in problem:
117
+ test_code = problem['test']
118
+ # assert 문에서 입력-출력 쌍 추출
119
+ test_cases.extend(self._parse_assert_statements(test_code))
120
+
121
+ # Plus 테스트 케이스 (plus_input, plus_output)
122
+ if 'plus_input' in problem and 'plus_output' in problem:
123
+ plus_inputs = problem['plus_input']
124
+ plus_outputs = problem['plus_output']
125
+
126
+ if isinstance(plus_inputs, str):
127
+ plus_inputs = json.loads(plus_inputs)
128
+ if isinstance(plus_outputs, str):
129
+ plus_outputs = json.loads(plus_outputs)
130
+
131
+ for inp, out in zip(plus_inputs, plus_outputs):
132
+ test_cases.append((str(inp), str(out)))
133
+
134
+ return test_cases
135
+
136
+ def _parse_assert_statements(self, test_code: str) -> List[Tuple[str, str]]:
137
+ """assert 문에서 입력-출력 쌍 추출"""
138
+ import re
139
+
140
+ test_cases = []
141
+ lines = test_code.strip().split('\n')
142
+
143
+ for line in lines:
144
+ line = line.strip()
145
+ if line.startswith('assert '):
146
+ # assert function(args) == expected 형태 파싱
147
+ match = re.match(r'assert\s+(\w+)\(([^)]*)\)\s*==\s*(.+)', line)
148
+ if match:
149
+ func_name, args, expected = match.groups()
150
+ test_cases.append((args.strip(), expected.strip()))
151
+
152
+ return test_cases
153
+
154
+ def validate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]:
155
+ """솔루션 검증 (AZR Python Executor 사용 예정)"""
156
+
157
+ validation_result = {
158
+ 'problem_id': problem['task_id'],
159
+ 'solution': solution,
160
+ 'syntax_valid': False,
161
+ 'test_results': [],
162
+ 'overall_success': False,
163
+ 'error_message': None
164
+ }
165
+
166
+ try:
167
+ # 1. 구문 검증
168
+ compile(solution, '<string>', 'exec')
169
+ validation_result['syntax_valid'] = True
170
+
171
+ # 2. 테스트 케이스 실행 (향후 AZR Python Executor 연동)
172
+ test_cases = self.get_test_cases(problem)
173
+ validation_result['test_results'] = [
174
+ {'input': inp, 'expected': out, 'passed': False}
175
+ for inp, out in test_cases
176
+ ]
177
+
178
+ # 임시: 구문만 통과하면 성공으로 처리
179
+ validation_result['overall_success'] = True
180
+
181
+ except SyntaxError as e:
182
+ validation_result['error_message'] = f"Syntax Error: {e}"
183
+ except Exception as e:
184
+ validation_result['error_message'] = f"Validation Error: {e}"
185
+
186
+ return validation_result
187
+
188
+ def get_sequential_problems(self, benchmark_config: BenchmarkConfig,
189
+ num_problems: int) -> List[Dict[str, Any]]:
190
+ """순차적으로 N개 문제 로드"""
191
+ problems = []
192
+
193
+ for i in range(num_problems):
194
+ problem_index = benchmark_config.start_index + i
195
+ problem_id = f"{benchmark_config.problem_prefix}/{problem_index}"
196
+
197
+ try:
198
+ problem = self.load_problem(benchmark_config, problem_id)
199
+ problems.append(problem)
200
+ except Exception as e:
201
+ self.logger.log_error(f"Failed to load {problem_id}: {e}")
202
+ continue
203
+
204
+ return problems
205
+
206
+ def get_problem_statistics(self, benchmark_config: BenchmarkConfig) -> Dict[str, Any]:
207
+ """벤치마크 통계 정보"""
208
+ problem_file = benchmark_config.data_path
209
+
210
+ if not os.path.exists(problem_file):
211
+ return {"error": f"File not found: {problem_file}"}
212
+
213
+ with open(problem_file, 'r', encoding='utf-8') as f:
214
+ problems = [json.loads(line) for line in f]
215
+
216
+ stats = {
217
+ 'total_problems': len(problems),
218
+ 'benchmark_name': benchmark_config.name,
219
+ 'data_file': problem_file,
220
+ 'sample_problem_ids': [p['task_id'] for p in problems[:5]]
221
+ }
222
+
223
+ return stats
absolute_zero_reasoner/testtime/complete_pipeline.py ADDED
The diff for this file is too large to render. See raw diff
 
absolute_zero_reasoner/testtime/config.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TestTime RLVR Configuration
3
+
4
+ AZR 기반 TestTime RLVR을 위한 설정 클래스
5
+ """
6
+
7
+ from dataclasses import dataclass
8
+ from typing import Optional, List, Dict, Any
9
+ import torch
10
+
11
+
12
+ @dataclass
13
+ class TestTimeConfig:
14
+ """TestTime RLVR 전용 설정"""
15
+
16
+ # ============================================================================
17
+ # 기본 모델 설정 (AZR 기반)
18
+ # ============================================================================
19
+ model_name: str = "Qwen/Qwen2.5-7B"
20
+ device: str = "auto"
21
+ torch_dtype: torch.dtype = torch.bfloat16
22
+ use_flash_attention: bool = True
23
+ enable_gradient_checkpointing: bool = True
24
+
25
+ # ============================================================================
26
+ # TestTime 학습 설정
27
+ # ============================================================================
28
+ max_adaptation_steps: int = 10 # AZR 대비 짧은 적응 학습
29
+ adaptation_batch_size: int = 1 # 소규모 배치
30
+ gradient_accumulation_steps: int = 4
31
+ learning_rate: float = 1e-6 # AZR과 동일
32
+
33
+ # ============================================================================
34
+ # 반복 제어 설정
35
+ # ============================================================================
36
+ max_cycles: int = 3 # 최대 반복 횟수
37
+ min_improvement_threshold: float = 0.05 # 최소 개선 임계값
38
+ early_stopping_patience: int = 2 # Early stopping
39
+
40
+ # ============================================================================
41
+ # IPO 추출 설정
42
+ # ============================================================================
43
+ max_ipo_triples: int = 10 # 추출할 최대 트리플 수
44
+ python_executor_timeout: int = 5 # AZR보다 짧은 타임아웃
45
+ validate_triples: bool = True # 트리플 검증 여부
46
+
47
+ # ============================================================================
48
+ # 다중 프로그램 생성 설정
49
+ # ============================================================================
50
+ num_program_variations: int = 4 # 생성할 다양한 프로그램 수
51
+ baseline_evaluation_rounds: int = 5 # 베이스라인 성능 측정 횟수
52
+ diverse_generation_temperature: float = 0.7 # 다양한 프로그램 생성용 temperature
53
+ baseline_generation_temperature: float = 0.05 # 베이스라인 측정용 temperature
54
+
55
+ # ============================================================================
56
+ # 태스크 생성 설정
57
+ # ============================================================================
58
+ task_distribution: Dict[str, float] = None # induction:deduction:abduction 비율
59
+ max_tasks_per_type: int = 5 # 타입별 최대 태스크 수
60
+ use_azr_templates: bool = True # AZR 템플릿 사용
61
+ skip_task_evaluation: bool = True # Task evaluation(4단계) 스킵 여부 (VeRL에서 수행)
62
+
63
+ # ============================================================================
64
+ # 보상 설정 (AZR 기반)
65
+ # ============================================================================
66
+ use_accuracy_reward: bool = True
67
+ use_improvement_reward: bool = True # TestTime 전용 개선도 보상
68
+ use_complexity_reward: bool = True
69
+ accuracy_weight: float = 1.0
70
+ improvement_weight: float = 0.5 # 개선도 가중치
71
+ complexity_weight: float = 0.1
72
+
73
+ # ============================================================================
74
+ # 로깅 설정
75
+ # ============================================================================
76
+ log_level: str = "INFO"
77
+ save_intermediate_results: bool = True
78
+ log_ipo_details: bool = True
79
+ log_task_details: bool = True
80
+ log_training_metrics: bool = True
81
+
82
+ # ============================================================================
83
+ # 메모리 최적화 설정 (AZR 기반)
84
+ # ============================================================================
85
+ gpu_memory_utilization: float = 0.4
86
+ max_workers: int = 2 # Python executor workers
87
+ use_memory_efficient_attention: bool = True
88
+
89
+ def __post_init__(self):
90
+ """설정 후처리"""
91
+ if self.task_distribution is None:
92
+ # 기본 태스크 분포: 균등 분배
93
+ self.task_distribution = {
94
+ "induction": 0.33,
95
+ "deduction": 0.33,
96
+ "abduction": 0.34
97
+ }
98
+
99
+ # device 자동 설정
100
+ if self.device == "auto":
101
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
102
+
103
+ # dtype 설정
104
+ if self.device == "cpu":
105
+ self.torch_dtype = torch.float32
106
+
107
+ def to_dict(self) -> Dict[str, Any]:
108
+ """설정을 딕셔너리로 변환"""
109
+ return {
110
+ "model_name": self.model_name,
111
+ "device": self.device,
112
+ "torch_dtype": str(self.torch_dtype),
113
+ "max_adaptation_steps": self.max_adaptation_steps,
114
+ "max_cycles": self.max_cycles,
115
+ "learning_rate": self.learning_rate,
116
+ "task_distribution": self.task_distribution,
117
+ "reward_weights": {
118
+ "accuracy": self.accuracy_weight,
119
+ "improvement": self.improvement_weight,
120
+ "complexity": self.complexity_weight
121
+ }
122
+ }
123
+
124
+ @classmethod
125
+ def from_dict(cls, config_dict: Dict[str, Any]) -> 'TestTimeConfig':
126
+ """딕셔너리에서 설정 로드"""
127
+ return cls(**config_dict)
128
+
129
+
130
+ @dataclass
131
+ class BenchmarkConfig:
132
+ """벤치마크별 설정"""
133
+
134
+ name: str # "humaneval", "mbpp", "livecodebase"
135
+ data_path: str
136
+ problem_prefix: str # "HumanEval", "Mbpp"
137
+ start_index: int = 0 # MBPP는 2부터 시작
138
+ max_problems: int = 5 # 테스트할 문제 수
139
+
140
+ # 벤치마크별 특화 설정
141
+ test_timeout: int = 10
142
+ use_plus_version: bool = True # HumanEval+, MBPP+ 사용
143
+
144
+ @classmethod
145
+ def get_humaneval_config(cls) -> 'BenchmarkConfig':
146
+ return cls(
147
+ name="humaneval",
148
+ data_path="evaluation/code_eval/data/HumanEvalPlus.jsonl",
149
+ problem_prefix="HumanEval",
150
+ start_index=0,
151
+ max_problems=5
152
+ )
153
+
154
+ @classmethod
155
+ def get_mbpp_config(cls) -> 'BenchmarkConfig':
156
+ return cls(
157
+ name="mbpp",
158
+ data_path="evaluation/code_eval/data/MbppPlus.jsonl",
159
+ problem_prefix="Mbpp",
160
+ start_index=2, # MBPP는 2번부터
161
+ max_problems=5
162
+ )
absolute_zero_reasoner/testtime/ipo_extractor.py ADDED
@@ -0,0 +1,1235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IPO Triple Extractor
3
+
4
+ AZR Python Executor 기반 (Input, Program, Output) 트리플 추출 시스템
5
+ 요구사항 2: "AZR Python Executor를 이용하여 (i,p,o) pair를 만든다"
6
+ """
7
+
8
+ import ast
9
+ import re
10
+ import json
11
+ from typing import Dict, List, Any, Tuple, Optional
12
+ from concurrent.futures import TimeoutError
13
+
14
+ from ..utils.code_utils.python_executor import PythonExecutor
15
+ from .config import TestTimeConfig
16
+ from .logger import TestTimeLogger
17
+ from .solution_generator import InitialSolutionGenerator
18
+
19
+
20
+ class IPOBuffer:
21
+ """IPO triple을 저장하고 관리하는 버퍼"""
22
+
23
+ def __init__(self):
24
+ self.buffer = {} # {problem_id: [ipo_triples]}
25
+
26
+ def add(self, problem_id: str, ipo_triple: Dict[str, Any]):
27
+ """IPO triple을 버퍼에 추가"""
28
+ if problem_id not in self.buffer:
29
+ self.buffer[problem_id] = []
30
+ self.buffer[problem_id].append(ipo_triple)
31
+
32
+ def get_all(self, problem_id: str) -> List[Dict[str, Any]]:
33
+ """특정 문제의 모든 IPO triple 반환"""
34
+ return self.buffer.get(problem_id, [])
35
+
36
+ def clear(self, problem_id: str = None):
37
+ """버퍼 초기화"""
38
+ if problem_id:
39
+ self.buffer.pop(problem_id, None)
40
+ else:
41
+ self.buffer.clear()
42
+
43
+ def size(self, problem_id: str = None) -> int:
44
+ """버퍼 크기 반환"""
45
+ if problem_id:
46
+ return len(self.buffer.get(problem_id, []))
47
+ return sum(len(triples) for triples in self.buffer.values())
48
+
49
+
50
+ class IPOTripleExtractor:
51
+ """(Input, Program, Output) 트리플 추출 및 검증"""
52
+
53
+ def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None,
54
+ model=None, tokenizer=None):
55
+ self.config = config
56
+ self.logger = logger or TestTimeLogger()
57
+ self.model = model
58
+ self.tokenizer = tokenizer
59
+
60
+ # AZR Python Executor 초기화 (기존 방식)
61
+ self.executor = PythonExecutor(
62
+ timeout_length=config.python_executor_timeout,
63
+ ast_check=True, # AZR 기본 설정
64
+ max_workers=config.max_workers
65
+ )
66
+
67
+ self.extracted_triples = []
68
+
69
+ # 입력 생성 프롬프트와 응답 저장용
70
+ self.last_generation_prompt = ""
71
+ self.last_generation_response = ""
72
+
73
+ # VLLM 배치 처리를 위한 참조
74
+ self.solution_generator = None
75
+
76
+ def extract_triples(self, problem: Dict[str, Any], solution: str) -> List[Dict[str, Any]]:
77
+ """벤치마크 문제와 솔루션에서 IPO 트리플 추출"""
78
+
79
+ problem_id = problem.get('task_id', 'unknown')
80
+ self.logger.log_info(f"🔍 Extracting IPO triples for {problem_id}")
81
+
82
+ triples = []
83
+
84
+ try:
85
+ # 1. 함수 정보 추출 (entry point 우선)
86
+ entry_point = problem.get('entry_point', 'unknown')
87
+ func_info = self._extract_function_info(solution, entry_point)
88
+ if not func_info:
89
+ self.logger.log_error(f"Failed to extract function info from solution")
90
+ return []
91
+
92
+ # 2. 테스트 케이스에서 입력-출력 쌍 생성 (LLM 솔루션 기반)
93
+ test_cases = self._extract_test_cases(problem, solution)
94
+
95
+ # 3. 솔루션 실행으로 IPO 트리플 생성
96
+ for i, (test_input_str, expected_output) in enumerate(test_cases):
97
+ if len(triples) >= self.config.max_ipo_triples:
98
+ break
99
+
100
+ # test_input_str에서 실제 인자 추출 (예: "strlen('')" -> "''")
101
+ import re
102
+ match = re.match(rf'{entry_point}\((.*)\)', test_input_str)
103
+ if match:
104
+ actual_args = match.group(1)
105
+ else:
106
+ actual_args = test_input_str # fallback
107
+
108
+ triple = self._create_ipo_triple(
109
+ func_info['full_code'], # 🔧 수정: 전체 코드 사용 (도우미 함수 포함)
110
+ func_info,
111
+ actual_args, # 실제 인자만 전달
112
+ expected_output,
113
+ triple_id=f"{problem_id}_triple_{i}",
114
+ full_input_str=test_input_str # 전체 입력 문자열도 전달
115
+ )
116
+
117
+ if triple:
118
+ triples.append(triple)
119
+
120
+ # 🔧 수정: Synthetic 트리플 생성 제거 (단일 예시만 사용하여 치팅 방지)
121
+ # Synthetic 트리플 생성 로직을 제거하여 진짜 단일 예시만 사용
122
+
123
+ # 검증 및 로깅
124
+ validation_results = [self._validate_triple(triple) for triple in triples]
125
+ self.logger.log_ipo_extraction(problem_id, triples, validation_results)
126
+
127
+ # 유효한 트리플만 반환
128
+ valid_triples = [triple for triple, valid in zip(triples, validation_results) if valid]
129
+
130
+ self.logger.log_info(f"✅ Extracted {len(valid_triples)}/{len(triples)} valid IPO triples")
131
+ return valid_triples
132
+
133
+ except Exception as e:
134
+ self.logger.log_error(f"IPO extraction failed: {e}")
135
+ return []
136
+
137
+ def _extract_function_info(self, solution: str, entry_point: str = None) -> Optional[Dict[str, str]]:
138
+ """솔루션에서 함수 정보 추출 (entry point 우선)"""
139
+
140
+ try:
141
+ # 🔧 개선: Raw LLM response인지 확인하고 함수 코드 추출
142
+ processed_solution = solution
143
+ if "LLM GENERATED SOLUTION:" in solution:
144
+ self.logger.log_info("📝 Raw LLM response detected, extracting function code")
145
+ processed_solution = self._extract_function_from_llm_response(solution)
146
+ if not processed_solution:
147
+ self.logger.log_error("Failed to extract function from LLM response")
148
+ return None
149
+
150
+ # AST로 함수 정의 파싱
151
+ tree = ast.parse(processed_solution)
152
+
153
+ # 🔧 수정: Entry point 함수 우선 검색
154
+ target_function = None
155
+ all_functions = []
156
+
157
+ for node in ast.walk(tree):
158
+ if isinstance(node, ast.FunctionDef):
159
+ func_info = {
160
+ 'name': node.name,
161
+ 'args': [arg.arg for arg in node.args.args],
162
+ 'signature': f"def {node.name}({', '.join([arg.arg for arg in node.args.args])}):",
163
+ 'full_code': processed_solution
164
+ }
165
+ all_functions.append(func_info)
166
+
167
+ # Entry point와 일치하는 함수 우선 선택
168
+ if entry_point and node.name == entry_point:
169
+ target_function = func_info
170
+ # 이 로그는 너무 자주 출력되므로 debug 레벨로 변경
171
+ self.logger.log_debug(f"🎯 Found entry point function: {entry_point}")
172
+ break
173
+
174
+ # Entry point 함수를 찾았으면 반환
175
+ if target_function:
176
+ return target_function
177
+
178
+ # Entry point를 찾지 못했으면 첫 번째 함수 반환 (기존 방식)
179
+ if all_functions:
180
+ self.logger.log_warning(f"⚠️ Entry point '{entry_point}' not found, using first function: {all_functions[0]['name']}")
181
+ return all_functions[0]
182
+
183
+ return None
184
+
185
+ except Exception as e:
186
+ self.logger.log_error(f"Function parsing failed: {e}")
187
+ return None
188
+
189
+ def _extract_function_from_llm_response(self, llm_response: str) -> str:
190
+ """Raw LLM response에서 함수 코드 추출 (solution_generator와 동일한 로직)"""
191
+
192
+ lines = llm_response.split('\n')
193
+ solution_lines = []
194
+ in_solution = False
195
+
196
+ # "LLM GENERATED SOLUTION:" 섹션 추출 (수정된 로직)
197
+ for i, line in enumerate(lines):
198
+ if "LLM GENERATED SOLUTION:" in line:
199
+ in_solution = True
200
+ continue
201
+ elif in_solution:
202
+ # "===============" 라인이 나오면 종료하되, 첫 번째 "==============="는 건너뛰기
203
+ if "===============" in line:
204
+ # 실제 솔루션 라인들이 있는지 확인
205
+ if solution_lines and any(l.strip() for l in solution_lines):
206
+ break
207
+ else:
208
+ # 아직 솔루션 라인이 없으면 계속 진행 (첫 번째 구분선 건너뛰기)
209
+ continue
210
+ solution_lines.append(line)
211
+
212
+ if not solution_lines:
213
+ return "" # 추출 실패시 빈 문자열 반환
214
+
215
+ extracted_solution = '\n'.join(solution_lines).strip()
216
+
217
+ # 함수 정의와 import 추출 (solution_generator 로직과 동일)
218
+ lines = extracted_solution.split('\n')
219
+ import_lines = []
220
+ func_lines = []
221
+ in_function = False
222
+ indent_level = 0
223
+
224
+ # 1. import 문 수집
225
+ for line in lines:
226
+ stripped = line.strip()
227
+ if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'):
228
+ import_lines.append(line)
229
+
230
+ # 2. 함수 정의 찾기
231
+ for line in lines:
232
+ if line.strip().startswith('def '):
233
+ in_function = True
234
+ func_lines = [line]
235
+ indent_level = len(line) - len(line.lstrip())
236
+ elif in_function:
237
+ if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level):
238
+ func_lines.append(line)
239
+ else:
240
+ break
241
+
242
+ # 3. import + function 결합
243
+ if func_lines:
244
+ result_lines = import_lines + [''] + func_lines if import_lines else func_lines
245
+ return '\n'.join(result_lines)
246
+ else:
247
+ return extracted_solution
248
+
249
+ def _fix_humaneval_canonical_solution(self, problem: Dict[str, Any]) -> str:
250
+ """HumanEval canonical solution 복원 (함수 시그니처 추가)"""
251
+
252
+ canonical_code = problem.get('canonical_solution', '')
253
+ entry_point = problem.get('entry_point', '')
254
+ prompt = problem.get('prompt', '')
255
+
256
+ # HumanEval인지 확인
257
+ task_id = problem.get('task_id', '')
258
+ if not task_id.startswith('HumanEval/'):
259
+ return canonical_code
260
+
261
+ # 이미 함수 시그니처가 있는지 확인
262
+ if f"def {entry_point}" in canonical_code:
263
+ return canonical_code
264
+
265
+ try:
266
+ # Prompt에서 함수 시그니처 추출
267
+ import re
268
+ def_pattern = rf'def\s+{re.escape(entry_point)}\s*\([^)]*\)[^:]*:'
269
+ match = re.search(def_pattern, prompt, re.MULTILINE)
270
+
271
+ if match:
272
+ function_signature = match.group(0)
273
+
274
+ # Import 문도 추출 (있다면)
275
+ import_lines = []
276
+ for line in prompt.split('\n'):
277
+ stripped = line.strip()
278
+ if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'):
279
+ import_lines.append(line)
280
+
281
+ # 완전한 canonical solution 구성
282
+ if import_lines:
283
+ complete_canonical = '\n'.join(import_lines) + '\n\n' + function_signature + canonical_code
284
+ else:
285
+ complete_canonical = function_signature + canonical_code
286
+
287
+ self.logger.log_info(f"🔧 Fixed HumanEval canonical solution for {entry_point}")
288
+ return complete_canonical
289
+ else:
290
+ self.logger.log_warning(f"⚠️ Could not extract function signature for {entry_point}")
291
+ return canonical_code
292
+
293
+ except Exception as e:
294
+ self.logger.log_error(f"Failed to fix HumanEval canonical solution: {e}")
295
+ return canonical_code
296
+
297
+ def _extract_single_prompt_example(self, problem: Dict[str, Any]) -> Optional[Tuple[str, str]]:
298
+ """🔧 새로운 메서드: 프롬프트의 단일 예시만 추출 (치팅 방지)"""
299
+
300
+ try:
301
+ # base_input의 첫 번째 항목을 단일 예시로 사용
302
+ if 'base_input' in problem and problem['base_input']:
303
+ first_input = problem['base_input'][0]
304
+ entry_point = problem['entry_point']
305
+
306
+ self.logger.log_info(f"📥 Using first base_input as single example: {first_input}")
307
+
308
+ # 🔧 수정: HumanEval canonical solution 복원
309
+ canonical_code = self._fix_humaneval_canonical_solution(problem)
310
+ if canonical_code:
311
+ actual_output = self._execute_llm_solution(canonical_code, entry_point, first_input)
312
+
313
+ if actual_output is not None:
314
+ # 입력 문자열 형식 생성
315
+ if isinstance(first_input, list):
316
+ if len(first_input) == 1 and isinstance(first_input[0], list):
317
+ # [[args]] -> 단일 리스트 인자로 표시
318
+ input_str = repr(first_input[0])
319
+ elif len(first_input) == 1:
320
+ # [단일인자] -> 단일인자
321
+ input_str = repr(first_input[0])
322
+ else:
323
+ # [다중인자] -> 다중인자
324
+ input_str = ', '.join(repr(arg) for arg in first_input)
325
+ else:
326
+ input_str = repr(first_input)
327
+
328
+ result = (input_str, str(actual_output))
329
+ self.logger.log_info(f"✅ Single example extracted: Input={input_str}, Output={actual_output}")
330
+ return result
331
+ else:
332
+ self.logger.log_warning("❌ Failed to compute output with canonical solution")
333
+ else:
334
+ self.logger.log_warning("❌ No canonical solution available")
335
+ else:
336
+ self.logger.log_warning("❌ No base_input available")
337
+
338
+ except Exception as e:
339
+ self.logger.log_error(f"Single example extraction failed: {e}")
340
+
341
+ return None
342
+
343
+ def _extract_docstring_examples(self, prompt: str, func_name: str) -> List[Tuple[str, str]]:
344
+ """docstring에서 >>> 예제 추출"""
345
+
346
+ examples = []
347
+ lines = prompt.split('\n')
348
+
349
+ i = 0
350
+ while i < len(lines):
351
+ line = lines[i].strip()
352
+ # >>> func_name(...) 패턴 찾기
353
+ if line.startswith('>>>') and func_name in line:
354
+ # 입력 추출
355
+ input_line = line[3:].strip() # >>> 제거
356
+
357
+ # 다음 줄에서 출력 추출
358
+ if i + 1 < len(lines):
359
+ output_line = lines[i + 1].strip()
360
+ # 출력이 >>> 로 시작하지 않으면 출력값
361
+ if not output_line.startswith('>>>'):
362
+ examples.append((input_line, output_line))
363
+ i += 2
364
+ continue
365
+ i += 1
366
+ else:
367
+ i += 1
368
+
369
+ return examples
370
+
371
+ def _extract_test_cases(self, problem: Dict[str, Any], solution: str) -> List[Tuple[str, str]]:
372
+ """docstring의 예제에서 테스트 케이스 추출 (치팅 방지)"""
373
+
374
+ test_cases = []
375
+ func_name = problem.get('entry_point', 'unknown')
376
+ problem_id = problem.get('task_id', '')
377
+
378
+ # HumanEval과 MBPP 모두 docstring 예제만 사용
379
+ self.logger.log_info(f"🎯 Extracting docstring examples for {problem_id}")
380
+
381
+ # 프롬프트에서 docstring 예제 추출
382
+ prompt = problem.get('prompt', '')
383
+ examples = self._extract_docstring_examples(prompt, func_name)
384
+
385
+ if examples:
386
+ self.logger.log_info(f"📝 Found {len(examples)} docstring examples")
387
+ for i, (input_str, expected_output) in enumerate(examples):
388
+ try:
389
+ # 입력 파싱 (func_name(args) 형태에서 args 추출)
390
+ import ast
391
+ # "func_name(args)" -> args 추출
392
+ if input_str.startswith(func_name + '(') and input_str.endswith(')'):
393
+ args_str = input_str[len(func_name)+1:-1]
394
+ # 안전한 평가를 위해 ast.literal_eval 사용
395
+ try:
396
+ # 단일 인자인 경우
397
+ input_args = ast.literal_eval(args_str)
398
+ if not isinstance(input_args, tuple):
399
+ input_args = (input_args,)
400
+ except:
401
+ # 여러 인자인 경우
402
+ input_args = ast.literal_eval(f"({args_str})")
403
+
404
+ # LLM 솔루션 실행
405
+ actual_output = self._execute_llm_solution(solution, func_name, list(input_args))
406
+ if actual_output is not None:
407
+ test_cases.append((input_str, str(actual_output)))
408
+ self.logger.log_info(f"✅ Example {i+1}: {input_str} -> {actual_output}")
409
+ else:
410
+ self.logger.log_warning(f"❌ Example {i+1} execution failed")
411
+
412
+ except Exception as e:
413
+ self.logger.log_error(f"Example {i+1} parsing failed: {e}")
414
+ else:
415
+ self.logger.log_warning(f"⚠️ No docstring examples found, falling back to first base_input")
416
+ # docstring 예제가 없으면 첫 번째 base_input만 사용 (MBPP처럼)
417
+ if 'base_input' in problem and problem['base_input']:
418
+ inp_args = problem['base_input'][0]
419
+ # 입력 문자열 생성
420
+ if isinstance(inp_args, list):
421
+ args_str = ', '.join(repr(arg) for arg in inp_args)
422
+ input_str = f"{func_name}({args_str})"
423
+ else:
424
+ input_str = f"{func_name}({repr(inp_args)})"
425
+
426
+ actual_output = self._execute_llm_solution(solution, func_name, inp_args)
427
+ if actual_output is not None:
428
+ test_cases.append((input_str, str(actual_output)))
429
+
430
+ self.logger.log_info(f"📊 Extracted {len(test_cases)} test cases from docstring examples")
431
+ return test_cases
432
+
433
+ def _execute_llm_solution(self, llm_solution: str, func_name: str, input_args) -> Optional[str]:
434
+ """LLM 생성 솔루션을 실행하여 실제 출력 계산"""
435
+
436
+ try:
437
+ if not llm_solution or func_name == 'unknown':
438
+ return None
439
+
440
+ # 🔧 수정: 실행용 코드 구성 (MBPP+ 이중 리스트 처리)
441
+ if isinstance(input_args, list):
442
+ # MBPP+ 데이터가 이중 리스트로 감싸진 경우 처리
443
+ if len(input_args) == 1 and isinstance(input_args[0], list):
444
+ # [[args]] -> 단일 리스트 인자로 전달
445
+ args_str = repr(input_args[0])
446
+ elif len(input_args) == 1:
447
+ # [단일인자] -> 단일 인자로 전달
448
+ args_str = repr(input_args[0])
449
+ else:
450
+ # [다중인자] -> 다중 인자로 전달
451
+ args_str = ', '.join(repr(arg) for arg in input_args)
452
+ else:
453
+ args_str = repr(input_args)
454
+
455
+ execution_code = f"""
456
+ {llm_solution}
457
+
458
+ # Execute LLM solution
459
+ try:
460
+ result = {func_name}({args_str})
461
+ print(repr(result))
462
+ except Exception as e:
463
+ print(f"EXECUTION_ERROR: {{e}}")
464
+ """
465
+
466
+ # AZR Python Executor로 실행
467
+ output, status = self.executor.apply(execution_code)
468
+
469
+ if 'error' in status.lower() or 'EXECUTION_ERROR' in output:
470
+ return None
471
+
472
+ # 출력에서 결과 추출
473
+ output_lines = output.strip().split('\n')
474
+ if output_lines:
475
+ result_line = output_lines[-1].strip()
476
+ # repr()로 출력된 결과를 그대로 반환
477
+ return result_line
478
+
479
+ return None
480
+
481
+ except Exception as e:
482
+ self.logger.log_error(f"LLM solution execution failed: {e}")
483
+ return None
484
+
485
+ def _create_ipo_triple(self, solution: str, func_info: Dict[str, str],
486
+ test_input: str, expected_output: str,
487
+ triple_id: str, full_input_str: str = None) -> Optional[Dict[str, Any]]:
488
+ """IPO 트리플 생성 및 검증 (AZR Python Executor 사용)"""
489
+
490
+ try:
491
+ # 1. 솔루션 실행으로 실제 출력 확인
492
+ actual_output = self._execute_function(solution, func_info['name'], test_input)
493
+
494
+ if actual_output is None:
495
+ return None
496
+
497
+ # 2. IPO 트리플 구성
498
+ triple = {
499
+ 'id': triple_id,
500
+ 'input': test_input, # 실제 인자만 저장 (예: "''", "3.5")
501
+ 'full_input_str': full_input_str or f"{func_info['name']}({test_input})", # 전체 입력 문자열은 별도 필드에
502
+ 'program': solution, # 이미 func_info['full_code']가 전달됨
503
+ 'expected_output': expected_output,
504
+ 'actual_output': actual_output,
505
+ 'function_name': func_info['name'],
506
+ 'function_args': func_info['args'],
507
+ 'is_correct': str(actual_output) == str(expected_output),
508
+ 'extraction_method': 'test_case'
509
+ }
510
+
511
+ return triple
512
+
513
+ except Exception as e:
514
+ self.logger.log_error(f"Triple creation failed for {triple_id}: {e}")
515
+ return None
516
+
517
+ def _execute_function(self, code: str, func_name: str, inputs: str) -> Optional[str]:
518
+ """AZR Python Executor로 함수 실행"""
519
+
520
+ try:
521
+ # 실행용 코드 구성 (AZR 템플릿 스타일)
522
+ execution_code = f"""
523
+ {code}
524
+
525
+ # Execute function with inputs
526
+ try:
527
+ result = {func_name}({inputs})
528
+ print(repr(result))
529
+ except Exception as e:
530
+ print(f"EXECUTION_ERROR: {{e}}")
531
+ """
532
+
533
+ # AZR 방식으로 실행
534
+ output, status = self.executor.apply(execution_code)
535
+
536
+ if 'error' in status.lower() or 'EXECUTION_ERROR' in output:
537
+ return None
538
+
539
+ # 출력에서 결과 추출
540
+ output_lines = output.strip().split('\n')
541
+ if output_lines:
542
+ return output_lines[-1].strip()
543
+
544
+ return None
545
+
546
+ except Exception as e:
547
+ self.logger.log_error(f"Function execution failed: {e}")
548
+ return None
549
+
550
+ # 🔧 제거: Synthetic 트리플 생성 메서드들 제거
551
+ # 단일 예시만 사용하여 치팅 방지 목적에 맞게 불필요한 메서드들 제거
552
+
553
+ def _validate_triple(self, triple: Dict[str, Any]) -> bool:
554
+ """IPO 트리플 검증"""
555
+
556
+ if not self.config.validate_triples:
557
+ return True
558
+
559
+ try:
560
+ # 1. 기본 필드 존재 확인
561
+ required_fields = ['input', 'program', 'expected_output', 'function_name']
562
+ if not all(field in triple for field in required_fields):
563
+ return False
564
+
565
+ # 2. 코드 구문 검증
566
+ try:
567
+ ast.parse(triple['program'])
568
+ except SyntaxError:
569
+ return False
570
+
571
+ # 3. 재실행으로 일관성 검증 (AZR 방식)
572
+ # 이제 triple['input']은 이미 실제 인자만 포함
573
+ actual_output = self._execute_function(
574
+ triple['program'],
575
+ triple['function_name'],
576
+ triple['input']
577
+ )
578
+
579
+ if actual_output is None:
580
+ return False
581
+
582
+ # 4. 출력 일치 확인
583
+ return str(actual_output) == str(triple['expected_output'])
584
+
585
+ except Exception as e:
586
+ self.logger.log_error(f"Triple validation failed: {e}")
587
+ return False
588
+
589
+ def get_triple_statistics(self) -> Dict[str, Any]:
590
+ """추출된 트리플 통계"""
591
+
592
+ if not self.extracted_triples:
593
+ return {"total": 0, "valid": 0, "invalid": 0}
594
+
595
+ valid_count = sum(1 for triple in self.extracted_triples if triple.get('is_correct', False))
596
+
597
+ return {
598
+ "total": len(self.extracted_triples),
599
+ "valid": valid_count,
600
+ "invalid": len(self.extracted_triples) - valid_count,
601
+ "extraction_methods": {
602
+ "test_case": sum(1 for t in self.extracted_triples if t.get('extraction_method') == 'test_case'),
603
+ "synthetic": sum(1 for t in self.extracted_triples if t.get('extraction_method') == 'synthetic')
604
+ }
605
+ }
606
+
607
+ def generate_diverse_inputs(self, problem: Dict[str, Any], solution: str,
608
+ existing_examples: List[Tuple[str, str]]) -> List[Dict[str, Any]]:
609
+ """LLM을 사용하여 다양한 입력 생성"""
610
+
611
+ problem_id = problem.get('task_id', 'unknown')
612
+ self.logger.log_info(f"🎲 Generating diverse inputs for {problem_id}")
613
+
614
+ try:
615
+ # 1. 함수 정보 추출
616
+ entry_point = problem.get('entry_point', 'unknown')
617
+ func_info = self._extract_function_info(solution, entry_point)
618
+ if not func_info:
619
+ self.logger.log_error("Failed to extract function info for input generation")
620
+ return []
621
+
622
+ # 2. 인자 타입 정보 추론
623
+ arg_type_info = self._infer_argument_types(func_info, existing_examples, solution)
624
+
625
+ # 3. 프롬프트 생성
626
+ prompt = self._create_input_generation_prompt(
627
+ problem_description=problem.get('prompt', ''),
628
+ existing_examples=existing_examples,
629
+ full_code=solution,
630
+ arg_type_info=arg_type_info
631
+ )
632
+
633
+ # 4. LLM으로 입력 생성
634
+ generated_inputs = self._call_llm_for_inputs(prompt, existing_examples, func_info, arg_type_info)
635
+
636
+ # 5. 생성된 입력 검증
637
+ valid_inputs = self._validate_generated_inputs(generated_inputs, func_info, solution)
638
+
639
+ self.logger.log_info(f"✅ Generated {len(valid_inputs)} valid diverse inputs")
640
+ return valid_inputs
641
+
642
+ except Exception as e:
643
+ self.logger.log_error(f"Failed to generate diverse inputs: {e}")
644
+ return []
645
+
646
+ def generate_diverse_inputs_batch(self, program_input_pairs: List[Dict[str, Any]]) -> Tuple[List[List[Dict[str, Any]]], List[Optional[Dict[str, Any]]]]:
647
+ """배치로 여러 프로그램의 diverse input 생성"""
648
+
649
+ if not self.solution_generator:
650
+ self.logger.log_error("Solution generator not set for batch processing")
651
+ return [], []
652
+
653
+ self.logger.log_info(f"🎲 Generating diverse inputs for {len(program_input_pairs)} programs (BATCH)")
654
+
655
+ try:
656
+ # 모든 프로그램의 입력 생성 프롬프트 생성
657
+ batch_prompts = []
658
+ program_contexts = []
659
+
660
+ for pair in program_input_pairs:
661
+ problem = pair['problem']
662
+ solution = pair['solution']
663
+ existing_examples = pair['existing_examples']
664
+
665
+ # 함수 정보 추출
666
+ entry_point = problem.get('entry_point', 'unknown')
667
+ func_info = self._extract_function_info(solution, entry_point)
668
+ if not func_info:
669
+ program_contexts.append(None)
670
+ batch_prompts.append("")
671
+ continue
672
+
673
+ # 인자 타입 정보 추론
674
+ arg_type_info = self._infer_argument_types(func_info, existing_examples, solution)
675
+
676
+ # 프롬프트 생성
677
+ prompt = self._create_input_generation_prompt(
678
+ problem_description=problem.get('prompt', ''),
679
+ existing_examples=existing_examples,
680
+ full_code=solution,
681
+ arg_type_info=arg_type_info
682
+ )
683
+
684
+ batch_prompts.append(prompt)
685
+ program_contexts.append({
686
+ 'func_info': func_info,
687
+ 'solution': solution,
688
+ 'problem': problem
689
+ })
690
+
691
+ # VLLM 배치로 LLM 호출
692
+ if not batch_prompts or all(not p for p in batch_prompts):
693
+ return [], []
694
+
695
+ self.logger.log_info(f"🔍 Sending {len(batch_prompts)} prompts to VLLM for input generation")
696
+ self.logger.log_info(f"🔍 First prompt preview: {batch_prompts[0][:200]}..." if batch_prompts else "No prompts")
697
+
698
+ # Input generation은 코드 생성이 아니므로 후처리 없이 원시 응답 사용
699
+ # generate_batch의 후처리(함수 추출 등)는 input generation에 부적합
700
+ batch_responses = self.solution_generator._generate_batch_with_vllm(
701
+ batch_prompts,
702
+ temperature=0.7 # Input generation에는 약간의 랜덤성 필요
703
+ )
704
+
705
+ self.logger.log_info(f"🔍 Received {len(batch_responses)} responses from VLLM")
706
+ for i, response in enumerate(batch_responses[:2]): # 처음 2개만 로깅
707
+ self.logger.log_info(f"🔍 Response {i} preview: {response[:200]}...")
708
+
709
+ # 각 응답을 파싱하여 입력 생성
710
+ batch_results = []
711
+ batch_generation_info = [] # 각 프로그램의 input generation 정보 저장
712
+
713
+ for i, (response, context) in enumerate(zip(batch_responses, program_contexts)):
714
+ if context is None:
715
+ batch_results.append([])
716
+ batch_generation_info.append(None)
717
+ continue
718
+
719
+ try:
720
+ # 응답에서 입력 추출
721
+ generated_inputs = self._parse_llm_input_response(
722
+ response,
723
+ context['func_info'],
724
+ context['problem'].get('task_id', 'unknown')
725
+ )
726
+
727
+ # 디버깅: 파싱된 입력 개수 로깅
728
+ self.logger.log_info(f"🔍 Parsed {len(generated_inputs)} inputs from response {i}")
729
+ if generated_inputs:
730
+ self.logger.log_info(f"🔍 First parsed input: {generated_inputs[0]}")
731
+
732
+ # 생성된 입력 검증
733
+ valid_inputs = self._validate_generated_inputs(
734
+ generated_inputs,
735
+ context['func_info'],
736
+ context['solution']
737
+ )
738
+
739
+ # 디버깅: 검증 후 입력 개수 로깅
740
+ self.logger.log_info(f"🔍 {len(valid_inputs)} inputs passed validation from response {i}")
741
+
742
+ batch_results.append(valid_inputs)
743
+
744
+ # Input generation 정보 저장
745
+ generation_info = {
746
+ 'prompt': batch_prompts[i] if i < len(batch_prompts) else '',
747
+ 'llm_response': response,
748
+ 'extracted_inputs': generated_inputs,
749
+ 'valid_inputs': valid_inputs,
750
+ 'existing_examples': program_input_pairs[i]['existing_examples'] if i < len(program_input_pairs) else [],
751
+ 'function_info': context['func_info'],
752
+ 'arg_type_info': self._infer_argument_types(
753
+ context['func_info'],
754
+ program_input_pairs[i]['existing_examples'] if i < len(program_input_pairs) else [],
755
+ context['solution']
756
+ )
757
+ }
758
+ batch_generation_info.append(generation_info)
759
+
760
+ except Exception as e:
761
+ self.logger.log_error(f"Failed to process batch item {i}: {e}")
762
+ # 더 자세한 디버깅 정보 추가
763
+ self.logger.log_error(f"Response preview: {response[:200]}...")
764
+ import traceback
765
+ self.logger.log_error(f"Traceback: {traceback.format_exc()}")
766
+ batch_results.append([])
767
+
768
+ # 에러 정보도 저장
769
+ batch_generation_info.append({
770
+ 'error': str(e),
771
+ 'prompt': batch_prompts[i] if i < len(batch_prompts) else '',
772
+ 'llm_response': response,
773
+ 'traceback': traceback.format_exc()
774
+ })
775
+
776
+ total_generated = sum(len(inputs) for inputs in batch_results)
777
+ self.logger.log_info(f"✅ Generated {total_generated} diverse inputs across {len(program_input_pairs)} programs")
778
+
779
+ # Return both inputs and generation info as a tuple
780
+ return batch_results, batch_generation_info
781
+
782
+ except Exception as e:
783
+ self.logger.log_error(f"Batch input generation failed: {e}")
784
+ return [], []
785
+
786
+ def _parse_llm_input_response(self, llm_response: str, func_info: Dict[str, Any], problem_id: str) -> List[Dict[str, Any]]:
787
+ """LLM 응답에서 입력 예제 파싱"""
788
+
789
+ self.logger.log_info(f"🔍 Parsing LLM response for {problem_id}, response length: {len(llm_response)}")
790
+
791
+ try:
792
+ # ```python ... ``` 블록에서 코드 추출
793
+ import re
794
+ code_pattern = r'```python\n(.*?)\n```'
795
+ matches = re.findall(code_pattern, llm_response, re.DOTALL)
796
+
797
+ if not matches:
798
+ self.logger.log_info("🔍 No code block found, searching for examples = [")
799
+ # 블록이 없으면 전체 응답에서 examples = 찾기
800
+ if 'examples = [' in llm_response:
801
+ start = llm_response.find('examples = [')
802
+ # 균형잡힌 괄호 찾기
803
+ bracket_count = 0
804
+ end = start
805
+ for i, char in enumerate(llm_response[start:]):
806
+ if char == '[':
807
+ bracket_count += 1
808
+ elif char == ']':
809
+ bracket_count -= 1
810
+ if bracket_count == 0:
811
+ end = start + i + 1
812
+ break
813
+
814
+ if end > start:
815
+ code = llm_response[start:end]
816
+ self.logger.log_info(f"🔍 Found examples code: {code[:100]}...")
817
+ exec_globals = {}
818
+ exec(code, exec_globals)
819
+ examples = exec_globals.get('examples', [])
820
+ self.logger.log_info(f"🔍 Extracted {len(examples)} examples")
821
+ return examples
822
+ else:
823
+ self.logger.log_info("🔍 No 'examples = [' found in response")
824
+ else:
825
+ # 코드 블록에서 examples 추출
826
+ self.logger.log_info(f"🔍 Found {len(matches)} code blocks")
827
+ code = matches[0]
828
+ self.logger.log_info(f"🔍 Code block preview: {code[:100]}...")
829
+ exec_globals = {}
830
+ exec(code, exec_globals)
831
+ examples = exec_globals.get('examples', [])
832
+ self.logger.log_info(f"🔍 Extracted {len(examples)} examples from code block")
833
+
834
+ # examples가 dict가 아닌 경우 처리
835
+ if examples and len(examples) > 0:
836
+ self.logger.log_info(f"🔍 First example type: {type(examples[0])}")
837
+ if isinstance(examples[0], dict):
838
+ # expected_output, description 등 불필요한 키 제거
839
+ cleaned_examples = []
840
+ for ex in examples:
841
+ cleaned = {k: v for k, v in ex.items()
842
+ if k not in ['expected_output', 'description']}
843
+ if cleaned: # 빈 dict가 아닌 경우만 추가
844
+ cleaned_examples.append(cleaned)
845
+ self.logger.log_info(f"🔍 Cleaned {len(cleaned_examples)} examples")
846
+ return cleaned_examples
847
+
848
+ return examples
849
+
850
+ return []
851
+
852
+ except Exception as e:
853
+ self.logger.log_error(f"Failed to parse generated examples for {problem_id}: {e}")
854
+ import traceback
855
+ self.logger.log_error(f"Traceback: {traceback.format_exc()}")
856
+ return []
857
+
858
+ def _infer_argument_types(self, func_info: Dict[str, str],
859
+ examples: List[Tuple[str, str]],
860
+ solution: str) -> Dict[str, str]:
861
+ """기존 예제와 AST 분석으로 인자 타입 추론"""
862
+
863
+ arg_types = {}
864
+ func_name = func_info['name']
865
+ arg_names = func_info['args']
866
+
867
+ # 1. AST에서 type annotation 추출
868
+ try:
869
+ tree = ast.parse(solution)
870
+ for node in ast.walk(tree):
871
+ if isinstance(node, ast.FunctionDef) and node.name == func_name:
872
+ for i, arg in enumerate(node.args.args):
873
+ if i < len(arg_names) and arg.annotation:
874
+ # Type annotation이 있는 경우
875
+ arg_types[arg_names[i]] = ast.unparse(arg.annotation)
876
+ except:
877
+ pass
878
+
879
+ # 2. 기존 예제에서 타입 추론
880
+ if examples:
881
+ for input_str, _ in examples:
882
+ # "func_name(args)" 형태에서 args 추출
883
+ if input_str.startswith(func_name + '(') and input_str.endswith(')'):
884
+ args_str = input_str[len(func_name)+1:-1]
885
+ try:
886
+ # 인자 파싱
887
+ parsed_args = eval(f"({args_str},)")
888
+ if not isinstance(parsed_args, tuple):
889
+ parsed_args = (parsed_args,)
890
+
891
+ # 각 인자의 타입 추론
892
+ for i, arg_value in enumerate(parsed_args):
893
+ if i < len(arg_names):
894
+ arg_name = arg_names[i]
895
+ arg_type = type(arg_value).__name__
896
+
897
+ # 특별한 케이스 처리
898
+ if isinstance(arg_value, list):
899
+ if arg_value and all(isinstance(x, type(arg_value[0])) for x in arg_value):
900
+ inner_type = type(arg_value[0]).__name__
901
+ arg_type = f"List[{inner_type}]"
902
+ else:
903
+ arg_type = "List"
904
+
905
+ # 기존 타입과 병합
906
+ if arg_name not in arg_types:
907
+ arg_types[arg_name] = arg_type
908
+ except:
909
+ pass
910
+
911
+ # 3. 타입 정보 딕셔너리로 반환
912
+ # arg_types가 비어있으면 unknown 타입으로 채우기
913
+ for arg_name in arg_names:
914
+ if arg_name not in arg_types:
915
+ arg_types[arg_name] = "Any (type unknown)"
916
+
917
+ return arg_types
918
+
919
+ def _create_input_generation_prompt(self, problem_description: str,
920
+ existing_examples: List[Tuple[str, str]],
921
+ full_code: str,
922
+ arg_type_info: Dict[str, str]) -> str:
923
+ """입력 생성을 위한 프롬프트 생성"""
924
+
925
+ # 모든 기존 예제를 포맷팅
926
+ examples_text = ""
927
+ for i, (input_str, output_str) in enumerate(existing_examples):
928
+ examples_text += f"Example {i+1}:\n"
929
+ examples_text += f"Input: {input_str}\n"
930
+ examples_text += f"Output: {output_str}\n\n"
931
+
932
+ # arg_type_info를 문자열로 포맷팅
933
+ arg_type_text = "Argument types:\n"
934
+ for arg, arg_type in arg_type_info.items():
935
+ arg_type_text += f"- {arg}: {arg_type}\n"
936
+
937
+ prompt = f"""Given the following problem description and its Python function implementation, first analyze the types and valid ranges of the function arguments, then write **5 different example inputs** for the function that cover a diverse mix of typical (general) cases and edge/boundary cases.
938
+
939
+ Problem Description:
940
+ '''
941
+ {problem_description}
942
+ '''
943
+
944
+ Existing Examples from Problem:
945
+ {examples_text}
946
+
947
+ Function Implementation:
948
+ ```python
949
+ {full_code}
950
+ ```
951
+
952
+ {arg_type_text}
953
+
954
+ Based on the existing examples above, generate 5 NEW diverse test inputs that are different from the existing ones. Each input should be a Python dict where:
955
+ - Keys are the exact parameter names from the function signature
956
+ - Values are appropriate test values for each parameter
957
+
958
+ Format your response as:
959
+ ```python
960
+ examples = [
961
+ {{dict_with_all_function_parameters}}, # Description of this test case
962
+ {{dict_with_all_function_parameters}}, # Description of this test case
963
+ ... # Continue for all 5 examples
964
+ ]
965
+ ```
966
+
967
+ Ensure your examples include:
968
+ - At least 2 typical/general cases
969
+ - At least 2 edge/boundary cases
970
+ - 1 special case (empty, zero, maximum values, etc.)
971
+ - All examples should be DIFFERENT from the existing examples shown above"""
972
+
973
+ return prompt
974
+
975
+ def _call_llm_for_inputs(self, prompt: str, existing_examples: List[Tuple[str, str]],
976
+ func_info: Dict[str, Any], arg_type_info: str) -> List[Dict[str, Any]]:
977
+ """LLM을 호출하여 입력 생성 및 파싱"""
978
+
979
+ # 프롬프트 저장
980
+ self.last_generation_prompt = prompt
981
+
982
+ try:
983
+ # Input 생성용 전용 LLM 호출 (temperature=0.5)
984
+ if self.model is not None and self.tokenizer is not None:
985
+ # VLLM 사용 확인
986
+ try:
987
+ from vllm import LLM
988
+ if isinstance(self.model, LLM):
989
+ response = self._generate_with_vllm_for_inputs(prompt)
990
+ else:
991
+ response = self._generate_with_hf_for_inputs(prompt)
992
+ except ImportError:
993
+ response = self._generate_with_hf_for_inputs(prompt)
994
+
995
+ # 응답 저장
996
+ self.last_generation_response = response
997
+
998
+ # 응답에서 examples 추출
999
+ parsed_inputs = self._parse_generated_examples(response)
1000
+
1001
+ # 입력 생성 정보 저장
1002
+ self.last_input_generation_info = {
1003
+ 'prompt': prompt,
1004
+ 'llm_response': response,
1005
+ 'extracted_inputs': parsed_inputs,
1006
+ 'existing_examples': existing_examples,
1007
+ 'function_info': func_info,
1008
+ 'arg_type_info': arg_type_info
1009
+ }
1010
+
1011
+ return parsed_inputs
1012
+ else:
1013
+ # 모델이 없으면 빈 리스트 반환 (테스트 환경)
1014
+ self.logger.log_warning("No model available for input generation")
1015
+ self.last_generation_response = "No model available"
1016
+
1017
+ # 실패한 경우에도 정보 저장
1018
+ self.last_input_generation_info = {
1019
+ 'prompt': prompt,
1020
+ 'llm_response': "No model available",
1021
+ 'extracted_inputs': [],
1022
+ 'existing_examples': existing_examples,
1023
+ 'function_info': func_info,
1024
+ 'arg_type_info': arg_type_info,
1025
+ 'error': "No model available"
1026
+ }
1027
+ return []
1028
+
1029
+ except Exception as e:
1030
+ self.logger.log_error(f"Failed to call LLM for inputs: {e}")
1031
+ self.last_generation_response = f"Error: {str(e)}"
1032
+
1033
+ # 에러 발생 시에도 정보 저장
1034
+ self.last_input_generation_info = {
1035
+ 'prompt': locals().get('prompt', 'N/A'),
1036
+ 'llm_response': f"Error: {str(e)}",
1037
+ 'extracted_inputs': [],
1038
+ 'existing_examples': locals().get('existing_examples', []),
1039
+ 'function_info': locals().get('func_info', {}),
1040
+ 'arg_type_info': locals().get('arg_type_info', 'N/A'),
1041
+ 'error': str(e)
1042
+ }
1043
+ return []
1044
+
1045
+ def _generate_with_vllm_for_inputs(self, prompt: str) -> str:
1046
+ """Input 생성용 VLLM 백엔드 (temperature=0.5로 다양성 확보)"""
1047
+ try:
1048
+ from vllm import SamplingParams
1049
+
1050
+ # Input 생성용 높은 temperature 설정
1051
+ sampling_params = SamplingParams(
1052
+ temperature=0.5, # 다양한 입력 생성을 위한 높은 temperature
1053
+ max_tokens=2048,
1054
+ top_p=0.95, # 다양성을 위해 top_p 사용
1055
+ stop=["\n```\n"], # 코드 블록 종료 시 정지
1056
+ )
1057
+
1058
+ outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
1059
+ return outputs[0].outputs[0].text.replace("\t", " ").strip()
1060
+
1061
+ except Exception as e:
1062
+ self.logger.log_error(f"VLLM input generation failed: {e}")
1063
+ return ""
1064
+
1065
+ def _generate_with_hf_for_inputs(self, prompt: str) -> str:
1066
+ """Input 생성용 HuggingFace 백엔드 (temperature=0.5로 다양성 확보)"""
1067
+ try:
1068
+ import torch
1069
+
1070
+ # 토크나이저 처리
1071
+ inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
1072
+
1073
+ # attention mask 명시적으로 설정
1074
+ if 'attention_mask' not in inputs:
1075
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
1076
+
1077
+ # 디바이스 이동
1078
+ inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
1079
+
1080
+ with torch.no_grad():
1081
+ # 메모리 정리
1082
+ if torch.cuda.is_available():
1083
+ torch.cuda.empty_cache()
1084
+
1085
+ # Input 생성용 sampling 설정
1086
+ outputs = self.model.generate(
1087
+ inputs['input_ids'],
1088
+ attention_mask=inputs['attention_mask'],
1089
+ max_new_tokens=2048,
1090
+ do_sample=True, # sampling 활성화
1091
+ temperature=0.5, # 다양한 입력 생성을 위한 temperature
1092
+ top_p=0.95, # 다양성을 위해 top_p 사용
1093
+ pad_token_id=self.tokenizer.eos_token_id,
1094
+ eos_token_id=self.tokenizer.eos_token_id
1095
+ )
1096
+
1097
+ # 응답 추출
1098
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
1099
+ response = response[len(prompt):].strip()
1100
+ return response
1101
+
1102
+ except Exception as e:
1103
+ self.logger.log_error(f"HuggingFace input generation failed: {e}")
1104
+ return ""
1105
+
1106
+ def _parse_generated_examples(self, llm_response: str) -> List[Dict[str, Any]]:
1107
+ """LLM 응답에서 예제 파싱"""
1108
+
1109
+ try:
1110
+ # ```python ... ``` 블록에서 코드 추출
1111
+ import re
1112
+ code_pattern = r'```python\n(.*?)\n```'
1113
+ matches = re.findall(code_pattern, llm_response, re.DOTALL)
1114
+
1115
+ if not matches:
1116
+ # 블록이 없으면 전체 응답에서 examples = 찾기
1117
+ if 'examples = [' in llm_response:
1118
+ start = llm_response.find('examples = [')
1119
+ # 균형잡힌 괄호 찾기
1120
+ bracket_count = 0
1121
+ end = start
1122
+ for i, char in enumerate(llm_response[start:]):
1123
+ if char == '[':
1124
+ bracket_count += 1
1125
+ elif char == ']':
1126
+ bracket_count -= 1
1127
+ if bracket_count == 0:
1128
+ end = start + i + 1
1129
+ break
1130
+
1131
+ if end > start:
1132
+ code = llm_response[start:end]
1133
+ exec_globals = {}
1134
+ exec(code, exec_globals)
1135
+ return exec_globals.get('examples', [])
1136
+ else:
1137
+ # 코드 블록에서 examples 추출
1138
+ code = matches[0]
1139
+ exec_globals = {}
1140
+ exec(code, exec_globals)
1141
+ return exec_globals.get('examples', [])
1142
+
1143
+ return []
1144
+
1145
+ except Exception as e:
1146
+ self.logger.log_error(f"Failed to parse generated examples: {e}")
1147
+ return []
1148
+
1149
+ def _validate_generated_inputs(self, generated_inputs: List[Dict[str, Any]],
1150
+ func_info: Dict[str, str],
1151
+ solution: str) -> List[Dict[str, Any]]:
1152
+ """생성된 입력의 유효성 검증"""
1153
+
1154
+ valid_inputs = []
1155
+ func_name = func_info['name']
1156
+
1157
+ for i, input_dict in enumerate(generated_inputs):
1158
+ try:
1159
+ # 1. 필수 인자 확인
1160
+ required_args = set(func_info['args'])
1161
+ provided_args = set(input_dict.keys())
1162
+
1163
+ if not required_args.issubset(provided_args):
1164
+ self.logger.log_warning(f"Input {i+1} missing required args: {required_args - provided_args}")
1165
+ continue
1166
+
1167
+ # 2. 실제 실행으로 검증
1168
+ # 인자를 순서대로 배열
1169
+ args = [input_dict[arg] for arg in func_info['args'] if arg in input_dict]
1170
+
1171
+ # 실행 테스트
1172
+ output = self._execute_llm_solution(solution, func_name, args)
1173
+ if output is not None:
1174
+ valid_inputs.append(input_dict)
1175
+ self.logger.log_info(f"✅ Valid input {i+1}: {input_dict}")
1176
+ else:
1177
+ self.logger.log_warning(f"❌ Input {i+1} execution failed")
1178
+
1179
+ except Exception as e:
1180
+ self.logger.log_error(f"Input {i+1} validation error: {e}")
1181
+
1182
+ return valid_inputs
1183
+
1184
+ def create_ipo_from_input(self, problem: Dict[str, Any],
1185
+ solution: str,
1186
+ input_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]:
1187
+ """새로운 입력으로 IPO triple 생성"""
1188
+
1189
+ try:
1190
+ problem_id = problem.get('task_id', 'unknown')
1191
+ entry_point = problem.get('entry_point', 'unknown')
1192
+
1193
+ # 함수 정보 추출
1194
+ func_info = self._extract_function_info(solution, entry_point)
1195
+ if not func_info:
1196
+ return None
1197
+
1198
+ # 인자를 순서대로 배열
1199
+ args = [input_dict[arg] for arg in func_info['args'] if arg in input_dict]
1200
+
1201
+ # 실행하여 출력 얻기
1202
+ output = self._execute_llm_solution(solution, func_info['name'], args)
1203
+ if output is None:
1204
+ return None
1205
+
1206
+ # 입력 문자열 생성
1207
+ args_str = ', '.join(repr(arg) for arg in args)
1208
+ full_input_str = f"{func_info['name']}({args_str})"
1209
+
1210
+ # IPO triple 생성
1211
+ triple_id = f"{problem_id}_generated_{len(self.extracted_triples)}"
1212
+
1213
+ triple = {
1214
+ 'id': triple_id,
1215
+ 'input': args_str, # 실제 인자만
1216
+ 'full_input_str': full_input_str, # 전체 함수 호출
1217
+ 'program': solution,
1218
+ 'expected_output': output,
1219
+ 'actual_output': output,
1220
+ 'function_name': func_info['name'],
1221
+ 'function_args': func_info['args'],
1222
+ 'is_correct': True, # 생성된 것은 항상 정확
1223
+ 'extraction_method': 'generated'
1224
+ }
1225
+
1226
+ return triple
1227
+
1228
+ except Exception as e:
1229
+ self.logger.log_error(f"Failed to create IPO from input: {e}")
1230
+ return None
1231
+
1232
+ def cleanup(self):
1233
+ """리소스 정리"""
1234
+ if hasattr(self.executor, 'cleanup'):
1235
+ self.executor.cleanup()
absolute_zero_reasoner/testtime/logger.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TestTime Logger
3
+
4
+ TestTime RLVR을 위한 포괄적 로깅 시스템
5
+ 요구사항에 따른 모든 단계별 로그 기록
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import time
11
+ from datetime import datetime
12
+ from typing import Dict, List, Any, Optional
13
+ from pathlib import Path
14
+ import logging
15
+
16
+
17
+ class TestTimeLogger:
18
+ """TestTime RLVR 전용 로거"""
19
+
20
+ def __init__(self, log_dir: str = "logs", log_level: str = "INFO", task_output_dir: str = None, log_file: str = None):
21
+ # 설계된 구조에 맞는 로그 디렉토리 설정
22
+ if task_output_dir:
23
+ # TTRLVR 통합 모드: 설계된 디렉토리 구조 사용
24
+ self.log_dir = Path(task_output_dir)
25
+ self.use_integrated_structure = True
26
+ else:
27
+ # 기존 모드: 기본 logs 디렉토리 사용
28
+ self.log_dir = Path(log_dir)
29
+ self.use_integrated_structure = False
30
+
31
+ self.log_dir.mkdir(parents=True, exist_ok=True)
32
+
33
+ # 디렉토리 구조에 따른 서브 디렉토리 생성
34
+ if self.use_integrated_structure:
35
+ # 설계된 구조: round_N 하위에 세부 디렉토리
36
+ (self.log_dir / "current_evaluation").mkdir(exist_ok=True)
37
+ (self.log_dir / "diverse_programs").mkdir(exist_ok=True)
38
+ (self.log_dir / "llm_responses").mkdir(exist_ok=True)
39
+ (self.log_dir / "azr_training_data").mkdir(exist_ok=True)
40
+ # 기존 구조에서는 서브 디렉토리를 생성하지 않음 (메인 로그 파일만)
41
+
42
+ # 기본 로거 설정
43
+ self.logger = logging.getLogger("TestTimeRLVR")
44
+ self.logger.setLevel(getattr(logging, log_level))
45
+
46
+ # 핸들러 설정
47
+ if not self.logger.handlers:
48
+ # 파일 핸들러
49
+ if log_file:
50
+ # 특정 로그 파일 경로가 주어진 경우 (Ray worker에서 사용)
51
+ self.log_file_path = log_file
52
+ file_handler = logging.FileHandler(log_file, mode='a') # append mode
53
+ else:
54
+ # 기본 로그 파일 생성
55
+ self.log_file_path = str(self.log_dir / f"testtime_rlvr_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
56
+ file_handler = logging.FileHandler(self.log_file_path)
57
+ file_handler.setLevel(logging.DEBUG)
58
+
59
+ # 콘솔 핸들러
60
+ console_handler = logging.StreamHandler()
61
+ console_handler.setLevel(getattr(logging, log_level))
62
+
63
+ # 포맷터
64
+ formatter = logging.Formatter(
65
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
66
+ )
67
+ file_handler.setFormatter(formatter)
68
+ console_handler.setFormatter(formatter)
69
+
70
+ self.logger.addHandler(file_handler)
71
+ self.logger.addHandler(console_handler)
72
+
73
+ def _get_timestamp(self) -> str:
74
+ """현재 타임스탬프 반환"""
75
+ return datetime.now().isoformat()
76
+
77
+ def _save_json_log(self, subdirectory: str, filename: str, data: Dict[str, Any]):
78
+ """JSON 로그 파일 저장"""
79
+ if self.use_integrated_structure:
80
+ # 설계된 구조: 각 카테고리별로 적절한 디렉토리에 저장
81
+ if subdirectory == "ipo_extraction":
82
+ # IPO 추출 로그는 diverse_programs 하위에 별도로 저장
83
+ log_path = self.log_dir / "diverse_programs" / f"{filename}.json"
84
+ elif subdirectory == "task_generation":
85
+ # Task generation 로그는 round 레벨에 저장 (모든 task 종류 포함)
86
+ log_path = self.log_dir / f"{filename}.json"
87
+ elif subdirectory == "problems":
88
+ log_path = self.log_dir / "current_evaluation" / f"{filename}.json"
89
+ elif subdirectory == "performance":
90
+ log_path = self.log_dir / "current_evaluation" / f"{filename}.json"
91
+ elif subdirectory == "training":
92
+ log_path = self.log_dir / "azr_training_data" / f"{filename}.json"
93
+ else:
94
+ # 기본값
95
+ log_path = self.log_dir / subdirectory / f"{filename}.json"
96
+ else:
97
+ # 기존 구조
98
+ log_path = self.log_dir / subdirectory / f"{filename}.json"
99
+
100
+ # 디렉토리 생성 (없다면)
101
+ log_path.parent.mkdir(parents=True, exist_ok=True)
102
+
103
+ # 기존 로그 로드 (있다면)
104
+ if log_path.exists():
105
+ with open(log_path, 'r', encoding='utf-8') as f:
106
+ existing_logs = json.load(f)
107
+ else:
108
+ existing_logs = []
109
+
110
+ # 새 로그 추가
111
+ data['timestamp'] = self._get_timestamp()
112
+ existing_logs.append(data)
113
+
114
+ # 저장
115
+ with open(log_path, 'w', encoding='utf-8') as f:
116
+ json.dump(existing_logs, f, indent=2, ensure_ascii=False)
117
+
118
+ # ============================================================================
119
+ # 1. 벤치마크 문제 로깅 (요구사항 1)
120
+ # ============================================================================
121
+
122
+ def log_problem_attempt(self, problem: Dict[str, Any], solution: str,
123
+ is_correct: bool, validation_result: Optional[Dict] = None):
124
+ """벤치마크 문제와 LLM 답변, 정답 여부 로그"""
125
+
126
+ log_data = {
127
+ 'problem_id': problem.get('task_id', 'unknown'),
128
+ 'benchmark': problem.get('benchmark_name', 'unknown'),
129
+ 'problem_prompt': problem.get('prompt', ''),
130
+ 'canonical_solution': problem.get('canonical_solution', ''),
131
+ 'llm_solution': solution,
132
+ 'is_correct': is_correct,
133
+ 'validation_result': validation_result or {}
134
+ }
135
+
136
+ self._save_json_log("problems", f"problem_{problem.get('task_id', 'unknown').replace('/', '_')}", log_data)
137
+
138
+ status = "✅ CORRECT" if is_correct else "❌ INCORRECT"
139
+ self.logger.info(f"Problem {problem.get('task_id', 'unknown')}: {status}")
140
+
141
+ def log_problem_loaded(self, problem_id: str, benchmark_name: str, method: str = "Original"):
142
+ """문제 로딩 로그 (EvalPlus/Original 방식 구분)"""
143
+ self.logger.info(f"Loaded problem {problem_id} from {benchmark_name} ({method} method)")
144
+
145
+ # ============================================================================
146
+ # 2. IPO 추출 로깅 (요구사항 2)
147
+ # ============================================================================
148
+
149
+ def log_ipo_extraction(self, problem_id: str, extracted_triples: List[Dict],
150
+ validation_results: List[bool]):
151
+ """생성된 (i,p,o) 트리플과 검증 결과 로그"""
152
+
153
+ log_data = {
154
+ 'problem_id': problem_id,
155
+ 'num_triples': len(extracted_triples),
156
+ 'triples': extracted_triples,
157
+ 'validation_results': validation_results,
158
+ 'valid_triples': sum(validation_results),
159
+ 'invalid_triples': len(validation_results) - sum(validation_results)
160
+ }
161
+
162
+ self._save_json_log("ipo_extraction", f"ipo_{problem_id.replace('/', '_')}", log_data)
163
+
164
+ self.logger.info(f"IPO Extraction for {problem_id}: {len(extracted_triples)} triples, "
165
+ f"{sum(validation_results)} valid")
166
+
167
+ # ============================================================================
168
+ # 3. 태스크 생성 로깅 (요구사항 2)
169
+ # ============================================================================
170
+
171
+ def log_task_generation(self, problem_id: str, induction_tasks: List[Dict],
172
+ deduction_tasks: List[Dict], abduction_tasks: List[Dict]):
173
+ """생성된 induction, deduction, abduction 문제 로그"""
174
+
175
+ log_data = {
176
+ 'problem_id': problem_id,
177
+ 'induction_tasks': {
178
+ 'count': len(induction_tasks),
179
+ 'tasks': induction_tasks
180
+ },
181
+ 'deduction_tasks': {
182
+ 'count': len(deduction_tasks),
183
+ 'tasks': deduction_tasks
184
+ },
185
+ 'abduction_tasks': {
186
+ 'count': len(abduction_tasks),
187
+ 'tasks': abduction_tasks
188
+ },
189
+ 'total_tasks': len(induction_tasks) + len(deduction_tasks) + len(abduction_tasks)
190
+ }
191
+
192
+ self._save_json_log("task_generation", f"tasks_{problem_id.replace('/', '_')}", log_data)
193
+
194
+ total_tasks = log_data['total_tasks']
195
+ self.logger.info(f"Task Generation for {problem_id}: {total_tasks} tasks "
196
+ f"(I:{len(induction_tasks)}, D:{len(deduction_tasks)}, A:{len(abduction_tasks)})")
197
+
198
+ # ============================================================================
199
+ # 4. 학습 메트릭 로깅 (요구사항 3, 4)
200
+ # ============================================================================
201
+
202
+ def log_task_accuracy(self, problem_id: str, task_type: str, accuracy: float,
203
+ rewards: List[float], step: int):
204
+ """induction/deduction/abduction 태스크 정확도와 reward 로그"""
205
+
206
+ log_data = {
207
+ 'problem_id': problem_id,
208
+ 'task_type': task_type, # 'induction', 'deduction', 'abduction'
209
+ 'step': step,
210
+ 'accuracy': accuracy,
211
+ 'rewards': rewards,
212
+ 'avg_reward': sum(rewards) / len(rewards) if rewards else 0.0,
213
+ 'max_reward': max(rewards) if rewards else 0.0,
214
+ 'min_reward': min(rewards) if rewards else 0.0
215
+ }
216
+
217
+ self._save_json_log("training", f"accuracy_{problem_id.replace('/', '_')}", log_data)
218
+
219
+ self.logger.info(f"Step {step} - {task_type.capitalize()} accuracy: {accuracy:.4f}, "
220
+ f"avg reward: {log_data['avg_reward']:.4f}")
221
+
222
+ def log_verl_training(self, problem_id: str, step: int, loss: float,
223
+ learning_rate: float, metrics: Dict[str, Any]):
224
+ """VeRL 학습 진행 상황 로그"""
225
+
226
+ log_data = {
227
+ 'problem_id': problem_id,
228
+ 'step': step,
229
+ 'loss': loss,
230
+ 'learning_rate': learning_rate,
231
+ 'metrics': metrics
232
+ }
233
+
234
+ self._save_json_log("training", f"verl_{problem_id.replace('/', '_')}", log_data)
235
+
236
+ self.logger.info(f"VeRL Training Step {step}: loss={loss:.6f}, lr={learning_rate:.2e}")
237
+
238
+ # ============================================================================
239
+ # 5. 성능 변화 로깅
240
+ # ============================================================================
241
+
242
+ def log_performance_change(self, problem_id: str, cycle: int,
243
+ before_accuracy: float, after_accuracy: float,
244
+ improvement: float):
245
+ """매 사이클별 성능 변화 로그"""
246
+
247
+ log_data = {
248
+ 'problem_id': problem_id,
249
+ 'cycle': cycle,
250
+ 'before_accuracy': before_accuracy,
251
+ 'after_accuracy': after_accuracy,
252
+ 'improvement': improvement,
253
+ 'improvement_percentage': improvement * 100
254
+ }
255
+
256
+ self._save_json_log("performance", f"cycle_{problem_id.replace('/', '_')}", log_data)
257
+
258
+ direction = "↗️" if improvement > 0 else "↘️" if improvement < 0 else "→"
259
+ self.logger.info(f"Cycle {cycle} Performance: {before_accuracy:.4f} → {after_accuracy:.4f} "
260
+ f"({direction} {improvement:+.4f})")
261
+
262
+ # ============================================================================
263
+ # 일반 로깅
264
+ # ============================================================================
265
+
266
+ def log_info(self, message: str):
267
+ """일반 정보 로그"""
268
+ self.logger.info(message)
269
+
270
+ def log_error(self, message: str):
271
+ """에러 로그"""
272
+ self.logger.error(message)
273
+
274
+ def log_warning(self, message: str):
275
+ """경고 로그"""
276
+ self.logger.warning(message)
277
+
278
+ def log_debug(self, message: str):
279
+ """디버그 로그"""
280
+ self.logger.debug(message)
281
+
282
+ def get_log_summary(self) -> Dict[str, Any]:
283
+ """로그 요약 정보 반환"""
284
+ summary = {
285
+ 'log_directory': str(self.log_dir),
286
+ 'subdirectories': {
287
+ 'problems': len(list((self.log_dir / "problems").glob("*.json"))),
288
+ 'ipo_extraction': len(list((self.log_dir / "ipo_extraction").glob("*.json"))),
289
+ 'task_generation': len(list((self.log_dir / "task_generation").glob("*.json"))),
290
+ 'training': len(list((self.log_dir / "training").glob("*.json"))),
291
+ 'performance': len(list((self.log_dir / "performance").glob("*.json")))
292
+ }
293
+ }
294
+
295
+ return summary
absolute_zero_reasoner/testtime/prompts.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TestTime RLVR 프롬프트 중앙 관리 시스템
3
+
4
+ 모든 프롬프트를 한 곳에서 관리하여 일관성과 유지보수성을 향상시킵니다.
5
+ """
6
+
7
+ from typing import Dict, List, Any
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+
11
+
12
+ class PromptType(Enum):
13
+ """프롬프트 유형 정의"""
14
+ SOLUTION_GENERATION = "solution_generation"
15
+ DIVERSE_GENERATION = "diverse_generation"
16
+ INPUT_GENERATION = "input_generation"
17
+ TASK_GENERATION = "task_generation"
18
+ TASK_EVALUATION = "task_evaluation"
19
+
20
+
21
+ class BenchmarkType(Enum):
22
+ """벤치마크 유형 정의"""
23
+ HUMANEVAL = "humaneval"
24
+ MBPP = "mbpp"
25
+ GENERAL = "general"
26
+
27
+
28
+ @dataclass
29
+ class PromptTemplate:
30
+ """프롬프트 템플릿 데이터 클래스"""
31
+ name: str
32
+ template: str
33
+ description: str
34
+ benchmark: BenchmarkType
35
+ temperature: float = 0.05
36
+ variables: List[str] = None
37
+
38
+ def __post_init__(self):
39
+ if self.variables is None:
40
+ self.variables = []
41
+
42
+
43
+ class PromptManager:
44
+ """프롬프트 중앙 관리 클래스"""
45
+
46
+ def __init__(self):
47
+ self.prompts = self._initialize_prompts()
48
+
49
+ def _initialize_prompts(self) -> Dict[str, PromptTemplate]:
50
+ """모든 프롬프트 템플릿 초기화"""
51
+
52
+ prompts = {}
53
+
54
+ # ================================================================================
55
+ # 1. SOLUTION GENERATION PROMPTS (Current Evaluation - 베이스라인)
56
+ # ================================================================================
57
+
58
+ # HumanEval 기본 솔루션 생성
59
+ prompts["solution_humaneval_basic"] = PromptTemplate(
60
+ name="HumanEval 기본 솔루션 생성",
61
+ benchmark=BenchmarkType.HUMANEVAL,
62
+ temperature=0.05,
63
+ description="HumanEval 문제에 대한 기본 솔루션 생성 (greedy)",
64
+ variables=["problem_prompt"],
65
+ template="""You are a Python writing assistant. Complete the following Python function.
66
+
67
+ {problem_prompt}
68
+
69
+ Please provide a complete implementation of the function."""
70
+ )
71
+
72
+ # HumanEval 다중 함수 처리
73
+ prompts["solution_humaneval_multi"] = PromptTemplate(
74
+ name="HumanEval 다중 함수 솔루션 생성",
75
+ benchmark=BenchmarkType.HUMANEVAL,
76
+ temperature=0.05,
77
+ description="여러 함수가 있는 HumanEval 문제 처리",
78
+ variables=["problem_prompt", "entry_point"],
79
+ template="""You are a Python writing assistant. Complete the following Python function.
80
+
81
+ {problem_prompt}
82
+
83
+ Please provide ONLY the implementation for the function `{entry_point}`.
84
+ Complete the body of the `{entry_point}` function where it is incomplete.
85
+ Do not modify or reimplement other functions that are already complete."""
86
+ )
87
+
88
+ # MBPP 기본 솔루션 생성
89
+ prompts["solution_mbpp_basic"] = PromptTemplate(
90
+ name="MBPP 기본 솔루션 생성",
91
+ benchmark=BenchmarkType.MBPP,
92
+ temperature=0.05,
93
+ description="MBPP 문제에 대한 기본 솔루션 생성",
94
+ variables=["problem_prompt"],
95
+ template="""
96
+ Please generate a complete, self-contained Python script that solves the following problem.
97
+
98
+ CRITICAL REQUIREMENTS:
99
+ - You MUST maintain the EXACT function signature as shown in the examples
100
+ - The function name, parameter names, parameter types, and parameter count MUST match exactly with the examples
101
+ - Look at the assert statements carefully to understand the expected function signature
102
+ - DO NOT change the number of parameters or their types from what is shown in the examples
103
+
104
+ Instructions:
105
+ - Wrap the entire script in a Markdown code block with syntax highlighting (```python ... ```).
106
+ - For each function, include a concise docstring enclosed in triple single quotes (''' ... '''), placed immediately below the def line.
107
+ The docstring should briefly describe:
108
+ • The function's purpose
109
+ • Input parameters
110
+ • Return value
111
+
112
+ Problem statement:
113
+ {problem_prompt}
114
+ """
115
+ )
116
+
117
+ # ================================================================================
118
+ # 2. DIVERSE GENERATION PROMPTS (다양한 프로그램 생성)
119
+ # ================================================================================
120
+
121
+ # HumanEval 다양성 솔루션
122
+ prompts["diverse_humaneval_basic"] = PromptTemplate(
123
+ name="HumanEval 다양성 솔루션 생성",
124
+ benchmark=BenchmarkType.HUMANEVAL,
125
+ temperature=0.7,
126
+ description="HumanEval 문제에 대한 다양한 접근법 솔루션",
127
+ variables=["diversity_instruction", "problem_prompt"],
128
+ template="""You are a Python writing assistant. {diversity_instruction}
129
+
130
+ {problem_prompt}
131
+
132
+ Please provide a complete implementation of the function."""
133
+ )
134
+
135
+ # HumanEval 다양성 다중 함수
136
+ prompts["diverse_humaneval_multi"] = PromptTemplate(
137
+ name="HumanEval 다양성 다중 함수 솔루션",
138
+ benchmark=BenchmarkType.HUMANEVAL,
139
+ temperature=0.7,
140
+ description="다중 함수 HumanEval에 대한 다양성 솔루션",
141
+ variables=["diversity_instruction", "problem_prompt", "entry_point"],
142
+ template="""You are a Python writing assistant. {diversity_instruction}
143
+
144
+ {problem_prompt}
145
+
146
+ Please provide ONLY the implementation for the function `{entry_point}`.
147
+ Complete the body of the `{entry_point}` function where it is incomplete.
148
+ Do not modify or reimplement other functions that are already complete."""
149
+ )
150
+
151
+ # MBPP 다양성 솔루션
152
+ prompts["diverse_mbpp_basic"] = PromptTemplate(
153
+ name="MBPP 다양성 솔루션 생성",
154
+ benchmark=BenchmarkType.MBPP,
155
+ temperature=0.7,
156
+ description="MBPP 문제에 대한 다양한 접근법 솔루션",
157
+ variables=["diversity_instruction", "problem_prompt"],
158
+ template="""Please generate a complete, self-contained Python script that solves the following problem.
159
+
160
+ CRITICAL REQUIREMENTS:
161
+ - You MUST maintain the EXACT function signature as shown in the examples
162
+ - The function name, parameter names, parameter types, and parameter count MUST match exactly with the examples
163
+ - Look at the assert statements carefully to understand the expected function signature
164
+ - DO NOT change the number of parameters or their types from what is shown in the examples
165
+
166
+ Instructions:
167
+ - Wrap the entire script in a Markdown code block with syntax highlighting (```python ... ```).
168
+ - For each function, include a concise docstring enclosed in triple single quotes (''' ... '''), placed immediately below the def line.
169
+ The docstring should briefly describe:
170
+ • The function's purpose
171
+ • Input parameters
172
+ • Return value
173
+
174
+ {diversity_instruction}
175
+
176
+ Problem statement:
177
+ {problem_prompt}
178
+ """
179
+ )
180
+
181
+ # ================================================================================
182
+ # 3. INPUT GENERATION PROMPTS (입력 증강)
183
+ # ================================================================================
184
+
185
+ prompts["input_generation_basic"] = PromptTemplate(
186
+ name="기본 입력 생성",
187
+ benchmark=BenchmarkType.GENERAL,
188
+ temperature=0.5,
189
+ description="기존 IPO 예제를 바탕으로 새로운 입력 생성",
190
+ variables=["problem_description", "existing_examples", "full_code", "arg_type_info"],
191
+ template="""Given the following problem description and its Python function implementation, first analyze the types and valid ranges of the function arguments, then write **5 different example inputs** for the function that cover a diverse mix of typical (general) cases and edge/boundary cases.
192
+
193
+ Problem Description:
194
+ '''
195
+ {problem_description}
196
+ '''
197
+
198
+ Existing Examples from Problem:
199
+ {existing_examples}
200
+
201
+ Function Implementation:
202
+ ```python
203
+ {full_code}
204
+ ```
205
+
206
+ {arg_type_info}
207
+
208
+ Based on the existing examples above, generate 5 NEW diverse test inputs that are different from the existing ones. Each input should be a Python dict where:
209
+ - Keys are the exact parameter names from the function signature
210
+ - Values are appropriate test values for each parameter
211
+
212
+ Format your response as:
213
+ ```python
214
+ examples = [
215
+ {{dict_with_all_function_parameters}}, # Description of this test case
216
+ {{dict_with_all_function_parameters}}, # Description of this test case
217
+ ... # Continue for all 5 examples
218
+ ]
219
+ ```
220
+
221
+ Ensure your examples include:
222
+ - At least 2 typical/general cases
223
+ - At least 2 edge/boundary cases
224
+ - 1 special case (empty, zero, maximum values, etc.)
225
+ - All examples should be DIFFERENT from the existing examples shown above"""
226
+ )
227
+
228
+ # ================================================================================
229
+ # 4. TASK GENERATION PROMPTS (IPO → 추론 태스크)
230
+ # ================================================================================
231
+
232
+ prompts["task_induction"] = PromptTemplate(
233
+ name="Induction 태스크 생성 (AZR code_f)",
234
+ benchmark=BenchmarkType.GENERAL,
235
+ temperature=0.05,
236
+ description="주어진 입력-출력으로부터 프로그램 추론 (AZR 원본)",
237
+ variables=["input_output_pairs", "message"],
238
+ template="""A conversation between User and Assistant.
239
+ The User provides a set of input/output pairs and a message describing the hidden function. The Assistant must:
240
+ 1. **Privately think step-by-step** about how to reconstruct the general function based on the provided examples.
241
+ 2. **Output exactly one** `<think>...</think>` block containing the full reasoning process.
242
+ 3. **Then output exactly one** `<answer>...</answer>` block containing **only** the Python code snippet defining the function `f`—no labels, no comments, no extra text.
243
+ 4. **Do not** generate any text outside these two blocks.
244
+ 5. Follow to the **code requirements** and **formatting rules**.
245
+
246
+ # Code Requirements:
247
+ - Name the entry function `f` (e.g., `def f(...): ...`), you may include nested definitions inside `f`.
248
+ - Ensure the function returns a value.
249
+ - Include at least one input parameter.
250
+ - Make the function deterministic.
251
+ - AVOID the FOLLOWING:
252
+ * Random functions or variables
253
+ * Date/time operations
254
+ * I/O operations (reading files, network requests)
255
+ * Printing or logging
256
+ * Any external state
257
+ - Ensure execution completes within 10 seconds on a modern CPU.
258
+ - All imports and custom class definitions must be at the very top of the code snippet.
259
+ - The snippet must end with a return statement from the main function `f`; anything after will be removed.
260
+
261
+ User:
262
+ # Input and Output Pairs:
263
+ {input_output_pairs}
264
+
265
+ # Message:
266
+ {message}"""
267
+ )
268
+
269
+ prompts["task_deduction"] = PromptTemplate(
270
+ name="Deduction 태스크 생성 (AZR code_o)",
271
+ benchmark=BenchmarkType.GENERAL,
272
+ temperature=0.05,
273
+ description="주어진 프로그램과 입력으로부터 출력 추론 (AZR 원본)",
274
+ variables=["snippet", "input_args"],
275
+ template="""A conversation between User and Assistant.
276
+ The User provides a Python code snippet and specific input values. The Assistant must:
277
+ 1. **Privately think step-by-step** about how the code executes with the given inputs.
278
+ 2. **Output exactly one** `<think>...</think>` block containing your full reasoning.
279
+ 3. **Then output exactly one** `<answer>...</answer>` block containing **only** the output values—no labels, no comments, no extra text.
280
+ 4. **Do not** generate any text outside these two blocks.
281
+ 5. Adhere to the **output rules**.
282
+
283
+ # Output Rules:
284
+ - If the output is a string, wrap it in quotes.
285
+ - For dicts, lists, and other literals, use valid Python literal notation.
286
+
287
+ User:
288
+ # Python Code Snippet:
289
+ {snippet}
290
+
291
+ # Input:
292
+ {input_args}"""
293
+ )
294
+
295
+ prompts["task_abduction"] = PromptTemplate(
296
+ name="Abduction 태스크 생성 (AZR code_i)",
297
+ benchmark=BenchmarkType.GENERAL,
298
+ temperature=0.05,
299
+ description="주어진 프로그램과 출력으로부터 입력 추론 (AZR 원본)",
300
+ variables=["snippet", "output"],
301
+ template="""A conversation between User and Assistant.
302
+ The User provides a Python code snippet and its observed output. The Assistant must:
303
+ 1. **Privately think step-by-step** about which input produces that output.
304
+ 2. **Output exactly one** `<think>...</think>` block containing your full reasoning.
305
+ 3. **Then output exactly one** `<answer>...</answer>` block containing **only** the input values—no labels, no comments, no extra text.
306
+ 4. **Do not** generate any text outside these two blocks.
307
+ 5. Adhere to the **input rules**.
308
+
309
+ # Input Rules:
310
+ - If an argument is a string, wrap it in quotes.
311
+ - For multiple arguments, separate by commas.
312
+ - Use Python literal notation for lists, dicts, tuples.
313
+ - Boolean values must be `True` or `False`.
314
+
315
+ User:
316
+ # Python Code Snippet:
317
+ {snippet}
318
+
319
+ # Observed Output:
320
+ {output}"""
321
+ )
322
+
323
+ # ================================================================================
324
+ # 5. TASK EVALUATION PROMPTS (LLM 태스크 응답)
325
+ # ================================================================================
326
+
327
+ prompts["task_evaluation_basic"] = PromptTemplate(
328
+ name="기본 태스크 평가",
329
+ benchmark=BenchmarkType.GENERAL,
330
+ temperature=0.05,
331
+ description="생성된 추론 태스크에 대한 LLM 응답",
332
+ variables=["task_prompt"],
333
+ template="{task_prompt}"
334
+ )
335
+
336
+ return prompts
337
+
338
+ def get_prompt(self, prompt_key: str, **kwargs) -> str:
339
+ """프롬프트 키로 템플릿을 가져와 변수를 채움"""
340
+ if prompt_key not in self.prompts:
341
+ raise ValueError(f"Unknown prompt key: {prompt_key}")
342
+
343
+ template = self.prompts[prompt_key]
344
+
345
+ # 필수 변수 확인
346
+ missing_vars = []
347
+ for var in template.variables:
348
+ if var not in kwargs:
349
+ missing_vars.append(var)
350
+
351
+ if missing_vars:
352
+ raise ValueError(f"Missing required variables for prompt '{prompt_key}': {missing_vars}")
353
+
354
+ # 템플릿 포맷팅
355
+ try:
356
+ return template.template.format(**kwargs)
357
+ except KeyError as e:
358
+ raise ValueError(f"Template formatting error for prompt '{prompt_key}': {e}")
359
+
360
+ def get_temperature(self, prompt_key: str) -> float:
361
+ """프롬프트의 권장 temperature 반환"""
362
+ if prompt_key not in self.prompts:
363
+ raise ValueError(f"Unknown prompt key: {prompt_key}")
364
+ return self.prompts[prompt_key].temperature
365
+
366
+ def get_diversity_instruction(self, variation_id: int) -> str:
367
+ """variation_id에 따른 다양성 지시문 반환"""
368
+ diversity_instructions = [
369
+ "", # 기본
370
+ "",
371
+ "",
372
+ ""
373
+ ]
374
+
375
+ # diversity_instructions = [
376
+ # "", # 기본
377
+ # "Implement this in a robust way that works well for various examples",
378
+ # "Provide an alternative solution with a unique implementation style:",
379
+ # "Try to implement using a different approach, algorithm, or coding style than typical solutions."
380
+ # ]
381
+
382
+ return diversity_instructions[variation_id % len(diversity_instructions)]
383
+
384
+ def list_prompts(self) -> Dict[str, PromptTemplate]:
385
+ """모든 프롬프트 템플릿 목록 반환"""
386
+ return self.prompts.copy()
387
+
388
+ def get_prompts_by_type(self, benchmark: BenchmarkType) -> Dict[str, PromptTemplate]:
389
+ """벤치마크 타입별 프롬프트 반환"""
390
+ return {
391
+ key: template for key, template in self.prompts.items()
392
+ if template.benchmark == benchmark or template.benchmark == BenchmarkType.GENERAL
393
+ }
394
+
395
+
396
+ # 전역 프롬프트 매니저 인스턴스
397
+ prompt_manager = PromptManager()
398
+
399
+
400
+ # 편의 함수들
401
+ def get_prompt(prompt_key: str, **kwargs) -> str:
402
+ """프롬프트 가져오기 편의 함수"""
403
+ return prompt_manager.get_prompt(prompt_key, **kwargs)
404
+
405
+
406
+ def get_temperature(prompt_key: str) -> float:
407
+ """프롬프트 temperature 가져오기 편의 함수"""
408
+ return prompt_manager.get_temperature(prompt_key)
409
+
410
+
411
+ def get_diversity_instruction(variation_id: int) -> str:
412
+ """다양성 지시문 가져오기 편의 함수"""
413
+ return prompt_manager.get_diversity_instruction(variation_id)
absolute_zero_reasoner/testtime/solution_generator.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Initial Solution Generator
3
+
4
+ AZR 기반 TestTime RLVR을 위한 초기 솔루션 생성기
5
+ 기존 Test-Time-RLVR의 generate_initial_solution 함수를 클래스화하여 확장
6
+ """
7
+
8
+ import re
9
+ import torch
10
+ from typing import Dict, Any, Optional, Tuple, List
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM
12
+
13
+ from .config import TestTimeConfig
14
+ from .logger import TestTimeLogger
15
+ from .prompts import get_prompt, get_temperature, get_diversity_instruction
16
+
17
+ # AZR에서 사용하는 코드 추출 함수 직접 임포트
18
+ from ..rewards.custom_evaluate import extract_code
19
+
20
+ # VLLM 최적화 지원
21
+ try:
22
+ from vllm import LLM, SamplingParams
23
+ VLLM_AVAILABLE = True
24
+ except ImportError:
25
+ VLLM_AVAILABLE = False
26
+
27
+
28
+ class InitialSolutionGenerator:
29
+ """벤치마크 문제에 대한 초기 솔루션 생성"""
30
+
31
+ def __init__(self, model, tokenizer, config: TestTimeConfig,
32
+ logger: Optional[TestTimeLogger] = None, use_vllm: bool = True):
33
+ self.model = model
34
+ self.tokenizer = tokenizer
35
+ self.config = config
36
+ self.logger = logger or TestTimeLogger()
37
+ self.use_vllm = use_vllm and VLLM_AVAILABLE
38
+
39
+ # VLLM 사용 가능 여부 확인 및 로깅
40
+ if use_vllm and not VLLM_AVAILABLE:
41
+ self.logger.log_info("⚠️ VLLM requested but not available, falling back to HuggingFace")
42
+ elif self.use_vllm:
43
+ self.logger.log_info("🚀 Using VLLM for optimized inference")
44
+ else:
45
+ self.logger.log_info("🔧 Using HuggingFace Transformers for inference")
46
+
47
+ def generate(self, problem: Dict[str, Any]) -> str:
48
+ """문제에 대한 초기 솔루션 생성 (AZR 코드 평가 프롬프트 사용)"""
49
+
50
+ problem_prompt = problem['prompt']
51
+ problem_id = problem.get('task_id', 'unknown')
52
+
53
+ # AZR 코드 평가에서 사용하는 프롬프트 포맷 적용
54
+ # prompt = f"Please provide a self-contained Python script that solves the following problem in a markdown code block:\n\n{problem_prompt}"
55
+
56
+ # 중앙 프롬프트 시스템 사용
57
+ if 'HumanEval' in problem_id:
58
+ # entry_point 함수명 찾기
59
+ entry_point = problem.get('entry_point', 'unknown')
60
+
61
+ # 프롬프트에서 함수가 여러 개 있는지 확인
62
+ import re
63
+ function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE))
64
+
65
+ if function_count > 1:
66
+ # 다중 함수 프롬프트 사용
67
+ prompt = get_prompt("solution_humaneval_multi",
68
+ problem_prompt=problem_prompt,
69
+ entry_point=entry_point)
70
+ else:
71
+ # 단일 함수 프롬프트 사용
72
+ prompt = get_prompt("solution_humaneval_basic",
73
+ problem_prompt=problem_prompt)
74
+ else:
75
+ # MBPP 프롬프트 사용
76
+ prompt = get_prompt("solution_mbpp_basic",
77
+ problem_prompt=problem_prompt)
78
+
79
+ self.logger.log_info(f"🔍 Generating initial solution for {problem_id}")
80
+ self.logger.log_info(f"📋 Full prompt: {prompt}")
81
+
82
+ # VLLM 또는 HuggingFace 백엔드 선택
83
+ if self.use_vllm and isinstance(self.model, LLM):
84
+ solution = self._generate_with_vllm(prompt)
85
+ else:
86
+ solution = self._generate_with_huggingface(prompt)
87
+
88
+ # 마크다운 코드 블록에서 Python 코드 추출 (개선된 방식)
89
+ extracted_solution = self._extract_python_code(solution)
90
+
91
+ # 코드 추출 결과 로깅
92
+ if extracted_solution and extracted_solution != solution:
93
+ self.logger.log_info(f"🔍 Extracted Python code from markdown block")
94
+ solution = extracted_solution
95
+ elif not extracted_solution:
96
+ self.logger.log_info(f"🔍 No markdown code block found, using original text")
97
+
98
+ # HumanEval의 경우 프롬프트에서 import 추출하여 추가 (EvalPlus 방식)
99
+ if 'HumanEval' in problem_id:
100
+ solution = self._add_imports_from_prompt(solution, problem_prompt)
101
+
102
+ # 함수 정의 복구 (AZR 로직 그대로)
103
+ solution = self._fix_function_definition(solution, prompt, problem_id)
104
+
105
+ self.logger.log_info(f"✅ Generated solution ({len(solution)} chars)")
106
+ self.logger.log_info(f"🔍 Solution preview: {solution[:200]}...")
107
+
108
+ # 디버깅: 실제 솔루션 내용 로깅
109
+ self.logger.log_info(f"🔍 Full solution for debugging:")
110
+ self.logger.log_info(f"--- START SOLUTION ---")
111
+ self.logger.log_info(solution)
112
+ self.logger.log_info(f"--- END SOLUTION ---")
113
+
114
+ return solution
115
+
116
+ def generate_diverse(self, problem: Dict[str, Any], temperature: float = 0.7, variation_id: int = 0) -> str:
117
+ """다양한 솔루션 생성 (높은 temperature 사용)"""
118
+
119
+ problem_prompt = problem['prompt']
120
+ problem_id = problem.get('task_id', 'unknown')
121
+
122
+ # 중앙 관리 다양성 프롬프트 시스템 사용
123
+ diversity_instruction = get_diversity_instruction(variation_id)
124
+
125
+ # HumanEval에 대해서는 함수 완성 요청 (다양성 버전)
126
+ if 'HumanEval' in problem_id:
127
+ entry_point = problem.get('entry_point', 'unknown')
128
+
129
+ import re
130
+ function_count = len(re.findall(r'^\s*def\s+\w+', problem_prompt, re.MULTILINE))
131
+
132
+ if function_count > 1:
133
+ prompt = get_prompt("diverse_humaneval_multi",
134
+ diversity_instruction=diversity_instruction,
135
+ problem_prompt=problem_prompt,
136
+ entry_point=entry_point)
137
+ else:
138
+ prompt = get_prompt("diverse_humaneval_basic",
139
+ diversity_instruction=diversity_instruction,
140
+ problem_prompt=problem_prompt)
141
+ else:
142
+ # MBPP 다양성 프롬프트 사용
143
+ prompt = get_prompt("diverse_mbpp_basic",
144
+ diversity_instruction=diversity_instruction,
145
+ problem_prompt=problem_prompt)
146
+
147
+ self.logger.log_info(f"🎨 Generating diverse solution #{variation_id+1} for {problem_id}")
148
+
149
+ # 다양성 생성 메서드 사용
150
+ try:
151
+ from vllm import LLM
152
+ if isinstance(self.model, LLM):
153
+ solution = self._generate_with_vllm_diverse(prompt, temperature)
154
+ else:
155
+ solution = self._generate_with_huggingface_diverse(prompt, temperature)
156
+ except ImportError:
157
+ solution = self._generate_with_huggingface_diverse(prompt, temperature)
158
+
159
+ # 코드 추출 및 후처리 (기존과 동일)
160
+ extracted_solution = self._extract_python_code(solution)
161
+ if extracted_solution and extracted_solution != solution:
162
+ self.logger.log_info(f"🔍 Extracted Python code from markdown block")
163
+ solution = extracted_solution
164
+
165
+ if 'HumanEval' in problem_id:
166
+ solution = self._add_imports_from_prompt(solution, problem_prompt)
167
+
168
+ solution = self._fix_function_definition(solution, prompt, problem_id)
169
+
170
+ self.logger.log_info(f"✅ Generated diverse solution #{variation_id+1} ({len(solution)} chars)")
171
+
172
+ return solution
173
+
174
+ def _generate_with_vllm(self, prompt: str) -> str:
175
+ """VLLM 백엔드로 생성 (AZR 방식)"""
176
+
177
+ # AZR evaluation과 동일한 SamplingParams 설정
178
+ sampling_params = SamplingParams(
179
+ temperature=0.05,
180
+ max_tokens=2048, # AZR 평가 설정
181
+ top_p=1.0, # greedy mode
182
+ stop=["\n```\n"], # 코드 블록 종료 시 정지
183
+ )
184
+
185
+ # VLLM 생성
186
+ outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
187
+ solution = outputs[0].outputs[0].text.replace("\t", " ") # AZR 방식 탭 처리
188
+
189
+ return solution.strip()
190
+
191
+ def _generate_with_vllm_diverse(self, prompt: str, temperature: float = 0.7) -> str:
192
+ """다양한 솔루션 생성용 VLLM 백엔드 (높은 temperature)"""
193
+
194
+ # 다양성을 위한 SamplingParams 설정
195
+ sampling_params = SamplingParams(
196
+ temperature=temperature, # 높은 temperature로 다양성 확보
197
+ max_tokens=2048,
198
+ top_p=0.95, # 다양성을 위해 top_p 사용
199
+ stop=["\n```\n"], # 코드 블록 종료 시 정지
200
+ )
201
+
202
+ # VLLM 생성
203
+ outputs = self.model.generate([prompt], sampling_params, use_tqdm=False)
204
+ solution = outputs[0].outputs[0].text.replace("\t", " ")
205
+
206
+ return solution.strip()
207
+
208
+ def generate_batch(self, prompts: List[str], temperature: float = 0.7) -> List[str]:
209
+ """배치로 여러 프롬프트 동시 처리"""
210
+
211
+ # 실제 모델 타입 확인 (VLLM 로딩 실패 시 HuggingFace 모델이 로드됨)
212
+ if self.use_vllm and isinstance(self.model, LLM):
213
+ raw_solutions = self._generate_batch_with_vllm(prompts, temperature)
214
+ else:
215
+ # HuggingFace는 순차 처리 (fallback)
216
+ raw_solutions = [self._generate_with_huggingface(prompt) for prompt in prompts]
217
+
218
+ # 각 솔루션에 대해 후처리 수행
219
+ processed_solutions = []
220
+ for i, (prompt, solution) in enumerate(zip(prompts, raw_solutions)):
221
+ # 1. 마크다운에서 Python 코드 추출
222
+ extracted = self._extract_python_code(solution)
223
+ if extracted and extracted != solution:
224
+ self.logger.log_info(f"🔍 Extracted Python code from markdown block for batch item {i+1}")
225
+ solution = extracted
226
+
227
+ # 2. HumanEval 문제인 경우 import 추가
228
+ # 프롬프트에서 problem ID 추출 (프롬프트에 포함되어 있다고 가정)
229
+ if 'HumanEval' in prompt:
230
+ # 프롬프트에서 원본 problem prompt 추출 시도
231
+ # 프롬프트 구조에 따라 조정 필요
232
+ solution = self._add_imports_from_prompt(solution, prompt)
233
+
234
+ # 3. 함수 정의 수정 (필요한 경우)
235
+ # generate_diverse와 동일한 처리
236
+ solution = self._fix_function_definition(solution, prompt)
237
+
238
+ processed_solutions.append(solution)
239
+
240
+ return processed_solutions
241
+
242
+ def _generate_batch_with_vllm(self, prompts: List[str], temperature: float = 0.7) -> List[str]:
243
+ """VLLM으로 배치 처리"""
244
+
245
+ # VLLM 샘플링 파라미터
246
+ # seed를 제거하여 매번 다른 응답 생성
247
+ sampling_params = SamplingParams(
248
+ temperature=temperature,
249
+ top_p=0.85,
250
+ max_tokens=1024,
251
+ stop=[] # stop 토큰 명시적으로 비움
252
+ )
253
+
254
+ # VLLM 배치 생성
255
+ outputs = self.model.generate(prompts, sampling_params, use_tqdm=False)
256
+
257
+ # 결과 추출
258
+ solutions = []
259
+ for i, output in enumerate(outputs):
260
+ solution = output.outputs[0].text.replace("\t", " ")
261
+ # 디버깅: finish_reason 확인
262
+ finish_reason = output.outputs[0].finish_reason
263
+ if finish_reason != "stop" and i < 3: # 처음 3개만 로깅
264
+ self.logger.log_warning(f"Output {i} finish_reason: {finish_reason}, length: {len(solution)}")
265
+ solutions.append(solution.strip())
266
+
267
+ return solutions
268
+
269
+ def _generate_with_huggingface(self, prompt: str) -> str:
270
+ """HuggingFace 백엔드로 생성 (attention mask 수정)"""
271
+
272
+ # 토크나이저 처리 (attention mask 경고 수정)
273
+ inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
274
+
275
+ # attention mask 명시적으로 설정
276
+ if 'attention_mask' not in inputs:
277
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
278
+
279
+ # 디바이스 이동 (AZR 방식 그대로)
280
+ device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
281
+ if isinstance(device, str):
282
+ inputs = {k: v.to(device) for k, v in inputs.items()}
283
+ else:
284
+ # 모델이 이미 특정 디바이스에 있는 경우
285
+ inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
286
+
287
+ with torch.no_grad():
288
+ # 메모리 정리 (AZR 방식 그대로)
289
+ if torch.cuda.is_available():
290
+ torch.cuda.empty_cache()
291
+
292
+ # AZR evaluation과 동일한 greedy 설정
293
+ outputs = self.model.generate(
294
+ inputs['input_ids'],
295
+ attention_mask=inputs['attention_mask'], # attention mask 명시적으로 전달
296
+ max_new_tokens=2048, # 원래 AZR 평가 설정
297
+ do_sample=False, # greedy mode (--greedy와 동일)
298
+ pad_token_id=self.tokenizer.eos_token_id
299
+ )
300
+
301
+ # 솔루션 추출 (AZR 방식 그대로)
302
+ solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
303
+ solution = solution[len(prompt):].strip()
304
+
305
+ return solution
306
+
307
+ def _generate_with_huggingface_diverse(self, prompt: str, temperature: float = 0.7) -> str:
308
+ """다양한 솔루션 생성용 HuggingFace 백엔드 (높은 temperature)"""
309
+
310
+ # 토크나이저 처리
311
+ inputs = self.tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
312
+
313
+ # attention mask 명시적으로 설정
314
+ if 'attention_mask' not in inputs:
315
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
316
+
317
+ # 디바이스 이동
318
+ device = getattr(self.model, 'device', 'cuda' if torch.cuda.is_available() else 'cpu')
319
+ if isinstance(device, str):
320
+ inputs = {k: v.to(device) for k, v in inputs.items()}
321
+ else:
322
+ # 모델이 이미 특정 디바이스에 있는 경우
323
+ inputs = {k: v.to(next(self.model.parameters()).device) for k, v in inputs.items()}
324
+
325
+ with torch.no_grad():
326
+ # 메모리 정리
327
+ if torch.cuda.is_available():
328
+ torch.cuda.empty_cache()
329
+
330
+ # 다양성을 위한 sampling 설정
331
+ outputs = self.model.generate(
332
+ inputs['input_ids'],
333
+ attention_mask=inputs['attention_mask'],
334
+ max_new_tokens=2048,
335
+ do_sample=True, # sampling 활성화
336
+ temperature=temperature, # 높은 temperature
337
+ top_p=0.95, # 다양성을 위해 top_p 사용
338
+ pad_token_id=self.tokenizer.eos_token_id,
339
+ eos_token_id=self.tokenizer.eos_token_id
340
+ )
341
+
342
+ # 솔루션 추출
343
+ solution = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
344
+ solution = solution[len(prompt):].strip()
345
+
346
+ return solution
347
+
348
+ def _extract_python_code(self, solution: str) -> str:
349
+ """개선된 Python 코드 추출 (AZR 방식 + 추가 패턴)"""
350
+
351
+ # 1. AZR의 extract_code 함수 먼저 시도
352
+ try:
353
+ extracted = extract_code(solution, language="python")
354
+ if extracted:
355
+ return extracted
356
+ except:
357
+ pass
358
+
359
+ # 2. 다양한 마크다운 패턴 시도
360
+ patterns = [
361
+ r'```python\n(.*?)```', # ```python ... ```
362
+ r'```\n(.*?)```', # ``` ... ```
363
+ r'```py\n(.*?)```', # ```py ... ```
364
+ r'```Python\n(.*?)```', # ```Python ... ```
365
+ r'Here is.*?:\n\n```python\n(.*?)```', # 설명 텍스트 포함
366
+ r'Here is.*?:\n\n```\n(.*?)```', # 설명 텍스트 포함
367
+ ]
368
+
369
+ for pattern in patterns:
370
+ matches = re.findall(pattern, solution, re.DOTALL | re.IGNORECASE)
371
+ if matches:
372
+ return matches[-1].strip()
373
+
374
+ # 3. def로 시작하는 함수 찾기
375
+ lines = solution.split('\n')
376
+ code_lines = []
377
+ in_function = False
378
+
379
+ for line in lines:
380
+ if line.strip().startswith('def '):
381
+ in_function = True
382
+ code_lines.append(line)
383
+ elif in_function and (line.startswith(' ') or line.strip() == ''):
384
+ code_lines.append(line)
385
+ elif in_function and line.strip() and not line.startswith(' '):
386
+ # 함수 정의 끝
387
+ break
388
+
389
+ if code_lines:
390
+ return '\n'.join(code_lines)
391
+
392
+ # 4. 원본 반환
393
+ return solution
394
+
395
+ def _add_imports_from_prompt(self, solution: str, prompt: str) -> str:
396
+ """HumanEval 프롬프트에서 import 문을 추출하여 솔루션에 추가 (EvalPlus 방식)"""
397
+
398
+ # 이미 import가 있으면 그대로 반환
399
+ if 'from typing import' in solution or 'import typing' in solution:
400
+ return solution
401
+
402
+ # 프롬프트에서 import 문 추출
403
+ import_lines = []
404
+ prompt_lines = prompt.split('\n')
405
+
406
+ for line in prompt_lines:
407
+ stripped = line.strip()
408
+ # import 문 찾기
409
+ if (stripped.startswith('from ') and 'import' in stripped) or stripped.startswith('import '):
410
+ import_lines.append(line)
411
+ # 함수 정의가 시작되면 중단
412
+ elif stripped.startswith('def '):
413
+ break
414
+
415
+ # import가 없으면 원본 반환
416
+ if not import_lines:
417
+ return solution
418
+
419
+ # import 추가
420
+ self.logger.log_info(f"🔧 Adding imports from prompt: {import_lines}")
421
+
422
+ # 솔루션이 이미 import로 시작하는지 확인
423
+ solution_lines = solution.split('\n')
424
+ first_non_empty_line = None
425
+ for i, line in enumerate(solution_lines):
426
+ if line.strip():
427
+ first_non_empty_line = i
428
+ break
429
+
430
+ # import를 맨 앞에 추가
431
+ if first_non_empty_line is not None:
432
+ # 기존 import 뒤에 추가하거나 맨 앞에 추가
433
+ imports_text = '\n'.join(import_lines) + '\n\n'
434
+
435
+ # 첫 번째 비어있지 않은 줄이 import인 경우
436
+ if solution_lines[first_non_empty_line].strip().startswith(('import ', 'from ')):
437
+ # 마지막 import 찾기
438
+ last_import_idx = first_non_empty_line
439
+ for i in range(first_non_empty_line, len(solution_lines)):
440
+ if solution_lines[i].strip() and not solution_lines[i].strip().startswith(('import ', 'from ')):
441
+ break
442
+ if solution_lines[i].strip().startswith(('import ', 'from ')):
443
+ last_import_idx = i
444
+
445
+ # 마지막 import 다음에 추가
446
+ solution_lines.insert(last_import_idx + 1, '')
447
+ solution_lines.insert(last_import_idx + 1, '\n'.join(import_lines))
448
+ return '\n'.join(solution_lines)
449
+ else:
450
+ # 맨 앞에 추가
451
+ return imports_text + solution
452
+
453
+ return imports_text + solution
454
+
455
+ def _fix_function_definition(self, solution: str, prompt: str, problem_id: str = "") -> str:
456
+ """함수 정의가 누락된 경우 복구 + lpw 스타일 중복 처리"""
457
+
458
+ # lpw 스타일: 프롬프트에서 함수 이름 추출
459
+ func_def_match = re.search(r'def\s+(\w+)\([^)]*\)(?:\s*->\s*[^:]+)?:', prompt)
460
+ if not func_def_match:
461
+ return solution
462
+
463
+ entry_point = func_def_match.group(1)
464
+ func_def_line = func_def_match.group(0)
465
+
466
+ # HumanEval의 경우 전체 코드를 반환하므로 중복 처리 불필요
467
+ if 'HumanEval' in problem_id:
468
+ # 이미 전체 코드가 있으므로 그대로 반환
469
+ return solution
470
+
471
+ # MBPP의 경우 기존 로직 유지
472
+ # Case 1: LLM이 전체 함수를 생성한 경우 (lpw 스타일 체크)
473
+ if (prompt in solution) or (f'def {entry_point}(' in solution):
474
+ # 함수가 이미 포함되어 있음
475
+ self.logger.log_info(f"✅ Function definition already present for {entry_point}")
476
+ return solution
477
+
478
+ # Case 2: 함수 본문만 생성한 경우 - 함수 정의 추가
479
+ if solution and not solution.startswith('def '):
480
+ # 함수 정의와 함수 내용을 결합
481
+ lines = solution.split('\n')
482
+ fixed_lines = [func_def_line]
483
+
484
+ for line in lines:
485
+ if line.strip(): # 빈 줄이 아닌 경우
486
+ # if __name__ == "__main__": 부분은 함수 밖에 있어야 함
487
+ if line.strip().startswith('if __name__'):
488
+ # 함수 정의 끝내고 메인 부분 시작
489
+ fixed_lines.append('') # 빈 줄 추가
490
+ fixed_lines.append(line.strip())
491
+ else:
492
+ # 함수 내용은 4칸 인덴테이션
493
+ if not line.startswith(' ') and line.strip():
494
+ line = ' ' + line.lstrip()
495
+ fixed_lines.append(line)
496
+ else:
497
+ fixed_lines.append(line)
498
+
499
+ solution = '\n'.join(fixed_lines)
500
+ self.logger.log_info(f"🔧 Fixed function definition for {entry_point}")
501
+
502
+ return solution
503
+
504
+ def generate_fallback_solution(self, problem: Dict[str, Any]) -> str:
505
+ """문제 생성 실패 시 대체 솔루션 생성"""
506
+
507
+ entry_point = problem.get('entry_point', 'solution')
508
+ problem_description = problem.get('prompt', '')
509
+
510
+ # 문제 유형별 기본 템플릿 (기존 방식)
511
+ if 'similar_elements' in problem_description:
512
+ # similar_elements 문제 (Mbpp/2)
513
+ solution = f"""def {entry_point}(test_tup1, test_tup2):
514
+ return tuple(set(test_tup1) & set(test_tup2))"""
515
+ elif 'kth_element' in problem_description:
516
+ # kth_element 문제
517
+ solution = f"""def {entry_point}(arr, k):
518
+ return sorted(arr)[k-1]"""
519
+ else:
520
+ # 일반 템플릿
521
+ solution = f"""def {entry_point}(*args):
522
+ # TODO: Implement this function
523
+ return None"""
524
+
525
+ self.logger.log_info(f"🔄 Generated fallback solution for {entry_point}")
526
+ return solution
527
+
528
+ def validate_syntax(self, solution: str) -> Tuple[bool, Optional[str]]:
529
+ """솔루션 구문 검증"""
530
+ try:
531
+ compile(solution, '<string>', 'exec')
532
+ return True, None
533
+ except SyntaxError as e:
534
+ return False, str(e)
535
+ except Exception as e:
536
+ return False, str(e)
537
+
538
+ def extract_function_signature(self, prompt: str) -> Optional[Dict[str, str]]:
539
+ """프롬프트에서 함수 시그니처 추출"""
540
+
541
+ # def function_name(args) -> return_type: 패턴 매칭
542
+ pattern = r'def\s+(\w+)\(([^)]*)\)(?:\s*->\s*([^:]+))?:'
543
+ match = re.search(pattern, prompt)
544
+
545
+ if match:
546
+ func_name = match.group(1)
547
+ args = match.group(2)
548
+ return_type = match.group(3)
549
+
550
+ return {
551
+ 'name': func_name,
552
+ 'args': args.strip(),
553
+ 'return_type': return_type.strip() if return_type else None,
554
+ 'full_signature': match.group(0)
555
+ }
556
+
557
+ return None
558
+
559
+ def format_solution(self, raw_solution: str, problem: Dict[str, Any]) -> str:
560
+ """솔루션 형식 정리"""
561
+
562
+ # 기본 정리
563
+ solution = raw_solution.strip()
564
+
565
+ # 함수 정의 확인 및 수정
566
+ if not solution.startswith('def '):
567
+ signature = self.extract_function_signature(problem.get('prompt', ''))
568
+ if signature:
569
+ # 함수 정의 추가
570
+ lines = solution.split('\n')
571
+ indented_lines = [' ' + line if line.strip() else line for line in lines]
572
+ solution = signature['full_signature'] + '\n' + '\n'.join(indented_lines)
573
+
574
+ # 불필요한 설명 텍스트 제거
575
+ lines = solution.split('\n')
576
+ code_lines = []
577
+ in_function = False
578
+
579
+ for line in lines:
580
+ if line.strip().startswith('def '):
581
+ in_function = True
582
+ code_lines.append(line)
583
+ elif in_function:
584
+ code_lines.append(line)
585
+ elif line.strip() and not any(keyword in line.lower() for keyword in
586
+ ['explanation', 'here', 'this function', 'the solution']):
587
+ code_lines.append(line)
588
+
589
+ return '\n'.join(code_lines).strip()
590
+
591
+ @staticmethod
592
+ def extract_docstring_from_function(code: str) -> str:
593
+ """함수 코드에서 docstring을 추출"""
594
+ import re
595
+
596
+ # 함수 정의 다음에 오는 docstring 패턴 매칭
597
+ # """...""" 또는 '''...''' 형태
598
+ docstring_patterns = [
599
+ r'def\s+\w+\([^)]*\):\s*\n\s*"""(.*?)"""', # """..."""
600
+ r'def\s+\w+\([^)]*\):\s*\n\s*\'\'\'(.*?)\'\'\'', # '''...'''
601
+ ]
602
+
603
+ for pattern in docstring_patterns:
604
+ match = re.search(pattern, code, re.DOTALL)
605
+ if match:
606
+ docstring = match.group(1).strip()
607
+ # 여러 줄인 경우 깔끔하게 정리
608
+ lines = docstring.split('\n')
609
+ cleaned_lines = []
610
+ for line in lines:
611
+ cleaned_line = line.strip()
612
+ if cleaned_line:
613
+ cleaned_lines.append(cleaned_line)
614
+
615
+ return ' '.join(cleaned_lines)
616
+
617
+ # docstring이 없는 경우 기본 메시지 반환
618
+ return "Find the function that produces these outputs from these inputs."
619
+
620
+ def _extract_function_code(self, code: str) -> str:
621
+ """코드에서 함수 정의와 필요한 import 추출"""
622
+ import re
623
+
624
+ lines = code.strip().split('\n')
625
+ import_lines = []
626
+ func_lines = []
627
+ in_function = False
628
+ indent_level = 0
629
+
630
+ # 1. import 문 수집
631
+ for line in lines:
632
+ stripped = line.strip()
633
+ if (stripped.startswith('import ') or stripped.startswith('from ')) and not stripped.startswith('#'):
634
+ import_lines.append(line)
635
+
636
+ # 2. 함수 정의 찾기
637
+ for line in lines:
638
+ if line.strip().startswith('def '):
639
+ in_function = True
640
+ func_lines = [line]
641
+ # 첫 줄의 들여쓰기 레벨 저장
642
+ indent_level = len(line) - len(line.lstrip())
643
+ elif in_function:
644
+ # 빈 줄이거나 같은/더 깊은 들여쓰기면 함수의 일부
645
+ if not line.strip() or (line.strip() and len(line) - len(line.lstrip()) > indent_level):
646
+ func_lines.append(line)
647
+ else:
648
+ # 함수 끝
649
+ break
650
+
651
+ # 3. import + function 결합
652
+ if func_lines:
653
+ result_lines = import_lines + [''] + func_lines if import_lines else func_lines
654
+ return '\n'.join(result_lines)
655
+ else:
656
+ return code
657
+
658
+ def evaluate_solution(self, problem: Dict[str, Any], solution: str) -> Dict[str, Any]:
659
+ """LLM 솔루션을 벤치마크 테스트로 평가 (EvalPlus 필수)"""
660
+ try:
661
+ # EvalPlus 함수들 임포트 (pip으로 설치된 버전 사용)
662
+ self.logger.log_info("🔄 Attempting to import EvalPlus...")
663
+ from evalplus.evaluate import check_correctness
664
+ from evalplus.gen.util import trusted_exec
665
+ from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
666
+ from evalplus.eval import PASS
667
+ self.logger.log_info("✅ Using EvalPlus for evaluation")
668
+ except ImportError as e:
669
+ # EvalPlus가 없으면 오류로 처리 (fallback 제거)
670
+ self.logger.log_error(f"❌ EvalPlus is required but not available: {e}")
671
+ import traceback
672
+ self.logger.log_error(f"📋 Import traceback: {traceback.format_exc()}")
673
+ return {
674
+ 'correct': False,
675
+ 'passed_tests': 0,
676
+ 'total_tests': 0,
677
+ 'error': f"EvalPlus import failed: {e}. Please install EvalPlus properly.",
678
+ 'execution_results': [],
679
+ 'base_passed': 0,
680
+ 'plus_passed': 0,
681
+ 'base_total': 0,
682
+ 'plus_total': 0
683
+ }
684
+ except Exception as e:
685
+ self.logger.log_error(f"❌ EvalPlus import failed with unexpected error: {e}")
686
+ return {
687
+ 'correct': False,
688
+ 'passed_tests': 0,
689
+ 'total_tests': 0,
690
+ 'error': f"EvalPlus import error: {e}",
691
+ 'execution_results': [],
692
+ 'base_passed': 0,
693
+ 'plus_passed': 0,
694
+ 'base_total': 0,
695
+ 'plus_total': 0
696
+ }
697
+
698
+ result = {
699
+ 'correct': False,
700
+ 'passed_tests': 0,
701
+ 'total_tests': 0,
702
+ 'error': None,
703
+ 'execution_results': [],
704
+ 'base_passed': 0,
705
+ 'plus_passed': 0,
706
+ 'base_total': 0,
707
+ 'plus_total': 0
708
+ }
709
+
710
+ try:
711
+ # 1. 함수 정의 추출
712
+ extracted_code = self._extract_function_code(solution)
713
+ if not extracted_code:
714
+ result['error'] = "No function definition found"
715
+ return result
716
+
717
+ # 2. 데이터셋 타입 결정
718
+ task_id = problem.get('task_id', '')
719
+ if task_id.startswith('Mbpp'):
720
+ dataset = 'mbpp'
721
+ elif task_id.startswith('HumanEval'):
722
+ dataset = 'humaneval'
723
+ else:
724
+ # 기본값
725
+ dataset = 'mbpp'
726
+
727
+ # 3. expected outputs 생성 (canonical solution 사용)
728
+ entry_point = problem.get('entry_point', '')
729
+ canonical_solution = problem.get('canonical_solution', '')
730
+
731
+ if not canonical_solution:
732
+ result['error'] = "No canonical_solution found"
733
+ return result
734
+
735
+ # Expected outputs 계산
736
+ expected_output = {}
737
+
738
+ # Base tests
739
+ base_inputs = problem.get('base_input', [])
740
+ if base_inputs:
741
+ expected_output['base'], expected_output['base_time'] = trusted_exec(
742
+ problem.get('prompt', '') + canonical_solution,
743
+ base_inputs,
744
+ entry_point,
745
+ record_time=True,
746
+ output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS
747
+ )
748
+
749
+ # Plus tests
750
+ plus_inputs = problem.get('plus_input', [])
751
+ if plus_inputs:
752
+ expected_output['plus'], expected_output['plus_time'] = trusted_exec(
753
+ problem.get('prompt', '') + canonical_solution,
754
+ plus_inputs,
755
+ entry_point,
756
+ record_time=True,
757
+ output_not_none=entry_point in MBPP_OUTPUT_NOT_NONE_TASKS
758
+ )
759
+
760
+ # 4. EvalPlus check_correctness 호출
761
+ evalplus_result = check_correctness(
762
+ dataset=dataset,
763
+ completion_id=0,
764
+ problem=problem,
765
+ solution=extracted_code,
766
+ expected_output=expected_output,
767
+ base_only=False, # Plus tests도 실행
768
+ fast_check=False, # 모든 테스트 실행
769
+ identifier=task_id
770
+ )
771
+
772
+ # 5. 결과 파싱
773
+ if 'base' in evalplus_result:
774
+ base_stat, base_details = evalplus_result['base']
775
+ result['base_total'] = len(base_inputs)
776
+ if base_stat == PASS:
777
+ result['base_passed'] = result['base_total']
778
+ else:
779
+ result['base_passed'] = sum(1 for d in base_details if d) if base_details else 0
780
+
781
+ result['passed_tests'] += result['base_passed']
782
+ result['total_tests'] += result['base_total']
783
+
784
+ if 'plus' in evalplus_result:
785
+ plus_stat, plus_details = evalplus_result['plus']
786
+ result['plus_total'] = len(plus_inputs)
787
+ if plus_stat == PASS:
788
+ result['plus_passed'] = result['plus_total']
789
+ else:
790
+ result['plus_passed'] = sum(1 for d in plus_details if d) if plus_details else 0
791
+
792
+ result['passed_tests'] += result['plus_passed']
793
+ result['total_tests'] += result['plus_total']
794
+
795
+ # EvalPlus 기준: 모든 테스트 통과해야 correct
796
+ result['correct'] = (result['passed_tests'] == result['total_tests']) and result['total_tests'] > 0
797
+
798
+ # 에러 메시지 설정
799
+ if not result['correct']:
800
+ if base_stat != PASS:
801
+ result['error'] = f"Base tests failed: {base_stat}"
802
+ elif 'plus' in evalplus_result and plus_stat != PASS:
803
+ result['error'] = f"Plus tests failed: {plus_stat}"
804
+
805
+ # 로깅
806
+ self.logger.log_info(f"EvalPlus evaluation for {task_id}:")
807
+ self.logger.log_info(f" Base: {result['base_passed']}/{result['base_total']}")
808
+ self.logger.log_info(f" Plus: {result['plus_passed']}/{result['plus_total']}")
809
+ self.logger.log_info(f" Total: {result['passed_tests']}/{result['total_tests']}")
810
+ self.logger.log_info(f" Correct: {result['correct']}")
811
+
812
+ except Exception as e:
813
+ result['error'] = f"Evaluation failed: {str(e)}"
814
+ import traceback
815
+ self.logger.log_info(f"Evaluation traceback: {traceback.format_exc()}")
816
+
817
+ return result
818
+
819
+
820
+ @staticmethod
821
+ def load_model_with_optimizations(model_name: str, device: str,
822
+ config: TestTimeConfig, use_vllm: bool = True, tensor_parallel_size: int = 1) -> Tuple[Any, Any]:
823
+ """모델과 토크나이저 로드 (AZR 스타일 최적화, VLLM 지원)"""
824
+
825
+ # 토크나이저 로드
826
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
827
+ if tokenizer.pad_token is None:
828
+ tokenizer.pad_token = tokenizer.eos_token
829
+
830
+ # VLLM 사용 가능 여부 확인 및 모델 로드
831
+ if use_vllm and VLLM_AVAILABLE and device.startswith('cuda'):
832
+ try:
833
+ # GPU 디바이스 설정 (이미 설정된 CUDA_VISIBLE_DEVICES 우선 사용)
834
+ import os
835
+ if 'CUDA_VISIBLE_DEVICES' not in os.environ:
836
+ gpu_id = device.split(':')[1] if ':' in device else '0'
837
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
838
+ else:
839
+ # 이미 설정된 CUDA_VISIBLE_DEVICES 사용
840
+ gpu_id = os.environ['CUDA_VISIBLE_DEVICES']
841
+ print(f"🎯 Using existing CUDA_VISIBLE_DEVICES: {gpu_id}")
842
+
843
+ # VLLM 모델 로드 (Ray Actor 환경에서 메모리 최적화)
844
+ model = LLM(
845
+ model=model_name,
846
+ dtype=str(config.torch_dtype).split('.')[-1], # torch.float16 -> float16
847
+ trust_remote_code=True,
848
+ gpu_memory_utilization=config.gpu_memory_utilization,
849
+ max_model_len=getattr(config, 'max_model_len', 2048), # 충분한 길이로 증가
850
+ tensor_parallel_size=tensor_parallel_size, # GPU 개수에 맞춤
851
+ )
852
+ print(f"✅ VLLM model loaded successfully on GPU {gpu_id} (tensor_parallel_size={tensor_parallel_size})")
853
+ return model, tokenizer
854
+ except Exception as e:
855
+ import traceback
856
+ print(f"⚠️ VLLM loading failed: {e}")
857
+ print(f"🔍 Full traceback: {traceback.format_exc()}")
858
+ print(f"🔄 Falling back to HuggingFace")
859
+
860
+ # HuggingFace 모델 로드 (기존 방식)
861
+ model = AutoModelForCausalLM.from_pretrained(
862
+ model_name,
863
+ torch_dtype=config.torch_dtype,
864
+ device_map=device if device.startswith('cuda') else None,
865
+ trust_remote_code=True,
866
+ attn_implementation="flash_attention_2" if config.use_flash_attention and device.startswith('cuda') else None,
867
+ use_cache=False, # 학습용으로 캐시 비활성화
868
+ )
869
+
870
+ # Gradient checkpointing 활성화
871
+ # Gradient checkpointing 비활성화 - 추론 시에는 불필요하고 경고만 발생
872
+ # 학습이 필요한 경우 별도로 활성화해야 함
873
+ if hasattr(model, 'gradient_checkpointing_disable'):
874
+ model.gradient_checkpointing_disable()
875
+
876
+ print(f"✅ HuggingFace model loaded successfully")
877
+ return model, tokenizer
absolute_zero_reasoner/testtime/task_generator.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TestTime Task Generator
3
+
4
+ AZR 추론용 프롬프트 기반 Induction/Deduction/Abduction 태스크 생성
5
+ 요구사항 3: "AZR처럼 템플릿을 활용하여 induction, deduction, abduction 문제를 생성"
6
+ """
7
+
8
+ from typing import Dict, List, Any, Optional, Tuple
9
+ import random
10
+
11
+ from .config import TestTimeConfig
12
+ from .logger import TestTimeLogger
13
+ # AZR 추론용 프롬프트 직접 사용
14
+ from ..data_construction.prompts import get_code_problem_predictor_prompt
15
+ from .solution_generator import InitialSolutionGenerator
16
+
17
+
18
+ class TestTimeTaskGenerator:
19
+ """IPO 트리플에서 3종 태스크 생성"""
20
+
21
+ def __init__(self, config: TestTimeConfig, logger: Optional[TestTimeLogger] = None):
22
+ self.config = config
23
+ self.logger = logger or TestTimeLogger()
24
+
25
+ # AZR 추론용 프롬프트 직접 사용 (get_code_problem_predictor_prompt)
26
+ # 함수 코드 정리용 solution generator 인스턴스 생성
27
+ self.solution_generator = InitialSolutionGenerator(None, None, config, logger)
28
+
29
+ def generate_tasks(self, ipo_triples: List[Dict[str, Any]],
30
+ problem_id: str, round_num: int = 1) -> Dict[str, List[Dict[str, Any]]]:
31
+ """IPO 트리플에서 3종 태스크 생성 (각 트리플마다 3가지 태스크 모두 생성)"""
32
+
33
+ self.logger.log_info(f"🎯 Generating tasks for {problem_id} from {len(ipo_triples)} triples")
34
+
35
+ # 🔧 수정: 분배 로직 제거, 각 IPO 트리플에서 3가지 태스크 모두 생성
36
+ induction_tasks = []
37
+ deduction_tasks = []
38
+ abduction_tasks = []
39
+
40
+ for i, triple in enumerate(ipo_triples):
41
+ # 각 트리플에서 induction 태스크 생성
42
+ induction_task = self._generate_single_induction_task(triple, i, problem_id, round_num)
43
+ if induction_task:
44
+ induction_tasks.append(induction_task)
45
+
46
+ # 각 트리플에서 deduction 태스크 생성
47
+ deduction_task = self._generate_single_deduction_task(triple, i, problem_id, round_num)
48
+ if deduction_task:
49
+ deduction_tasks.append(deduction_task)
50
+
51
+ # 각 트리플에서 abduction 태스크 생성
52
+ abduction_task = self._generate_single_abduction_task(triple, i, problem_id, round_num)
53
+ if abduction_task:
54
+ abduction_tasks.append(abduction_task)
55
+
56
+ all_tasks = {
57
+ 'induction': induction_tasks,
58
+ 'deduction': deduction_tasks,
59
+ 'abduction': abduction_tasks
60
+ }
61
+
62
+ # 로깅
63
+ task_counts = {k: len(v) for k, v in all_tasks.items()}
64
+ total_generated = sum(task_counts.values())
65
+
66
+ self.logger.log_info(f"✅ Generated {len(induction_tasks)} induction, {len(deduction_tasks)} deduction, {len(abduction_tasks)} abduction tasks")
67
+
68
+ self.logger.log_task_generation(
69
+ problem_id,
70
+ induction_tasks,
71
+ deduction_tasks,
72
+ abduction_tasks
73
+ )
74
+
75
+ return all_tasks
76
+
77
+ def _generate_single_induction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]:
78
+ """단일 IPO 트리플에서 induction 태스크 생성"""
79
+
80
+ try:
81
+ # 입력-출력 쌍 준비
82
+ # 평가를 위해서는 실제 인자(triple['input'])를 사용
83
+ input_output_pairs = [(triple['input'], triple['actual_output'])]
84
+
85
+ # 표시용으로는 full_input_str 사용
86
+ display_input = triple.get('full_input_str', triple['input'])
87
+
88
+ # 🔧 수정: clean한 함수 코드만 추출 (test case 제거)
89
+ clean_program = self._extract_clean_function_code(triple['program'])
90
+
91
+ # 매개변수로 받은 problem_id 사용 (AZR 통합용)
92
+ original_problem_id = triple.get('id', '').split('_triple_')[0] # 원본 추출 로직 보존
93
+
94
+ # HumanEval인 경우 특별 처리
95
+ if 'HumanEval' in problem_id:
96
+ # 원본 프로그램에서 함수 설명 추출 (doctest 예시가 있는 원본에서)
97
+ extracted_message = self._extract_function_description(triple['program'])
98
+ if not extracted_message:
99
+ extracted_message = "Find the function that produces these outputs from these inputs."
100
+ else:
101
+ # MBPP는 기존 방식 유지
102
+ extracted_message = InitialSolutionGenerator.extract_docstring_from_function(clean_program)
103
+
104
+ # 사용자 정의: input_output_pairs + message → program
105
+ # 프롬프트용으로는 display 입력 사용
106
+ display_pairs = [(display_input, triple['actual_output'])]
107
+ azr_prompt = get_code_problem_predictor_prompt(
108
+ problem_type='code_f',
109
+ snippet=clean_program, # 🔧 수정: clean한 코드 사용
110
+ input_output_pairs=display_pairs,
111
+ message=extracted_message
112
+ )
113
+
114
+ # AZR 메타데이터 생성
115
+ source_program_id = triple.get('source_program_id', f'program_{index//3}')
116
+ ipo_index = triple.get('ipo_index', index % 3)
117
+
118
+ task = {
119
+ 'task_id': f'induction_{index}',
120
+ 'task_type': 'induction',
121
+ 'triple_id': triple['id'],
122
+ 'source_program_id': source_program_id, # 🆕 추가
123
+ 'ipo_index': ipo_index, # 🆕 추가
124
+ 'ipo_triple': { # 🆕 추가
125
+ 'input': triple['input'],
126
+ 'output': triple['actual_output'],
127
+ 'program': triple['program']
128
+ },
129
+ 'prompt': azr_prompt,
130
+ 'expected_solution': clean_program, # 🔧 수정: clean한 코드 사용
131
+ 'evaluation_data': {
132
+ 'input_output_pairs': input_output_pairs, # 평가용으로는 실제 인자 사용
133
+ 'original_function': triple['program']
134
+ },
135
+
136
+ # 🆕 AZR 학습용 메타데이터
137
+ 'uid': f"{problem_id}_round_{round_num}_induction_{index}",
138
+ 'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}",
139
+ 'original_problem_id': problem_id,
140
+ 'round': round_num,
141
+ 'extra_info': {'metric': 'code_f'}, # induction task는 code_f
142
+ 'basic_accuracy': 0.0, # 초기값, evaluation에서 업데이트됨
143
+ 'ground_truth': clean_program # AZR parquet 형식에서 사용
144
+ }
145
+
146
+ return task
147
+
148
+ except Exception as e:
149
+ self.logger.log_error(f"Failed to generate induction task for triple {triple.get('id', 'unknown')}: {e}")
150
+ return None
151
+
152
+ def _generate_single_deduction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]:
153
+ """단일 IPO 트리플에서 deduction 태스크 생성"""
154
+
155
+ try:
156
+ # 매개변수로 받은 problem_id 사용 (AZR 통합용)
157
+ original_problem_id = triple.get('id', '').split('_triple_')[0] # 원본 추출 로직 보존
158
+
159
+ # HumanEval인 경우 doctest 예시 제거
160
+ if 'HumanEval' in original_problem_id:
161
+ clean_program = self._remove_doctest_examples(triple['program'])
162
+ else:
163
+ # MBPP는 기존 방식 유지
164
+ clean_program = self._extract_clean_function_code(triple['program'])
165
+
166
+ # 사용자 정의: program + input → output
167
+ azr_prompt = get_code_problem_predictor_prompt(
168
+ problem_type='code_o', # 프로그램+입력→출력
169
+ snippet=clean_program, # 🔧 수정: clean한 코드 사용
170
+ input_args=triple['input']
171
+ )
172
+
173
+ # AZR 메타데이터 생성
174
+ source_program_id = triple.get('source_program_id', f'program_{index//3}')
175
+ ipo_index = triple.get('ipo_index', index % 3)
176
+
177
+ task = {
178
+ 'task_id': f'deduction_{index}',
179
+ 'task_type': 'deduction',
180
+ 'triple_id': triple['id'],
181
+ 'source_program_id': source_program_id, # 🆕 추가
182
+ 'ipo_index': ipo_index, # 🆕 추가
183
+ 'ipo_triple': { # 🆕 추가
184
+ 'input': triple['input'],
185
+ 'output': triple['actual_output'],
186
+ 'program': triple['program']
187
+ },
188
+ 'prompt': azr_prompt,
189
+ 'expected_solution': triple['actual_output'], # 🔧 수정: expected_solution으로 통일
190
+ 'evaluation_data': {
191
+ 'function_code': clean_program, # 🔧 수정: clean한 코드 사용 (complete_pipeline과 일치)
192
+ 'test_input': triple['input'], # 🔧 수정: complete_pipeline과 일치
193
+ 'original_function': triple['program']
194
+ },
195
+
196
+ # 🆕 AZR 학습용 메타데이터
197
+ 'uid': f"{problem_id}_round_{round_num}_deduction_{index}",
198
+ 'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}",
199
+ 'original_problem_id': problem_id,
200
+ 'round': round_num,
201
+ 'extra_info': {'metric': 'code_o'}, # deduction task는 code_o
202
+ 'basic_accuracy': 0.0, # 초기값, evaluation에서 업데이트됨
203
+ 'ground_truth': triple['actual_output'] # AZR parquet 형식에서 ���용
204
+ }
205
+
206
+ return task
207
+
208
+ except Exception as e:
209
+ self.logger.log_error(f"Failed to generate deduction task for triple {triple.get('id', 'unknown')}: {e}")
210
+ return None
211
+
212
+ def _generate_single_abduction_task(self, triple: Dict[str, Any], index: int, problem_id: str, round_num: int) -> Optional[Dict[str, Any]]:
213
+ """단일 IPO 트리플에서 abduction 태스크 생성"""
214
+
215
+ try:
216
+ # 매개변수로 받은 problem_id 사용 (AZR 통합용)
217
+ original_problem_id = triple.get('id', '').split('_triple_')[0] # 원본 추출 로직 보존
218
+
219
+ # HumanEval인 경우 doctest 예시 제거
220
+ if 'HumanEval' in original_problem_id:
221
+ clean_program = self._remove_doctest_examples(triple['program'])
222
+ else:
223
+ # MBPP는 기존 방식 유지
224
+ clean_program = self._extract_clean_function_code(triple['program'])
225
+
226
+ # 사용자 정의: program + output → input
227
+ azr_prompt = get_code_problem_predictor_prompt(
228
+ problem_type='code_i', # 프로그램+출력→입력
229
+ snippet=clean_program, # 🔧 수정: clean한 코드 사용
230
+ output=triple['actual_output'] # 🔧 수정: output 파라미터 사용
231
+ )
232
+
233
+ # AZR 메타데이터 생성
234
+ source_program_id = triple.get('source_program_id', f'program_{index//3}')
235
+ ipo_index = triple.get('ipo_index', index % 3)
236
+
237
+ task = {
238
+ 'task_id': f'abduction_{index}',
239
+ 'task_type': 'abduction',
240
+ 'triple_id': triple['id'],
241
+ 'source_program_id': source_program_id, # 🆕 추가
242
+ 'ipo_index': ipo_index, # 🆕 추가
243
+ 'ipo_triple': { # 🆕 추가
244
+ 'input': triple['input'],
245
+ 'output': triple['actual_output'],
246
+ 'program': triple['program']
247
+ },
248
+ 'prompt': azr_prompt,
249
+ 'expected_solution': triple.get('full_input_str', triple['input']), # 🔧 수정: 전체 함수 호출 사용
250
+ 'evaluation_data': {
251
+ 'function_code': clean_program, # 🔧 수정: clean한 코드 사용 (complete_pipeline과 일치)
252
+ 'expected_output': triple['actual_output'], # 🔧 수정: complete_pipeline과 일치
253
+ 'original_function': triple['program']
254
+ },
255
+
256
+ # 🆕 AZR 학습용 메타데이터
257
+ 'uid': f"{problem_id}_round_{round_num}_abduction_{index}",
258
+ 'ipo_group_id': f"{problem_id}_program_{source_program_id}_ipo_{ipo_index}",
259
+ 'original_problem_id': problem_id,
260
+ 'round': round_num,
261
+ 'extra_info': {'metric': 'code_i'}, # abduction task는 code_i
262
+ 'basic_accuracy': 0.0, # 초기값, evaluation에서 업데이트됨
263
+ 'ground_truth': triple.get('full_input_str', triple['input']) # AZR parquet 형식에서 사용
264
+ }
265
+
266
+ return task
267
+
268
+ except Exception as e:
269
+ self.logger.log_error(f"Failed to generate abduction task for triple {triple.get('id', 'unknown')}: {e}")
270
+ return None
271
+
272
+ def generate_induction_tasks(self, ipo_triples: List[Dict[str, Any]],
273
+ count: int) -> List[Dict[str, Any]]:
274
+ """Induction 태스크: 입력-출력 쌍에서 프로그램 추론 (사용자 정의 유지)"""
275
+
276
+ tasks = []
277
+ selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples)))
278
+
279
+ for i, triple in enumerate(selected_triples):
280
+ # 입력-출력 쌍 준비
281
+ input_output_pairs = [(triple['input'], triple['actual_output'])]
282
+
283
+ # 🔧 수정: clean한 함수 코드만 추출 (test case 제거)
284
+ clean_program = self._extract_clean_function_code(triple['program'])
285
+
286
+ # LLM이 생성한 함수에서 docstring 추출해서 message로 사용
287
+ extracted_message = InitialSolutionGenerator.extract_docstring_from_function(clean_program)
288
+
289
+ # 사용자 정의: input_output_pairs + message → program
290
+ azr_prompt = get_code_problem_predictor_prompt(
291
+ problem_type='code_f',
292
+ snippet=clean_program, # 🔧 수정: clean한 코드 사용
293
+ input_output_pairs=input_output_pairs,
294
+ message=extracted_message
295
+ )
296
+
297
+ task = {
298
+ 'task_id': f'induction_{i}',
299
+ 'task_type': 'induction',
300
+ 'triple_id': triple['id'],
301
+ 'prompt': azr_prompt,
302
+ 'expected_solution': clean_program, # 🔧 수정: clean한 코드 사용
303
+ 'evaluation_data': {
304
+ 'input_output_pairs': input_output_pairs,
305
+ 'original_function': triple['program']
306
+ }
307
+ }
308
+
309
+ tasks.append(task)
310
+
311
+ return tasks
312
+
313
+ def generate_deduction_tasks(self, ipo_triples: List[Dict[str, Any]],
314
+ count: int) -> List[Dict[str, Any]]:
315
+ """Deduction 태스크: 프로그램+입력에서 출력 예측 (사용자 정의에 맞게 수정)"""
316
+
317
+ tasks = []
318
+ selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples)))
319
+
320
+ for i, triple in enumerate(selected_triples):
321
+ # 🔧 수정: clean한 함수 코드만 추출 (test case 제거)
322
+ clean_program = self._extract_clean_function_code(triple['program'])
323
+
324
+ # 사용자 정의: program + input → output
325
+ azr_prompt = get_code_problem_predictor_prompt(
326
+ problem_type='code_o', # 프로그램+입력→출력
327
+ snippet=clean_program, # 🔧 수정: clean한 코드 사용
328
+ input_args=triple['input']
329
+ )
330
+
331
+ task = {
332
+ 'task_id': f'deduction_{i}',
333
+ 'task_type': 'deduction',
334
+ 'triple_id': triple['id'],
335
+ 'prompt': azr_prompt,
336
+ 'expected_solution': triple['actual_output'],
337
+ 'evaluation_data': {
338
+ 'function_code': clean_program, # 🔧 수정: clean한 코드 사용
339
+ 'test_input': triple['input']
340
+ }
341
+ }
342
+
343
+ tasks.append(task)
344
+
345
+ return tasks
346
+
347
+ def generate_abduction_tasks(self, ipo_triples: List[Dict[str, Any]],
348
+ count: int) -> List[Dict[str, Any]]:
349
+ """Abduction 태스크: 프로그램+출력에서 입력 예측 (사용자 정의에 맞게 수정)"""
350
+
351
+ tasks = []
352
+ selected_triples = random.sample(ipo_triples, min(count, len(ipo_triples)))
353
+
354
+ for i, triple in enumerate(selected_triples):
355
+ # 🔧 수정: clean한 함수 코드만 추출 (test case 제거)
356
+ clean_program = self._extract_clean_function_code(triple['program'])
357
+
358
+ # 사용자 정의: program + output → input
359
+ azr_prompt = get_code_problem_predictor_prompt(
360
+ problem_type='code_i', # 프로그램+출력→입력
361
+ snippet=clean_program, # 🔧 수정: clean한 코드 사용
362
+ output=triple['actual_output']
363
+ )
364
+
365
+ task = {
366
+ 'task_id': f'abduction_{i}',
367
+ 'task_type': 'abduction',
368
+ 'triple_id': triple['id'],
369
+ 'prompt': azr_prompt,
370
+ 'expected_solution': triple.get('full_input_str', triple['input']), # 🔧 수정: 전체 함수 호출 사용
371
+ 'evaluation_data': {
372
+ 'function_code': clean_program, # 🔧 수정: clean한 코드 사용
373
+ 'expected_output': triple['actual_output']
374
+ }
375
+ }
376
+
377
+ tasks.append(task)
378
+
379
+ return tasks
380
+
381
+ def _extract_clean_function_code(self, program_with_tests: str) -> str:
382
+ """🔧 수정: 프로그램에서 test case와 assert문을 제거하고 순수한 함수 코드만 추출"""
383
+
384
+ # solution_generator의 _extract_function_code 메서드 사용
385
+ clean_code = self.solution_generator._extract_function_code(program_with_tests)
386
+
387
+ # 로깅 (디버깅용)
388
+ if "assert" in program_with_tests or "# Test" in program_with_tests:
389
+ self.logger.log_info("🧹 Cleaned function code (removed test cases)")
390
+
391
+ return clean_code
392
+
393
+ def get_task_statistics(self, all_tasks: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
394
+ """태스크 생성 통계"""
395
+
396
+ stats = {
397
+ 'total_tasks': sum(len(tasks) for tasks in all_tasks.values()),
398
+ 'tasks_by_type': {task_type: len(tasks) for task_type, tasks in all_tasks.items()},
399
+ 'task_types': list(all_tasks.keys())
400
+ }
401
+
402
+ return stats
403
+
404
+ def _remove_doctest_examples(self, code: str) -> str:
405
+ """HumanEval 코드에서 doctest 예시 제거"""
406
+ import re
407
+
408
+ lines = code.split('\n')
409
+ result_lines = []
410
+ in_docstring = False
411
+ docstring_indent = 0
412
+ skip_next = False
413
+
414
+ for line in lines:
415
+ stripped = line.strip()
416
+
417
+ # docstring 시작/끝 감지
418
+ if '"""' in line or "'''" in line:
419
+ if not in_docstring:
420
+ in_docstring = True
421
+ docstring_indent = len(line) - len(line.lstrip())
422
+ result_lines.append(line)
423
+ else:
424
+ in_docstring = False
425
+ result_lines.append(line)
426
+ continue
427
+
428
+ # doctest 예시 라인 건너뛰기
429
+ if in_docstring:
430
+ if stripped.startswith('>>>'):
431
+ skip_next = True # 다음 라인(결과)도 건너뛰기
432
+ continue
433
+ elif skip_next and stripped and not stripped.startswith('>>>'):
434
+ skip_next = False
435
+ continue
436
+ else:
437
+ skip_next = False
438
+
439
+ result_lines.append(line)
440
+
441
+ return '\n'.join(result_lines)
442
+
443
+ def _extract_function_description(self, code: str) -> str:
444
+ """docstring에서 함수 설명 추출 (예시 제외)"""
445
+ import re
446
+
447
+ # 여러 형태의 docstring 매칭
448
+ patterns = [
449
+ r'"""(.*?)"""', # triple double quotes
450
+ r"'''(.*?)'''", # triple single quotes
451
+ ]
452
+
453
+ for pattern in patterns:
454
+ match = re.search(pattern, code, re.DOTALL)
455
+ if match:
456
+ description = match.group(1).strip()
457
+ # 예시 전까지의 모든 설명 추출
458
+ result_lines = []
459
+ lines = description.split('\n')
460
+ for line in lines:
461
+ cleaned_line = line.strip()
462
+ # >>> 예시가 시작되면 중단
463
+ if cleaned_line.startswith('>>>'):
464
+ break
465
+ # 빈 줄이 아니고 예시가 아닌 경우 추가
466
+ if cleaned_line:
467
+ result_lines.append(cleaned_line)
468
+
469
+ # 모든 설명 라인을 공백으로 연결
470
+ if result_lines:
471
+ return ' '.join(result_lines)
472
+
473
+ return ""
absolute_zero_reasoner/trainer/__init__.py ADDED
File without changes
absolute_zero_reasoner/trainer/ppo/__init__.py ADDED
File without changes
absolute_zero_reasoner/trainer/ppo/azr_ray_trainer.py ADDED
The diff for this file is too large to render. See raw diff
 
absolute_zero_reasoner/trainer/ppo/reason_rl_ray_trainer.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import Optional
3
+ from copy import deepcopy
4
+ from collections import defaultdict
5
+
6
+ from omegaconf import OmegaConf, open_dict
7
+ import torch
8
+ import numpy as np
9
+ from torch.utils.data import Dataset, Sampler
10
+ from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, reduce_metrics, compute_data_metrics, compute_timing_metrics, AdvantageEstimator, compute_response_mask
11
+ from verl.utils.debug import marked_timer
12
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto, DataProto
13
+ from verl.utils.dataset.rl_dataset import collate_fn
14
+ from verl import DataProto
15
+ from verl.protocol import pad_dataproto_to_divisor, unpad_dataproto
16
+ from verl.single_controller.ray import RayWorkerGroup
17
+ from verl.trainer.ppo import core_algos
18
+ from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
19
+ from verl.trainer.ppo.ray_trainer import Role, WorkerType, ResourcePoolManager
20
+ from verl.utils.tracking import ValidationGenerationsLogger
21
+
22
+ from absolute_zero_reasoner.utils.dataset.rl_dataset import RLHFDataset
23
+
24
+
25
+
26
+ class ReasonRLRayPPOTrainer(RayPPOTrainer):
27
+ def __init__(
28
+ self,
29
+ config,
30
+ tokenizer,
31
+ role_worker_mapping: dict[Role, WorkerType],
32
+ resource_pool_manager: ResourcePoolManager,
33
+ ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
34
+ processor=None,
35
+ reward_fn=None,
36
+ val_reward_fn=None,
37
+ train_dataset: Optional[Dataset] = None,
38
+ val_dataset: Optional[Dataset] = None,
39
+ collate_fn=None,
40
+ train_sampler: Optional[Sampler] = None,
41
+ device_name="cuda",
42
+ ):
43
+ """
44
+ Initialize distributed PPO trainer with Ray backend.
45
+ Note that this trainer runs on the driver process on a single CPU/GPU node.
46
+
47
+ Args:
48
+ config: Configuration object containing training parameters.
49
+ tokenizer: Tokenizer used for encoding and decoding text.
50
+ role_worker_mapping (dict[Role, WorkerType]): Mapping from roles to worker classes.
51
+ resource_pool_manager (ResourcePoolManager): Manager for Ray resource pools.
52
+ ray_worker_group_cls (RayWorkerGroup, optional): Class for Ray worker groups. Defaults to RayWorkerGroup.
53
+ processor: Optional data processor, used for multimodal data
54
+ reward_fn: Function for computing rewards during training.
55
+ val_reward_fn: Function for computing rewards during validation.
56
+ train_dataset (Optional[Dataset], optional): Training dataset. Defaults to None.
57
+ val_dataset (Optional[Dataset], optional): Validation dataset. Defaults to None.
58
+ collate_fn: Function to collate data samples into batches.
59
+ train_sampler (Optional[Sampler], optional): Sampler for the training dataset. Defaults to None.
60
+ device_name (str, optional): Device name for training (e.g., "cuda", "cpu"). Defaults to "cuda".
61
+ """
62
+
63
+ # Store the tokenizer for text processing
64
+ self.tokenizer = tokenizer
65
+ self.processor = processor
66
+ self.config = config
67
+ self.reward_fn = reward_fn
68
+ self.val_reward_fn = val_reward_fn
69
+
70
+ self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
71
+ assert self.hybrid_engine, "Currently, only support hybrid engine"
72
+
73
+ if self.hybrid_engine:
74
+ assert Role.ActorRollout in role_worker_mapping, f"{role_worker_mapping.keys()=}"
75
+
76
+ self.role_worker_mapping = role_worker_mapping
77
+ self.resource_pool_manager = resource_pool_manager
78
+ self.use_reference_policy = Role.RefPolicy in role_worker_mapping
79
+ self.use_rm = Role.RewardModel in role_worker_mapping
80
+ self.ray_worker_group_cls = ray_worker_group_cls
81
+ self.device_name = device_name
82
+ self.validation_generations_logger = ValidationGenerationsLogger()
83
+
84
+ # if ref_in_actor is True, the reference policy will be actor without lora applied
85
+ self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
86
+
87
+ # define in-reward KL control
88
+ # kl loss control currently not suppoorted
89
+ if config.algorithm.use_kl_in_reward:
90
+ self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl)
91
+
92
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.GAE:
93
+ self.use_critic = True
94
+ elif self.config.algorithm.adv_estimator in [
95
+ AdvantageEstimator.GRPO,
96
+ AdvantageEstimator.GRPO_PASSK,
97
+ AdvantageEstimator.REINFORCE_PLUS_PLUS,
98
+ AdvantageEstimator.REMAX,
99
+ AdvantageEstimator.RLOO,
100
+ AdvantageEstimator.OPO,
101
+ AdvantageEstimator.REINFORCE_PLUS_PLUS_BASELINE,
102
+ ]:
103
+ self.use_critic = False
104
+ else:
105
+ raise NotImplementedError
106
+
107
+ self._validate_config()
108
+ self._create_dataloader()
109
+
110
+ def _validate_config(self):
111
+ config = self.config
112
+ # number of GPUs total
113
+ n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes
114
+ if config.actor_rollout_ref.actor.strategy == "megatron":
115
+ model_parallel_size = config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size * config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size
116
+ assert n_gpus % (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size) == 0, f"n_gpus ({n_gpus}) must be divisible by model_parallel_size ({model_parallel_size}) times context_parallel_size ({config.actor_rollout_ref.actor.megatron.context_parallel_size})"
117
+ megatron_dp = n_gpus // (model_parallel_size * config.actor_rollout_ref.actor.megatron.context_parallel_size)
118
+ minimal_bsz = megatron_dp * config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
119
+ else:
120
+ minimal_bsz = n_gpus
121
+
122
+ # 1. Check total batch size for data correctness
123
+ real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n
124
+ assert real_train_batch_size % minimal_bsz == 0, f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})."
125
+
126
+ # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu"
127
+ # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu".
128
+ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
129
+ settings = {
130
+ "actor_rollout_ref.actor": "micro_batch_size",
131
+ "critic": "micro_batch_size",
132
+ "reward_model": "micro_batch_size",
133
+ "actor_rollout_ref.ref": "log_prob_micro_batch_size",
134
+ "actor_rollout_ref.rollout": "log_prob_micro_batch_size",
135
+ }
136
+
137
+ if name in settings:
138
+ param = settings[name]
139
+ param_per_gpu = f"{param}_per_gpu"
140
+
141
+ if mbs is None and mbs_per_gpu is None:
142
+ raise ValueError(f"[{name}] Please set at least one of '{name}.{param}' or '{name}.{param_per_gpu}'.")
143
+
144
+ if mbs is not None and mbs_per_gpu is not None:
145
+ raise ValueError(f"[{name}] You have set both '{name}.{param}' AND '{name}.{param_per_gpu}'. Please remove '{name}.{param}' because only '*_{param_per_gpu}'" + "is supported (the former is deprecated).")
146
+
147
+ if not config.actor_rollout_ref.actor.use_dynamic_bsz:
148
+ # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu
149
+ check_mutually_exclusive(
150
+ config.actor_rollout_ref.actor.ppo_micro_batch_size,
151
+ config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu,
152
+ "actor_rollout_ref.actor",
153
+ )
154
+
155
+ if self.use_reference_policy:
156
+ # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
157
+ check_mutually_exclusive(
158
+ config.actor_rollout_ref.ref.log_prob_micro_batch_size,
159
+ config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu,
160
+ "actor_rollout_ref.ref",
161
+ )
162
+
163
+ # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu
164
+ check_mutually_exclusive(
165
+ config.actor_rollout_ref.rollout.log_prob_micro_batch_size,
166
+ config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu,
167
+ "actor_rollout_ref.rollout",
168
+ )
169
+
170
+ if self.use_critic and not config.critic.use_dynamic_bsz:
171
+ # Check for critic micro-batch size conflicts
172
+ check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, "critic")
173
+
174
+ # Check for reward model micro-batch size conflicts
175
+ if config.reward_model.enable and not config.reward_model.use_dynamic_bsz:
176
+ check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, "reward_model")
177
+
178
+ # Actor
179
+ # check if train_batch_size is larger than ppo_mini_batch_size
180
+ # if NOT dynamic_bsz, we must ensure:
181
+ # ppo_mini_batch_size is divisible by ppo_micro_batch_size
182
+ # ppo_micro_batch_size * sequence_parallel_size >= n_gpus
183
+ if not config.actor_rollout_ref.actor.use_dynamic_bsz:
184
+ # assert config.data.train_batch_size >= config.actor_rollout_ref.actor.ppo_mini_batch_size
185
+ sp_size = config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1)
186
+ if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None:
187
+ assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0
188
+ assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus
189
+
190
+ assert config.actor_rollout_ref.actor.loss_agg_mode in [
191
+ "token-mean",
192
+ "seq-mean-token-sum",
193
+ "seq-mean-token-mean",
194
+ "seq-mean-token-sum-norm",
195
+ ], f"Invalid loss_agg_mode: {config.actor_rollout_ref.actor.loss_agg_mode}"
196
+
197
+ if config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
198
+ print("NOTICE: You have both enabled in-reward kl and kl loss.")
199
+
200
+ # critic
201
+ if self.use_critic and not config.critic.use_dynamic_bsz:
202
+ assert config.data.train_batch_size >= config.critic.ppo_mini_batch_size
203
+ sp_size = config.critic.get("ulysses_sequence_parallel_size", 1)
204
+ if config.critic.ppo_micro_batch_size is not None:
205
+ assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
206
+ assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus
207
+
208
+ # Check if use_remove_padding is enabled when using sequence parallelism for fsdp
209
+ if config.actor_rollout_ref.actor.strategy == "fsdp" and (config.actor_rollout_ref.actor.get("ulysses_sequence_parallel_size", 1) > 1 or config.actor_rollout_ref.ref.get("ulysses_sequence_parallel_size", 1) > 1):
210
+ assert config.actor_rollout_ref.model.use_remove_padding, "When using sequence parallelism for actor/ref policy, you must enable `use_remove_padding`."
211
+
212
+ if self.use_critic and config.critic.strategy == "fsdp":
213
+ if config.critic.get("ulysses_sequence_parallel_size", 1) > 1:
214
+ assert config.critic.model.use_remove_padding, "When using sequence parallelism for critic, you must enable `use_remove_padding`."
215
+
216
+ if config.data.get("val_batch_size", None) is not None:
217
+ print("WARNING: val_batch_size is deprecated." + " Validation datasets are sent to inference engines as a whole batch," + " which will schedule the memory themselves.")
218
+
219
+ # check eval config
220
+ if config.actor_rollout_ref.rollout.val_kwargs.do_sample:
221
+ assert config.actor_rollout_ref.rollout.temperature > 0, "validation gen temperature should be greater than 0 when enabling do_sample"
222
+
223
+ # check multi_turn with tool config
224
+ if config.actor_rollout_ref.rollout.multi_turn.enable:
225
+ assert config.actor_rollout_ref.rollout.multi_turn.tool_config_path is not None or config.actor_rollout_ref.rollout.multi_turn.interaction_config_path is not None, "tool_config_path or interaction_config_path must be set when enabling multi_turn with tool, due to no role-playing support"
226
+ assert config.algorithm.adv_estimator in [AdvantageEstimator.GRPO], "only GRPO is tested for multi-turn with tool"
227
+
228
+ print("[validate_config] All configuration checks passed successfully!")
229
+
230
+ def _create_dataloader(self):
231
+ """
232
+ Changed the prompt length of validation set to have another prompt length.
233
+ Create the train and val dataloader.
234
+ """
235
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
236
+ self.train_dataset = RLHFDataset(parquet_files=self.config.data.train_files,
237
+ tokenizer=self.tokenizer,
238
+ prompt_key=self.config.data.prompt_key,
239
+ max_prompt_length=self.config.data.max_prompt_length,
240
+ filter_prompts=True,
241
+ return_raw_chat=self.config.data.get('return_raw_chat', False),
242
+ truncation='error',
243
+ extra_source_key="train")
244
+ # use sampler for better ckpt resume
245
+ if self.config.data.shuffle:
246
+ train_dataloader_generator = torch.Generator()
247
+ train_dataloader_generator.manual_seed(self.config.data.get('seed', 1))
248
+ sampler = RandomSampler(data_source=self.train_dataset, generator=train_dataloader_generator)
249
+ else:
250
+ sampler = SequentialSampler(data_source=self.train_dataset)
251
+
252
+ self.train_dataloader = DataLoader(dataset=self.train_dataset,
253
+ batch_size=self.config.data.train_batch_size,
254
+ drop_last=True,
255
+ collate_fn=collate_fn,
256
+ sampler=sampler)
257
+
258
+ self.val_dataset = RLHFDataset(parquet_files=self.config.data.val_files,
259
+ tokenizer=self.tokenizer,
260
+ prompt_key=self.config.data.prompt_key,
261
+ max_prompt_length=self.config.data.max_prompt_length,
262
+ filter_prompts=True,
263
+ return_raw_chat=self.config.data.get('return_raw_chat', False),
264
+ truncation='error',
265
+ extra_source_key="val")
266
+ self.val_dataloader = DataLoader(dataset=self.val_dataset,
267
+ batch_size=len(self.val_dataset),
268
+ shuffle=True,
269
+ drop_last=True,
270
+ collate_fn=collate_fn)
271
+
272
+ assert len(self.train_dataloader) >= 1
273
+ assert len(self.val_dataloader) >= 1
274
+
275
+ print(f'Size of train dataloader: {len(self.train_dataloader)}')
276
+ print(f'Size of val dataloader: {len(self.val_dataloader)}')
277
+
278
+ # inject total_training_steps to actor/critic optim_config. This is hacky.
279
+ total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
280
+
281
+ if self.config.trainer.total_training_steps is not None:
282
+ total_training_steps = self.config.trainer.total_training_steps
283
+
284
+ self.total_training_steps = total_training_steps
285
+ print(f'Total training steps: {self.total_training_steps}')
286
+
287
+ OmegaConf.set_struct(self.config, True)
288
+ with open_dict(self.config):
289
+ self.config.actor_rollout_ref.actor.optim.total_training_steps = total_training_steps
290
+ # Only set critic total_training_steps if critic is actually used
291
+ if self.use_critic:
292
+ self.config.critic.optim.total_training_steps = total_training_steps
293
+
294
+
295
+ def _validate(self, do_sample: bool = False):
296
+ """
297
+ The validation loop of PPO.
298
+ The only difference is logging more metrics.
299
+ """
300
+ from collections import defaultdict
301
+ reward_tensor_lst = []
302
+ data_source_lst = []
303
+
304
+ # Lists to collect samples for the table
305
+ sample_inputs = []
306
+ sample_outputs = []
307
+ sample_scores = []
308
+
309
+ all_eval_metrics = defaultdict(list)
310
+
311
+ for test_data in self.val_dataloader:
312
+ test_batch = DataProto.from_single_dict(test_data)
313
+
314
+ # we only do validation on rule-based rm
315
+ if self.config.reward_model.enable and test_batch[0].non_tensor_batch['reward_model']['style'] == 'model':
316
+ return {}
317
+
318
+ # Store original inputs
319
+ input_ids = test_batch.batch['input_ids']
320
+ input_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
321
+ sample_inputs.extend(input_texts)
322
+
323
+ batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
324
+ non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
325
+ if "multi_modal_data" in test_batch.non_tensor_batch:
326
+ non_tensor_batch_keys_to_pop.append("multi_modal_data")
327
+ if "raw_prompt" in test_batch.non_tensor_batch:
328
+ non_tensor_batch_keys_to_pop.append("raw_prompt")
329
+ if "tools_kwargs" in test_batch.non_tensor_batch:
330
+ non_tensor_batch_keys_to_pop.append("tools_kwargs")
331
+ if "interaction_kwargs" in test_batch.non_tensor_batch:
332
+ non_tensor_batch_keys_to_pop.append("interaction_kwargs")
333
+ test_gen_batch = test_batch.pop(
334
+ batch_keys=batch_keys_to_pop,
335
+ non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
336
+ )
337
+
338
+ test_gen_batch.meta_info = {
339
+ 'eos_token_id': self.tokenizer.eos_token_id,
340
+ 'pad_token_id': self.tokenizer.pad_token_id,
341
+ 'recompute_log_prob': False,
342
+ 'do_sample': do_sample,
343
+ 'validate': True,
344
+ }
345
+
346
+ # pad to be divisible by dp_size
347
+ size_divisor = self.actor_rollout_wg.world_size if not self.async_rollout_mode else self.config.actor_rollout_ref.rollout.agent.num_workers
348
+ test_gen_batch_padded, pad_size = pad_dataproto_to_divisor(test_gen_batch, size_divisor)
349
+ if not self.async_rollout_mode:
350
+ test_output_gen_batch_padded = self.actor_rollout_wg.generate_sequences(test_gen_batch_padded)
351
+ else:
352
+ test_output_gen_batch_padded = self.async_rollout_manager.generate_sequences(test_gen_batch_padded)
353
+
354
+ # unpad
355
+ test_output_gen_batch = unpad_dataproto(test_output_gen_batch_padded, pad_size=pad_size)
356
+ print('validation generation end')
357
+
358
+ # Store generated outputs
359
+ output_ids = test_output_gen_batch.batch["responses"]
360
+
361
+ output_texts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output_ids]
362
+ sample_outputs.extend(output_texts)
363
+
364
+ test_batch = test_batch.union(test_output_gen_batch)
365
+
366
+ # evaluate using reward_function
367
+ reward_tensor, eval_metrics = self.val_reward_fn(test_batch)
368
+ for k, v in eval_metrics.items():
369
+ all_eval_metrics[k].append(v)
370
+
371
+ # Store scores
372
+ scores = reward_tensor.sum(-1).cpu().tolist()
373
+ sample_scores.extend(scores)
374
+
375
+ reward_tensor_lst.append(reward_tensor)
376
+ data_source_lst.append(test_batch.non_tensor_batch.get('data_source', ['unknown'] * reward_tensor.shape[0]))
377
+
378
+ self._maybe_log_val_generations(inputs=sample_inputs, outputs=sample_outputs, scores=sample_scores)
379
+
380
+ reward_tensor = torch.cat(reward_tensor_lst, dim=0).sum(-1).cpu() # (batch_size,)
381
+ data_sources = np.concatenate(data_source_lst, axis=0)
382
+
383
+ # evaluate test_score based on data source
384
+ data_source_reward = {}
385
+ for i in range(reward_tensor.shape[0]):
386
+ data_source = data_sources[i]
387
+ if data_source not in data_source_reward:
388
+ data_source_reward[data_source] = []
389
+ data_source_reward[data_source].append(reward_tensor[i].item())
390
+
391
+ metric_dict = {}
392
+ for data_source, rewards in data_source_reward.items():
393
+ metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)
394
+
395
+ for k, v in all_eval_metrics.items():
396
+ metric_dict[k] = np.mean(v)
397
+
398
+ if self.config.eval.get('save_generations', False):
399
+ import json
400
+ with open(f'{self.config.trainer.experiment_name}_generations_{self.global_steps}.json', 'w') as f:
401
+ json.dump({
402
+ 'inputs': sample_inputs,
403
+ 'outputs': sample_outputs,
404
+ 'scores': sample_scores
405
+ }, f)
406
+ return metric_dict
407
+
408
+
409
+ def fit(self):
410
+ """
411
+ The training loop of PPO.
412
+ The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
413
+ The light-weight advantage computation is done on the driver process.
414
+
415
+ The only difference is logging more metrics.
416
+ """
417
+ from absolute_zero_reasoner.utils.tracking import ReasonRLTracking
418
+ from absolute_zero_reasoner.utils.logging_utils.stdout import PrettyPrinter as pp
419
+ from omegaconf import OmegaConf
420
+
421
+ # Display training setup header
422
+ pp.section_header("Training Setup")
423
+
424
+ logger = ReasonRLTracking(
425
+ project_name=self.config.trainer.project_name,
426
+ experiment_name=self.config.trainer.experiment_name,
427
+ default_backend=self.config.trainer.logger,
428
+ config=OmegaConf.to_container(self.config, resolve=True),
429
+ tags=self.config.trainer.wandb_tags,
430
+ resume="must" if self.config.trainer.resume_mode == 'auto' and \
431
+ self.config.trainer.wandb_run_id is not None else False, # Add resume flag
432
+ run_id=self.config.trainer.wandb_run_id \
433
+ if self.config.trainer.wandb_run_id is not None else None # Pass existing run ID
434
+ )
435
+
436
+ pp.status("Config", f"Project: {self.config.trainer.project_name}, Experiment: {self.config.trainer.experiment_name}", "info")
437
+ pp.status("Algorithm", f"Using {self.config.algorithm.adv_estimator} advantage estimator", "info")
438
+ pp.status("Setup", f"Critic enabled: {self.use_critic}, Reference policy: {self.use_reference_policy}", "info")
439
+
440
+ self.global_steps = 0
441
+
442
+ # load checkpoint before doing anything
443
+ pp.status("Checkpoint", "Loading checkpoint if available...", "info")
444
+ self._load_checkpoint()
445
+
446
+ # base model chat template
447
+ if self.config.actor_rollout_ref.model.pretrained_tokenizer:
448
+ self.tokenizer.chat_template = "{%- for message in messages -%}{{- '\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}"
449
+
450
+ # perform validation before training
451
+ # currently, we only support validation using the reward_function.
452
+ if self.val_reward_fn is not None and self.config.trainer.get('val_before_train', True) and self.global_steps == 0:
453
+ pp.section_header("Initial Validation")
454
+ pp.status("Validation", "Running initial validation...", "info")
455
+
456
+ val_metrics = self._validate(do_sample=self.config.eval.do_sample)
457
+
458
+ # Convert metrics to table format
459
+ metrics_table = []
460
+ for k, v in val_metrics.items():
461
+ metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
462
+
463
+ pp.table(["Metric", "Value"], metrics_table, "Initial Validation Results")
464
+ logger.log(data=val_metrics, step=self.global_steps)
465
+
466
+ # save val metrics to model path
467
+ if self.config.eval.get('log_to_model_path', False):
468
+ import json
469
+ import os
470
+ with open(os.path.join(self.config.actor_rollout_ref.model.path, 'math_metrics.json'), 'w') as f:
471
+ json.dump(val_metrics, f)
472
+
473
+ if self.config.trainer.get('val_only', False):
474
+ pp.status("Training", "Validation only mode, exiting", "success")
475
+ return
476
+
477
+ # we start from step 1
478
+ self.global_steps += 1
479
+ last_val_metrics = None
480
+ self.max_steps_duration = 0
481
+
482
+ pp.section_header("Starting Training")
483
+ pp.status("Training", f"Starting training for {self.config.trainer.total_epochs} epochs ({total_steps} steps)", "info")
484
+
485
+ for epoch in range(self.config.trainer.total_epochs):
486
+ pp.status("Epoch", f"Starting epoch {epoch+1}/{self.config.trainer.total_epochs}", "info")
487
+
488
+ for batch_idx, batch_dict in enumerate(self.train_dataloader):
489
+ do_profile = self.global_steps in self.config.trainer.profile_steps if self.config.trainer.profile_steps is not None else False
490
+ if do_profile:
491
+ self.actor_rollout_wg.start_profile()
492
+ if self.use_reference_policy:
493
+ self.ref_policy_wg.start_profile()
494
+ if self.use_critic:
495
+ self.critic_wg.start_profile()
496
+ if self.use_rm:
497
+ self.rm_wg.start_profile()
498
+
499
+ metrics = {}
500
+ timing_raw = {}
501
+ batch: DataProto = DataProto.from_single_dict(batch_dict)
502
+
503
+ # pop those keys for generation
504
+ batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
505
+ non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
506
+ if "multi_modal_data" in batch.non_tensor_batch:
507
+ non_tensor_batch_keys_to_pop.append("multi_modal_data")
508
+ if "raw_prompt" in batch.non_tensor_batch:
509
+ non_tensor_batch_keys_to_pop.append("raw_prompt")
510
+ if "tools_kwargs" in batch.non_tensor_batch:
511
+ non_tensor_batch_keys_to_pop.append("tools_kwargs")
512
+ if "interaction_kwargs" in batch.non_tensor_batch:
513
+ non_tensor_batch_keys_to_pop.append("interaction_kwargs")
514
+ gen_batch = batch.pop(
515
+ batch_keys=batch_keys_to_pop,
516
+ non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
517
+ )
518
+
519
+ is_last_step = self.global_steps >= self.total_training_steps
520
+
521
+ with marked_timer("step", timing_raw):
522
+ # generate a batch
523
+ with marked_timer("gen", timing_raw, color="red"):
524
+ pp.status("Step", f"Generating sequences for batch {batch_idx+1}", "info")
525
+ gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
526
+
527
+ if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
528
+ with marked_timer("gen_max", timing_raw, color="purple"):
529
+ gen_baseline_batch = deepcopy(gen_batch)
530
+ gen_baseline_batch.meta_info["do_sample"] = False
531
+ gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
532
+
533
+ batch = batch.union(gen_baseline_output)
534
+ reward_baseline_tensor, _ = self.reward_fn(batch)
535
+ reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
536
+
537
+ batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
538
+
539
+ batch.batch["reward_baselines"] = reward_baseline_tensor
540
+
541
+ del gen_baseline_batch, gen_baseline_output
542
+
543
+ pp.status("Processing", "Preparing batch with UUIDs", "info")
544
+ batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))],
545
+ dtype=object)
546
+ # repeat to align with repeated responses in rollout
547
+ batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
548
+ batch = batch.union(gen_batch_output)
549
+
550
+ batch.batch["response_mask"] = compute_response_mask(batch)
551
+ pp.status("Processing", "Balancing batch across ranks", "info")
552
+ # Balance the number of valid tokens across DP ranks.
553
+ # NOTE: This usually changes the order of data in the `batch`,
554
+ # which won't affect the advantage calculation (since it's based on uid),
555
+ # but might affect the loss calculation (due to the change of mini-batching).
556
+ # TODO: Decouple the DP balancing and mini-batching.
557
+ if self.config.trainer.balance_batch:
558
+ self._balance_batch(batch, metrics=metrics)
559
+
560
+ # compute global_valid tokens
561
+ batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()
562
+ # recompute old_log_probs
563
+ with marked_timer("old_log_prob", timing_raw, color="blue"):
564
+ old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
565
+ entropys = old_log_prob.batch["entropys"]
566
+ response_masks = batch.batch["response_mask"]
567
+ loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
568
+ entropy_agg = core_algos.agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
569
+ old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
570
+ metrics.update(old_log_prob_metrics)
571
+ old_log_prob.batch.pop("entropys")
572
+ batch = batch.union(old_log_prob)
573
+
574
+ if "rollout_log_probs" in batch.batch.keys():
575
+ # TODO: we may want to add diff of probs too.
576
+ rollout_old_log_probs = batch.batch["rollout_log_probs"]
577
+ actor_old_log_probs = batch.batch["old_log_probs"]
578
+ attention_mask = batch.batch["attention_mask"]
579
+ responses = batch.batch["responses"]
580
+ response_length = responses.size(1)
581
+ response_mask = attention_mask[:, -response_length:]
582
+
583
+ rollout_probs = torch.exp(rollout_old_log_probs)
584
+ actor_probs = torch.exp(actor_old_log_probs)
585
+ rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
586
+ rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
587
+ rollout_probs_diff_max = torch.max(rollout_probs_diff)
588
+ rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
589
+ rollout_probs_diff_std = torch.std(rollout_probs_diff)
590
+ metrics.update(
591
+ {
592
+ "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
593
+ "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
594
+ "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
595
+ }
596
+ )
597
+
598
+ if self.use_reference_policy:
599
+ # compute reference log_prob
600
+ with marked_timer("ref", timing_raw, color="olive"):
601
+ if not self.ref_in_actor:
602
+ ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
603
+ else:
604
+ ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
605
+ batch = batch.union(ref_log_prob)
606
+
607
+ # compute values
608
+ if self.use_critic:
609
+ with marked_timer('values', timing_raw):
610
+ pp.status("Computation", "Computing critic values", "info")
611
+ values = self.critic_wg.compute_values(batch)
612
+ batch = batch.union(values)
613
+
614
+ with marked_timer('adv', timing_raw):
615
+ # compute scores. Support both model and function-based.
616
+ pp.status("Rewards", "Computing rewards", "info")
617
+ if self.use_rm:
618
+ # we first compute reward model score
619
+ reward_tensor = self.rm_wg.compute_rm_score(batch)
620
+ batch = batch.union(reward_tensor)
621
+
622
+ # we combine with rule-based rm
623
+ reward_tensor, train_metrics = self.reward_fn(batch)
624
+ train_metrics = {k: np.mean(v) for k, v in train_metrics.items()}
625
+ metrics.update(train_metrics)
626
+ batch.batch['token_level_scores'] = reward_tensor
627
+
628
+ # compute rewards. apply_kl_penalty if available
629
+ if self.config.algorithm.use_kl_in_reward:
630
+ pp.status("KL Penalty", "Applying KL penalty", "info")
631
+ batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty)
632
+ metrics.update(kl_metrics)
633
+ else:
634
+ batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
635
+
636
+ # compute advantages, executed on the driver process
637
+ pp.status("Advantage", f"Computing {self.config.algorithm.adv_estimator} advantage", "info")
638
+ batch = compute_advantage(batch,
639
+ adv_estimator=self.config.algorithm.adv_estimator,
640
+ gamma=self.config.algorithm.gamma,
641
+ lam=self.config.algorithm.lam,
642
+ num_repeat=self.config.actor_rollout_ref.rollout.n)
643
+
644
+ # update critic
645
+ if self.use_critic:
646
+ with marked_timer('update_critic', timing_raw):
647
+ pp.status("Update", "Updating critic network", "info")
648
+ critic_output = self.critic_wg.update_critic(batch)
649
+ critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
650
+ metrics.update(critic_output_metrics)
651
+
652
+ # implement critic warmup
653
+ if self.config.trainer.critic_warmup <= self.global_steps:
654
+ # update actor
655
+ with marked_timer('update_actor', timing_raw):
656
+ batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
657
+ pp.status("Update", "Updating actor network", "info")
658
+ actor_output = self.actor_rollout_wg.update_actor(batch)
659
+ actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
660
+ metrics.update(actor_output_metrics)
661
+
662
+ # Log rollout generations if enabled
663
+ rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
664
+ if rollout_data_dir:
665
+ with marked_timer("dump_rollout_generations", timing_raw, color="green"):
666
+ print(batch.batch.keys())
667
+ inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
668
+ outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
669
+ scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
670
+ self._dump_generations(
671
+ inputs=inputs,
672
+ outputs=outputs,
673
+ scores=scores,
674
+ reward_extra_infos_dict=train_metrics,
675
+ dump_path=rollout_data_dir,
676
+ )
677
+
678
+ # validate
679
+ if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
680
+ self.global_steps % self.config.trainer.test_freq == 0:
681
+ with marked_timer('testing', timing_raw):
682
+ pp.section_header(f"Validation (Step {self.global_steps})")
683
+ pp.status("Validation", "Running validation", "info")
684
+ val_metrics: dict = self._validate()
685
+ if is_last_step:
686
+ last_val_metrics = val_metrics
687
+
688
+ # Convert metrics to table format
689
+ val_metrics_table = []
690
+ for k, v in val_metrics.items():
691
+ val_metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
692
+
693
+ pp.table(["Metric", "Value"], val_metrics_table, f"Validation Results (Step {self.global_steps})")
694
+ metrics.update(val_metrics)
695
+
696
+ if self.config.trainer.save_freq > 0 and \
697
+ self.global_steps % self.config.trainer.save_freq == 0:
698
+ with marked_timer('save_checkpoint', timing_raw):
699
+ pp.status("Checkpoint", f"Saving checkpoint at step {self.global_steps}", "success")
700
+ self._save_checkpoint()
701
+
702
+ steps_duration = timing_raw["step"]
703
+ self.max_steps_duration = max(self.max_steps_duration, steps_duration)
704
+ # training metrics
705
+ metrics.update(
706
+ {
707
+ "training/global_step": self.global_steps,
708
+ "training/epoch": epoch,
709
+ }
710
+ )
711
+ # collect metrics
712
+ metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
713
+ metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
714
+
715
+ # Display key metrics in a table
716
+ key_metrics = {k: v for k, v in metrics.items()}
717
+ if key_metrics:
718
+ metrics_table = []
719
+ for k, v in key_metrics.items():
720
+ metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
721
+ pp.table(["Metric", "Value"], metrics_table, f"Step {self.global_steps} Results")
722
+
723
+ # Display timing info
724
+ timing_metrics = {k: v for k, v in metrics.items() if 'time' in k}
725
+ if timing_metrics:
726
+ timing_table = []
727
+ for k, v in timing_metrics.items():
728
+ timing_table.append([k, f"{v:.4f}s" if isinstance(v, float) else v])
729
+ pp.table(["Operation", "Time"], timing_table, "Timing Information")
730
+
731
+ logger.log(data=metrics, step=self.global_steps)
732
+
733
+ # Show progress within epoch
734
+ pp.progress_bar(self.global_steps, total_steps, f"Training Progress (Epoch {epoch+1})")
735
+
736
+ self.global_steps += 1
737
+
738
+ if self.global_steps >= self.total_training_steps:
739
+ pp.section_header("Training Complete")
740
+ # perform validation after training
741
+ if self.val_reward_fn is not None:
742
+ pp.status("Validation", "Running final validation", "info")
743
+ val_metrics = self._validate()
744
+
745
+ # Convert metrics to table format
746
+ final_metrics_table = []
747
+ for k, v in val_metrics.items():
748
+ final_metrics_table.append([k, f"{v:.4f}" if isinstance(v, float) else v])
749
+
750
+ pp.table(["Metric", "Value"], final_metrics_table, "Final Validation Results")
751
+ logger.log(data=val_metrics, step=self.global_steps)
752
+
753
+ if self.config.trainer.save_freq > 0 and \
754
+ (self.global_steps - 1) % self.config.trainer.save_freq != 0:
755
+ with marked_timer('save_checkpoint', timing_raw):
756
+ pp.status("Checkpoint", "Saving final checkpoint", "success")
757
+ self._save_checkpoint()
758
+
759
+ pp.status("Training", "Training completed successfully!", "success")
760
+ if do_profile:
761
+ self.actor_rollout_wg.stop_profile()
762
+ if self.use_reference_policy:
763
+ self.ref_policy_wg.stop_profile()
764
+ if self.use_critic:
765
+ self.critic_wg.stop_profile()
766
+ if self.use_rm:
767
+ self.rm_wg.stop_profile()
768
+ return
absolute_zero_reasoner/trainer/ppo/ttrlvr_azr_integration.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTRLVR + AZR Integration Module
3
+
4
+ TTRLVR의 데이터로부터 AZR의 학습에 필요한 형태로 변환하고,
5
+ TTRLVR의 reward 계산 로직을 통합
6
+ """
7
+
8
+ import os
9
+ import pandas as pd
10
+ from typing import Dict, List, Any, Optional
11
+ import torch
12
+ from transformers import AutoTokenizer
13
+
14
+ from ...rewards.ttrlvr_reward_manager import TTRLVRRewardManager
15
+
16
+
17
+ class TTRLVRAZRDataProcessor:
18
+ """TTRLVR 데이터를 AZR 학습에 맞게 처리하는 클래스"""
19
+
20
+ def __init__(self, tokenizer: AutoTokenizer):
21
+ self.tokenizer = tokenizer
22
+ self.reward_manager = TTRLVRRewardManager(
23
+ tokenizer=tokenizer,
24
+ num_examine=0,
25
+ reward_fn_extraction_type='rule',
26
+ math_metric='accuracy',
27
+ split='test',
28
+ splitter='boxed',
29
+ output_path='./ttrlvr_output',
30
+ max_prompt_length=2048,
31
+ generation_reward_config=type('obj', (object,), {
32
+ 'use_original_code_as_ref': False,
33
+ 'reward_type': 'code_execution',
34
+ 'weight': 1.0
35
+ })
36
+ )
37
+
38
+ def load_ttrlvr_data(self, data_path: str) -> Dict[str, pd.DataFrame]:
39
+ """TTRLVR parquet 파일들을 로드"""
40
+ data_by_type = {}
41
+
42
+ for task_type in ['induction', 'deduction', 'abduction']:
43
+ file_path = os.path.join(data_path, f"{task_type}.parquet")
44
+ if os.path.exists(file_path):
45
+ df = pd.read_parquet(file_path)
46
+ data_by_type[task_type] = df
47
+
48
+ return data_by_type
49
+
50
+ def prepare_batch_for_azr(self, batch_data: List[Dict[str, Any]]) -> Dict[str, Any]:
51
+ """
52
+ TTRLVR 배치 데이터를 AZR 형식으로 변환
53
+
54
+ Returns:
55
+ prompts: 프롬프트 리스트
56
+ metadata: 각 샘플의 메타데이터 (task_type, evaluation_data 등)
57
+ """
58
+ prompts = []
59
+ metadata = []
60
+
61
+ for data in batch_data:
62
+ # prompt 추출 (TTRLVR은 conversation 형식으로 저장)
63
+ if isinstance(data['prompt'], list) and len(data['prompt']) > 0:
64
+ prompt_text = data['prompt'][0].get('content', '')
65
+ else:
66
+ prompt_text = str(data['prompt'])
67
+
68
+ prompts.append(prompt_text)
69
+
70
+ # 메타데이터 구성
71
+ meta = {
72
+ 'task_type': self._extract_task_type_from_uid(data.get('uid', '')),
73
+ 'expected_solution': data.get('ground_truth', ''),
74
+ 'problem': data.get('problem', {}),
75
+ 'ipo_group_id': data.get('ipo_group_id', ''),
76
+ 'uid': data.get('uid', '')
77
+ }
78
+
79
+ # evaluation_data 구성 (task 타입별)
80
+ if meta['task_type'] == 'induction':
81
+ # IPO에서 input/output 쌍 추출
82
+ meta['evaluation_data'] = {
83
+ 'input_output_pairs': [
84
+ (meta['problem'].get('input', ''),
85
+ meta['problem'].get('output', ''))
86
+ ]
87
+ }
88
+ elif meta['task_type'] == 'deduction':
89
+ meta['evaluation_data'] = {
90
+ 'function_code': meta['problem'].get('snippet', ''),
91
+ 'input': meta['problem'].get('input', '')
92
+ }
93
+ elif meta['task_type'] == 'abduction':
94
+ meta['evaluation_data'] = {
95
+ 'function_code': meta['problem'].get('snippet', ''),
96
+ 'expected_output': meta['problem'].get('output', '')
97
+ }
98
+
99
+ metadata.append(meta)
100
+
101
+ return {
102
+ 'prompts': prompts,
103
+ 'metadata': metadata
104
+ }
105
+
106
+ def compute_rewards_for_responses(self,
107
+ prompts: List[str],
108
+ responses: List[str],
109
+ metadata: List[Dict[str, Any]]) -> List[float]:
110
+ """
111
+ 모델 응답에 대한 reward 계산
112
+ complete_pipeline.py의 _compute_rewards_with_azr과 동일한 로직 사용
113
+ """
114
+ return self.reward_manager.compute_rewards(prompts, responses, metadata)
115
+
116
+ def _extract_task_type_from_uid(self, uid: str) -> str:
117
+ """UID에서 task 타입 추출"""
118
+ if 'induction' in uid:
119
+ return 'induction'
120
+ elif 'deduction' in uid:
121
+ return 'deduction'
122
+ elif 'abduction' in uid:
123
+ return 'abduction'
124
+ else:
125
+ return 'unknown'
absolute_zero_reasoner/utils/__init__.py ADDED
File without changes
absolute_zero_reasoner/utils/auxiliary.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ reflection_keywords = [
2
+ "wait", "recheck", "retry", "rethink", "re-verify", "re-evaluate",
3
+ "check again", "try again", "think again", "verify again",
4
+ "evaluate again", "let's correct", "however", "alternatively",
5
+ "reconsider", "review", "revisit", "double-check", "cross-check",
6
+ "second look", "reassess", "inspect", "examine again", "re-examine",
7
+ "revise", "adjust", "modify", "recalibrate", "pause", "reflect",
8
+ "clarify", "confirm", "validate again", "on second thought",
9
+ "in retrospect", "upon reflection", "alternately", "perhaps",
10
+ "maybe", "on the other hand"
11
+ ]
absolute_zero_reasoner/utils/code_utils/__init__.py ADDED
File without changes
absolute_zero_reasoner/utils/code_utils/checks.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import ast
3
+ import re
4
+ from typing import List
5
+
6
+
7
+ def check_determinism(code: str, inputs: str, executor, prev_output: str = None, n_runs: int = 1):
8
+ """expects an executor that outputs string output and status"""
9
+ all_outputs = set()
10
+ if prev_output is not None:
11
+ hash = hashlib.md5(str(prev_output).encode()).hexdigest()
12
+ all_outputs.add(hash)
13
+ for _ in range(n_runs):
14
+ result = executor.run_code(code, inputs)[0]
15
+ hash = hashlib.md5(str(result).encode()).hexdigest()
16
+ all_outputs.add(hash)
17
+ return len(all_outputs) == 1
18
+
19
+
20
+ def contains_banned_imports(code: str, banned_keywords: List[str], banned_keywords_for_errors_and_exceptions: List[str] = []) -> bool:
21
+ """Check if code imports any banned modules using AST parsing."""
22
+ try:
23
+ tree = ast.parse(code)
24
+ for node in ast.walk(tree):
25
+ if isinstance(node, ast.Import):
26
+ for alias in node.names:
27
+ if any(banned in alias.name.split('.') for banned in banned_keywords):
28
+ return True
29
+ elif isinstance(node, ast.ImportFrom):
30
+ module = node.module.split('.') if node.module else []
31
+ if any(banned in module for banned in banned_keywords):
32
+ return True
33
+ for alias in node.names:
34
+ if any(banned in alias.name.split('.') for banned in banned_keywords):
35
+ return True
36
+
37
+ if banned_keywords_for_errors_and_exceptions:
38
+ # Check for assert statements
39
+ if isinstance(node, ast.Assert) and 'assert' in banned_keywords_for_errors_and_exceptions:
40
+ return True
41
+
42
+ # Check for raise statements
43
+ elif isinstance(node, ast.Raise) and 'raise' in banned_keywords_for_errors_and_exceptions:
44
+ return True
45
+
46
+ # Check for try-except blocks
47
+ elif isinstance(node, ast.Try) and 'try' in banned_keywords_for_errors_and_exceptions:
48
+ return True
49
+
50
+ # Check for except handlers
51
+ elif isinstance(node, ast.ExceptHandler) and 'except' in banned_keywords_for_errors_and_exceptions:
52
+ return True
53
+
54
+ return False
55
+ except SyntaxError:
56
+ # Fallback to simple check if AST parsing fails
57
+ return any(re.search(rf'\b{re.escape(banned)}\b', code) for banned in banned_keywords)
58
+
59
+
60
+ def check_no_definitions(code: str, composite_functions: List[str]) -> bool:
61
+ try:
62
+ tree = ast.parse(code)
63
+ except SyntaxError:
64
+ return False
65
+
66
+ for node in tree.body:
67
+ if isinstance(node, ast.FunctionDef) and node.name in composite_functions:
68
+ return False
69
+ return True
70
+
71
+
72
+ def check_composite_function(code: str, composite_functions: List[str]) -> bool:
73
+ composite_function_names = [f"g_{i}" for i in range(len(composite_functions))]
74
+
75
+ try:
76
+ tree = ast.parse(code)
77
+ except SyntaxError:
78
+ return False
79
+
80
+ f_def = None
81
+ for node in tree.body:
82
+ if isinstance(node, ast.FunctionDef) and node.name == 'f':
83
+ f_def = node
84
+ break
85
+ if f_def is None:
86
+ return False
87
+
88
+ parameters = {arg.arg for arg in f_def.args.args}
89
+
90
+ assigned_vars_visitor = AssignedVarsVisitor()
91
+ for stmt in f_def.body:
92
+ assigned_vars_visitor.visit(stmt)
93
+ scope_vars = parameters | assigned_vars_visitor.assigned
94
+
95
+ call_checker = CallChecker(composite_function_names, scope_vars)
96
+ for stmt in f_def.body:
97
+ call_checker.visit(stmt)
98
+
99
+ result = call_checker.called == set(composite_function_names) and call_checker.valid
100
+ return result
101
+
102
+
103
+ class AssignedVarsVisitor(ast.NodeVisitor):
104
+ def __init__(self):
105
+ self.assigned = set()
106
+
107
+ def visit_Assign(self, node):
108
+ for target in node.targets:
109
+ self.collect_names(target)
110
+ self.generic_visit(node)
111
+
112
+ def collect_names(self, node):
113
+ if isinstance(node, ast.Name):
114
+ self.assigned.add(node.id)
115
+ elif isinstance(node, (ast.Tuple, ast.List)):
116
+ for elt in node.elts:
117
+ self.collect_names(elt)
118
+
119
+
120
+ class CallChecker(ast.NodeVisitor):
121
+ def __init__(self, composite_functions, scope_vars):
122
+ self.composite_functions = composite_functions
123
+ self.scope_vars = scope_vars
124
+ self.called = set()
125
+ self.valid = True
126
+ self.local_scopes = [{}]
127
+
128
+ def visit_FunctionDef(self, node):
129
+ self.local_scopes.append({arg.arg: None for arg in node.args.args})
130
+ self.generic_visit(node)
131
+ self.local_scopes.pop()
132
+
133
+ def visit_ListComp(self, node):
134
+ comp_scope = {}
135
+ for gen in node.generators:
136
+ if isinstance(gen.iter, ast.Name) and gen.iter.id in self.scope_vars:
137
+ self.collect_names(gen.target, comp_scope)
138
+ self.local_scopes.append(comp_scope)
139
+ self.visit(node.elt)
140
+ for gen in node.generators:
141
+ for comp_if in gen.ifs:
142
+ self.visit(comp_if)
143
+ self.local_scopes.pop()
144
+
145
+ def visit_Call(self, node):
146
+ if isinstance(node.func, ast.Name):
147
+ if node.func.id in self.composite_functions:
148
+ func_name = node.func.id
149
+ self.called.add(func_name)
150
+ current_scope = self.build_current_scope()
151
+ for arg in node.args:
152
+ names = self.get_names(arg)
153
+ if not all(name in current_scope for name in names):
154
+ self.valid = False
155
+ elif node.func.id in {n.name for n in ast.walk(node) if isinstance(n, ast.FunctionDef)}:
156
+ for parent in ast.walk(node):
157
+ if isinstance(parent, ast.FunctionDef) and parent.name == node.func.id:
158
+ for param, arg in zip(parent.args.args, node.args):
159
+ if isinstance(arg, ast.Name):
160
+ self.local_scopes[-1][param.arg] = arg.id
161
+ self.generic_visit(node)
162
+
163
+ def build_current_scope(self):
164
+ scope = set(self.scope_vars)
165
+ for local_scope in self.local_scopes:
166
+ scope.update(local_scope.keys())
167
+ for mapped_var in local_scope.values():
168
+ if mapped_var:
169
+ scope.add(mapped_var)
170
+ return scope
171
+
172
+ def collect_names(self, node, scope_dict):
173
+ if isinstance(node, ast.Name):
174
+ scope_dict[node.id] = None
175
+ elif isinstance(node, (ast.Tuple, ast.List)):
176
+ for elt in node.elts:
177
+ self.collect_names(elt, scope_dict)
178
+
179
+ def get_names(self, node):
180
+ return [n.id for n in ast.walk(node) if isinstance(n, ast.Name)
181
+ and isinstance(n.ctx, ast.Load)
182
+ and n.id not in self.composite_functions]
absolute_zero_reasoner/utils/code_utils/parsers.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import re
3
+ from typing import List
4
+
5
+
6
+ def parse_imports(code_snippet: str) -> List[str]:
7
+ imports = []
8
+ try:
9
+ tree = ast.parse(code_snippet)
10
+ for node in ast.walk(tree):
11
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
12
+ # Reconstruct import line from AST node
13
+ if isinstance(node, ast.Import):
14
+ import_line = "import " + ", ".join(
15
+ [alias.name + (f" as {alias.asname}" if alias.asname else "")
16
+ for alias in node.names]
17
+ )
18
+ else:
19
+ module = node.module or ""
20
+ import_line = f"from {module} import " + ", ".join(
21
+ [alias.name + (f" as {alias.asname}" if alias.asname else "")
22
+ for alias in node.names]
23
+ )
24
+ if node.level > 0:
25
+ import_line = f"from {'.' * node.level}{module} import " + ", ".join(
26
+ [alias.name + (f" as {alias.asname}" if alias.asname else "")
27
+ for alias in node.names]
28
+ )
29
+ imports.append(import_line)
30
+ except Exception as e:
31
+ import_pattern = r"^\s*(?:from|import)\s+.*$"
32
+ imports = [i.strip() for i in re.findall(import_pattern, code_snippet, re.MULTILINE)]
33
+ return imports
34
+
35
+
36
+ def parse_error(error_message: str) -> str:
37
+ # split by colon
38
+ error_message = error_message.split(':')[0]
39
+ return error_message.strip()
40
+
41
+
42
+ def replace_main_function_name(code: str, old_name: str, new_name: str) -> str:
43
+ """
44
+ Replace all occurrences of `old_name` with `new_name` in the code.
45
+ Replace the definition and all recursive calls of `old_name` with `new_name`.
46
+ """
47
+ tree = ast.parse(code)
48
+ for node in ast.walk(tree):
49
+ if isinstance(node, ast.FunctionDef) and node.name == old_name:
50
+ node.name = new_name
51
+ elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == old_name:
52
+ node.func.id = new_name
53
+ return ast.unparse(tree)
54
+
55
+
56
+ def remove_comments_and_docstrings(code: str) -> str:
57
+ """
58
+ Remove all comments and docstrings from the code.
59
+ """
60
+ try:
61
+ tree = ast.parse(code)
62
+ for node in ast.walk(tree):
63
+ if isinstance(node, (ast.AsyncFunctionDef, ast.FunctionDef, ast.ClassDef, ast.Module)):
64
+ # Remove all leading docstrings
65
+ while node.body and isinstance(node.body[0], ast.Expr):
66
+ expr = node.body[0].value
67
+ if isinstance(expr, (ast.Str, ast.Constant)) and (
68
+ isinstance(expr.value, str) if isinstance(expr, ast.Constant) else True
69
+ ):
70
+ node.body.pop(0)
71
+ else:
72
+ break
73
+
74
+ # Convert back to code - AST unparse already removes comments
75
+ code_without_docstrings = ast.unparse(tree)
76
+
77
+ # Only remove empty lines and trim whitespace
78
+ lines = [
79
+ line.rstrip()
80
+ for line in code_without_docstrings.split('\n')
81
+ if line.strip()
82
+ ]
83
+
84
+ return '\n'.join(lines)
85
+ except Exception as e:
86
+ return code # Return original code if parsing fails
87
+
88
+
89
+ def remove_any_not_definition_imports(code: str) -> str:
90
+ """
91
+ Remove anything that is not a definition or import.
92
+ Preserves:
93
+ - Import/From imports
94
+ - Class definitions
95
+ - Function/AsyncFunction definitions
96
+ Removes:
97
+ - Top-level assignments
98
+ - Standalone expressions
99
+ - Constant declarations
100
+ """
101
+ class DefinitionFilter(ast.NodeTransformer):
102
+ def visit_Module(self, node):
103
+ # Keep only definitions and imports (explicitly exclude assignments)
104
+ node.body = [
105
+ n for n in node.body
106
+ if isinstance(n, (
107
+ ast.Import,
108
+ ast.ImportFrom,
109
+ ast.FunctionDef,
110
+ ast.AsyncFunctionDef,
111
+ ast.ClassDef
112
+ ))
113
+ ]
114
+ return node
115
+
116
+ try:
117
+ tree = ast.parse(code)
118
+ tree = DefinitionFilter().visit(tree)
119
+ ast.fix_missing_locations(tree)
120
+
121
+ # Remove empty lines and format
122
+ cleaned = ast.unparse(tree)
123
+ return '\n'.join([line for line in cleaned.split('\n') if line.strip()])
124
+
125
+ except Exception as e:
126
+ return code
127
+
128
+
129
+ class PrintRemover(ast.NodeTransformer):
130
+ def visit_Expr(self, node):
131
+ # Handle top-level print statements
132
+ if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name) and node.value.func.id == 'print':
133
+ return None
134
+ return node
135
+
136
+ def visit_Call(self, node):
137
+ # Handle print calls in other contexts (like assignments)
138
+ if isinstance(node.func, ast.Name) and node.func.id == 'print':
139
+ return ast.Constant(value=None)
140
+ return node
141
+
142
+ def _handle_block(self, node):
143
+ self.generic_visit(node)
144
+ if not node.body:
145
+ node.body.append(ast.Pass())
146
+ return node
147
+
148
+ def visit_For(self, node):
149
+ return self._handle_block(node)
150
+
151
+ def visit_While(self, node):
152
+ return self._handle_block(node)
153
+
154
+ def visit_FunctionDef(self, node):
155
+ return self._handle_block(node)
156
+
157
+ def visit_AsyncFunctionDef(self, node):
158
+ return self._handle_block(node)
159
+
160
+ def visit_If(self, node):
161
+ return self._handle_block(node)
162
+
163
+ def visit_With(self, node):
164
+ return self._handle_block(node)
165
+
166
+ def visit_Try(self, node):
167
+ self.generic_visit(node)
168
+
169
+ # Handle main try body
170
+ if not node.body:
171
+ node.body.append(ast.Pass())
172
+
173
+ # Handle except handlers
174
+ for handler in node.handlers:
175
+ if not handler.body:
176
+ handler.body.append(ast.Pass())
177
+
178
+ # Handle else clause
179
+ if node.orelse and not node.orelse:
180
+ node.orelse.append(ast.Pass())
181
+
182
+ # Handle finally clause
183
+ if node.finalbody and not node.finalbody:
184
+ node.finalbody.append(ast.Pass())
185
+
186
+ return node
187
+
188
+
189
+ def remove_print_statements(code: str) -> str:
190
+ """
191
+ Remove all print statements from the code.
192
+ """
193
+ tree = ast.parse(code)
194
+ tree = PrintRemover().visit(tree)
195
+ ast.fix_missing_locations(tree)
196
+ return ast.unparse(tree)
197
+
198
+
199
+ if __name__ == "__main__":
200
+ print(parse_error("NameError: name 'x' is not defined"))
201
+ print(parse_error("TypeError: unsupported operand type(s) for -: 'str' and 'str'"))
202
+ print(parse_error("ValueError: invalid literal for int() with base 10: 'x'"))
absolute_zero_reasoner/utils/code_utils/python_executor.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # https://github.com/QwenLM/QwQ/blob/main/eval/eval/math_opensource_utils/python_executor.py
4
+
5
+ import copy
6
+ import datetime
7
+ import io
8
+ import logging
9
+ import pickle
10
+ import traceback
11
+ from concurrent.futures import TimeoutError
12
+ from contextlib import redirect_stdout
13
+ from functools import partial
14
+ from typing import Any, Dict, Optional, List, Tuple
15
+ import ast
16
+ import time
17
+
18
+ import numpy as np
19
+ import dateutil.relativedelta
20
+ import regex
21
+ from pebble import ProcessPool
22
+ from timeout_decorator import timeout
23
+ from tqdm import tqdm
24
+
25
+ from absolute_zero_reasoner.utils.code_utils.templates import (
26
+ RUN_CODE_TEMPLATE,
27
+ EVAL_INPUT_PREDICTION_TEMPLATE,
28
+ EVAL_OUTPUT_PREDICTION_TEMPLATE,
29
+ VALIDATE_CODE_TEMPLATE,
30
+ CHECK_DETERMINISM_TEMPLATE,
31
+ EVAL_K_INPUT_PREDICTION_TEMPLATE,
32
+ EVAL_K_OUTPUT_PREDICTION_TEMPLATE,
33
+ )
34
+ from absolute_zero_reasoner.utils.code_utils.checks import contains_banned_imports
35
+ from absolute_zero_reasoner.utils.code_utils.parsers import parse_error
36
+
37
+
38
+ class GenericRuntime:
39
+ GLOBAL_DICT = {}
40
+ LOCAL_DICT = None
41
+ HEADERS = []
42
+
43
+ def __init__(self):
44
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
45
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
46
+
47
+ for c in self.HEADERS:
48
+ self.exec_code(c)
49
+
50
+ def exec_code(self, code_piece: str) -> None:
51
+ if regex.search(r'(\s|^)?input\(', code_piece):
52
+ # regex.search(r'(\s|^)?os.', code_piece):
53
+ raise RuntimeError()
54
+ exec(code_piece, self._global_vars)
55
+
56
+ # TODO: use: https://github.com/shroominic/codebox-api
57
+ # @high safe exec in sandbox
58
+ # byte_code = compile_restricted(
59
+ # code_piece,
60
+ # filename='<inline code>',
61
+ # mode='exec'
62
+ # )
63
+ # print("global vars:", self._global_vars)
64
+ # _print_ = PrintCollector
65
+ # exec(byte_code, {'__builtins__': utility_builtins}, None)
66
+
67
+ def eval_code(self, expr: str) -> Any:
68
+ return eval(expr, self._global_vars)
69
+
70
+ def inject(self, var_dict: Dict[str, Any]) -> None:
71
+ for k, v in var_dict.items():
72
+ self._global_vars[k] = v
73
+
74
+ @property
75
+ def answer(self):
76
+ return self._global_vars['answer']
77
+
78
+
79
+ class DateRuntime(GenericRuntime):
80
+ GLOBAL_DICT = {
81
+ 'datetime': datetime.datetime,
82
+ 'timedelta': dateutil.relativedelta.relativedelta,
83
+ 'relativedelta': dateutil.relativedelta.relativedelta
84
+ }
85
+
86
+
87
+ class CustomDict(dict):
88
+ def __iter__(self):
89
+ return list(super().__iter__()).__iter__()
90
+
91
+
92
+ class ColorObjectRuntime(GenericRuntime):
93
+ GLOBAL_DICT = {'dict': CustomDict}
94
+
95
+
96
+ class PythonExecutor:
97
+ def __init__(
98
+ self,
99
+ runtime: Optional[Any] = None,
100
+ get_answer_symbol: Optional[str] = None,
101
+ get_answer_expr: Optional[str] = None,
102
+ get_answer_from_stdout: bool = False,
103
+ timeout_length: int = 10,
104
+ ast_check: bool = False,
105
+ max_workers: int = 1,
106
+ ) -> None:
107
+ self.runtime = runtime if runtime else GenericRuntime()
108
+ self.answer_symbol = get_answer_symbol
109
+ self.answer_expr = get_answer_expr
110
+ self.get_answer_from_stdout = get_answer_from_stdout
111
+ self.timeout_length = timeout_length
112
+ self.ast_check = ast_check
113
+ self.max_workers = max_workers
114
+ self._process_pool = None
115
+
116
+ def __del__(self):
117
+ try:
118
+ self.cleanup()
119
+ # self.pool.terminate()
120
+ except Exception as e:
121
+ print(f"Error terminating pool: {e}")
122
+ pass
123
+
124
+ def cleanup(self):
125
+ """Explicitly clean up the process pool"""
126
+ if self._process_pool is not None:
127
+ self._process_pool.close()
128
+ self._process_pool.join()
129
+ self._process_pool = None
130
+
131
+ def _get_process_pool(self, size_hint):
132
+ """Get or create a ProcessPool with appropriate size"""
133
+ if self._process_pool is None:
134
+ self._process_pool = ProcessPool(max_workers=min(size_hint, self.max_workers))
135
+ return self._process_pool
136
+
137
+ def process_generation_to_code(self, gens: str):
138
+ return [g.strip().split('\n') for g in gens]
139
+
140
+ def run_code(self, code: str, inputs: str, imports: List[str] = []) -> Tuple[str, str]:
141
+ if isinstance(imports, np.ndarray):
142
+ imports = imports.tolist()
143
+ if imports:
144
+ code = '\n'.join(imports) + '\n' + code
145
+ code_snippet = RUN_CODE_TEMPLATE.format(code=code, inputs=inputs)
146
+ # print(code_snippet)
147
+ if self.ast_check:
148
+ try:
149
+ ast.parse(code_snippet)
150
+ except:
151
+ return '', 'error'
152
+ return self.apply(code_snippet)
153
+
154
+ def validate_code(self, code: str, inputs: str, imports: List[str] = []) -> bool:
155
+ if isinstance(imports, np.ndarray):
156
+ imports = imports.tolist()
157
+ if imports:
158
+ code = '\n'.join(imports) + '\n' + code
159
+ code_snippet = VALIDATE_CODE_TEMPLATE.format(code=code, inputs=inputs)
160
+ if self.ast_check:
161
+ try:
162
+ ast.parse(code_snippet)
163
+ except:
164
+ return False
165
+ _, status = self.apply(code_snippet)
166
+ return not 'error' in status.lower()
167
+
168
+ def eval_input_prediction(self, code: str, gold_output: str, agent_input: str, imports: List[str] = []) -> float:
169
+ if isinstance(imports, np.ndarray):
170
+ imports = imports.tolist()
171
+ if imports:
172
+ code = '\n'.join(imports) + '\n' + code
173
+ code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE.format(code=code, gold_output=gold_output, agent_input=agent_input)
174
+ if self.ast_check:
175
+ try:
176
+ ast.parse(code_snippet)
177
+ except:
178
+ return 0.0
179
+ max_retries = 3
180
+ for retry in range(max_retries):
181
+ try:
182
+ correct, status = self.apply(code_snippet)
183
+ return 0.0 if 'error' in status.lower() or not eval(correct) else 1.0
184
+ except Exception as e:
185
+ if retry == max_retries - 1:
186
+ error_details = traceback.format_exc()
187
+ print(f"Error in eval_input_prediction: {e}\n{error_details}")
188
+ return
189
+ time.sleep(0.1 * (retry + 1)) # Exponential backoff
190
+
191
+ def eval_output_prediction(self, code: str, gold_output: str, agent_output: str, imports: List[str] = []) -> float:
192
+ try: # fast check if we dont need to run the code
193
+ if eval(gold_output) == eval(agent_output):
194
+ return 1.0
195
+ except:
196
+ pass
197
+ if isinstance(imports, np.ndarray):
198
+ imports = imports.tolist()
199
+ if imports:
200
+ code = '\n'.join(imports) + '\n' + code
201
+ code_snippet = EVAL_OUTPUT_PREDICTION_TEMPLATE.format(code=code, gold_output=gold_output, agent_output=agent_output)
202
+ if self.ast_check:
203
+ try:
204
+ ast.parse(code_snippet)
205
+ except:
206
+ return 0.0
207
+ max_retries = 3
208
+ for retry in range(max_retries):
209
+ try:
210
+ correct, status = self.apply(code_snippet)
211
+ return 0.0 if 'error' in status.lower() or not eval(correct) else 1.0
212
+ except Exception as e:
213
+ if retry == max_retries - 1:
214
+ error_details = traceback.format_exc()
215
+ print(f"Error in eval_output_prediction: {e}\n{error_details}")
216
+ return
217
+ time.sleep(0.1 * (retry + 1)) # Exponential backoff
218
+
219
+ def eval_k_input_prediction(self, code: str, gold_output: str, k_agent_inputs: List[str], imports: List[str] = []) -> List[float]:
220
+ if isinstance(imports, np.ndarray):
221
+ imports = imports.tolist()
222
+ if imports:
223
+ code = '\n'.join(imports) + '\n' + code
224
+ invalid_lists = []
225
+ valid_k_agent_inputs = []
226
+ for k_agent_input in k_agent_inputs:
227
+ try:
228
+ ast.parse(f'f({k_agent_input})')
229
+ valid_k_agent_inputs.append(k_agent_input)
230
+ except:
231
+ invalid_lists.append(0.0)
232
+ acc_list, status = self.apply(EVAL_K_INPUT_PREDICTION_TEMPLATE(code=code, gold_output=gold_output, k_agent_inputs=valid_k_agent_inputs))
233
+ assert 'error' not in status.lower()
234
+ output_acc = eval(acc_list) + invalid_lists
235
+ assert len(output_acc) == len(k_agent_inputs)
236
+ return output_acc
237
+
238
+ def eval_k_output_prediction(self, code: str, gold_output: str, k_agent_outputs: List[str], imports: List[str] = []) -> List[float]:
239
+ if isinstance(imports, np.ndarray):
240
+ imports = imports.tolist()
241
+ if imports:
242
+ code = '\n'.join(imports) + '\n' + code
243
+ invalid_lists = []
244
+ valid_k_agent_outputs = []
245
+ for k_agent_output in k_agent_outputs:
246
+ try:
247
+ if k_agent_output != '':
248
+ ast.parse(f'f({k_agent_output})')
249
+ valid_k_agent_outputs.append(k_agent_output)
250
+ else:
251
+ invalid_lists.append(0.0)
252
+ except:
253
+ invalid_lists.append(0.0)
254
+ acc_list, status = self.apply(EVAL_K_OUTPUT_PREDICTION_TEMPLATE(code=code, gold_output=gold_output, k_agent_outputs=valid_k_agent_outputs))
255
+ assert 'error' not in status.lower()
256
+ output_acc = eval(acc_list) + invalid_lists
257
+ assert len(output_acc) == len(k_agent_outputs)
258
+ return output_acc
259
+
260
+ def check_all(
261
+ self,
262
+ code: str,
263
+ inputs: str,
264
+ banned_keywords: List[str] = [],
265
+ check_determinism: bool = True,
266
+ imports: List[str] = [],
267
+ check_error: bool = False,
268
+ banned_keywords_for_errors_and_exceptions: List[str] = [],
269
+ ) -> Tuple[bool, str]:
270
+ if isinstance(imports, np.ndarray):
271
+ imports = imports.tolist()
272
+ if imports:
273
+ code = '\n'.join(imports) + '\n' + code
274
+ if contains_banned_imports(code=code, banned_keywords=banned_keywords, banned_keywords_for_errors_and_exceptions=banned_keywords_for_errors_and_exceptions if check_error else []):
275
+ return False, None
276
+ if check_error:
277
+ code_snippet = RUN_CODE_TEMPLATE.format(code=code, inputs=inputs)
278
+ try:
279
+ ast.parse(code_snippet)
280
+ except:
281
+ return False, 'error'
282
+ output, status = self.apply(code_snippet)
283
+ if check_determinism: # run the code again, see if outputs are same
284
+ output_2, status_2 = self.apply(code_snippet)
285
+ if status_2.lower() != status.lower() and output != output_2:
286
+ return False, 'error'
287
+ # True if the code is valid code but might have error, output no error if the code returns something
288
+ return True, 'NoError' if status.lower() == 'done' else parse_error(status)
289
+ else:
290
+ if check_determinism:
291
+ code_snippet = CHECK_DETERMINISM_TEMPLATE.format(code=code, inputs=inputs)
292
+ else:
293
+ code_snippet = RUN_CODE_TEMPLATE.format(code=code, inputs=inputs)
294
+ if self.ast_check:
295
+ try:
296
+ ast.parse(code_snippet)
297
+ except:
298
+ return False, 'error'
299
+ output, status = self.apply(code_snippet)
300
+ return not 'error' in status.lower(), output
301
+
302
+ @staticmethod
303
+ def execute(
304
+ code,
305
+ get_answer_from_stdout=None,
306
+ runtime=None,
307
+ answer_symbol=None,
308
+ answer_expr=None,
309
+ timeout_length=10,
310
+ auto_mode=False
311
+ ):
312
+ try:
313
+ if auto_mode:
314
+ if "print(" in code[-1]:
315
+ program_io = io.StringIO()
316
+ with redirect_stdout(program_io):
317
+ timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
318
+ program_io.seek(0)
319
+ result = program_io.read()
320
+ else:
321
+ # print(code)
322
+ timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
323
+ result = timeout(timeout_length)(runtime.eval_code)(code[-1])
324
+ else:
325
+ if get_answer_from_stdout:
326
+ program_io = io.StringIO()
327
+ with redirect_stdout(program_io):
328
+ timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
329
+ program_io.seek(0)
330
+ result = program_io.read()
331
+ elif answer_symbol:
332
+ timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
333
+ result = runtime._global_vars[answer_symbol]
334
+ elif answer_expr:
335
+ timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
336
+ result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
337
+ else:
338
+ timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
339
+ result = timeout(timeout_length)(runtime.eval_code)(code[-1])
340
+ report = "Done"
341
+ str(result) # codec check
342
+ pickle.dumps(result) # serialization check
343
+ except:
344
+ result = ''
345
+ report = traceback.format_exc().split('\n')[-2]
346
+ return result, report
347
+
348
+ def apply(self, code):
349
+ return self.batch_apply([code])[0]
350
+
351
+ @staticmethod
352
+ def truncate(s, max_length=400):
353
+ half = max_length // 2
354
+ if len(s) > max_length:
355
+ s = s[:half] + "..." + s[-half:]
356
+ return s
357
+
358
+ def batch_apply(self, batch_code):
359
+ all_code_snippets = self.process_generation_to_code(batch_code)
360
+
361
+ timeout_cnt = 0
362
+ all_exec_results = []
363
+
364
+ pool = self._get_process_pool(len(all_code_snippets))
365
+ executor = partial(
366
+ self.execute,
367
+ get_answer_from_stdout=self.get_answer_from_stdout,
368
+ runtime=self.runtime,
369
+ answer_symbol=self.answer_symbol,
370
+ answer_expr=self.answer_expr,
371
+ timeout_length=self.timeout_length,
372
+ auto_mode=True
373
+ )
374
+
375
+ try:
376
+ future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
377
+ iterator = future.result()
378
+
379
+ if len(all_code_snippets) > 100:
380
+ progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
381
+ else:
382
+ progress_bar = None
383
+
384
+ while True:
385
+ try:
386
+ result = next(iterator)
387
+ all_exec_results.append(result)
388
+ except StopIteration:
389
+ break
390
+ except TimeoutError as error:
391
+ logging.warning(f"Timeout error in code execution: {error}")
392
+ all_exec_results.append(("", "Timeout Error"))
393
+ timeout_cnt += 1
394
+ except Exception as error:
395
+ logging.warning(f"Error in code execution: {error}")
396
+ all_exec_results.append(("", f"Error: {str(error)}"))
397
+ if progress_bar is not None:
398
+ progress_bar.update(1)
399
+
400
+ if progress_bar is not None:
401
+ progress_bar.close()
402
+ except Exception as e:
403
+ logging.error(f"Critical error in batch execution: {e}")
404
+ # Make sure we have results for all snippets
405
+ while len(all_exec_results) < len(all_code_snippets):
406
+ all_exec_results.append(("", f"Critical Error: {str(e)}"))
407
+
408
+ # Cleanup the pool on critical errors
409
+ self.cleanup()
410
+
411
+ batch_results = []
412
+ for code, (res, report) in zip(all_code_snippets, all_exec_results):
413
+ # post processing
414
+ res, report = str(res).strip(), str(report).strip()
415
+ res, report = self.truncate(res), self.truncate(report)
416
+ batch_results.append((res, report))
417
+ return batch_results
418
+
419
+
420
+ def _test():
421
+ batch_code = [
422
+ """
423
+ def f(a):
424
+ return a
425
+ print(f(1,2))
426
+ """
427
+ ]
428
+
429
+ executor = PythonExecutor(get_answer_from_stdout=True)
430
+ predictions = executor.apply(batch_code[0])
431
+ print(predictions)
432
+
433
+
434
+ if __name__ == '__main__':
435
+ _test()
absolute_zero_reasoner/utils/code_utils/sandboxfusion_executor.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import List, Tuple
3
+ import ast
4
+ import time
5
+ import requests
6
+ import docker
7
+ from docker.errors import DockerException
8
+ import socket
9
+
10
+ import numpy as np
11
+ from pebble import ProcessPool
12
+ from sandbox_fusion import run_code, RunCodeRequest, set_endpoint, RunStatus
13
+
14
+ from absolute_zero_reasoner.utils.code_utils.templates import (
15
+ RUN_CODE_TEMPLATE_REPR,
16
+ EVAL_INPUT_PREDICTION_TEMPLATE_REPR,
17
+ EVAL_OUTPUT_PREDICTION_TEMPLATE_REPR,
18
+ VALIDATE_CODE_TEMPLATE_REPR,
19
+ CHECK_DETERMINISM_TEMPLATE_REPR,
20
+ EVAL_K_INPUT_PREDICTION_TEMPLATE,
21
+ EVAL_K_OUTPUT_PREDICTION_TEMPLATE,
22
+ )
23
+ from absolute_zero_reasoner.utils.code_utils.checks import contains_banned_imports
24
+ from absolute_zero_reasoner.utils.code_utils.parsers import parse_error
25
+
26
+
27
+ # Docker images
28
+ IMAGES = {
29
+ 'global': 'volcengine/sandbox-fusion:server-20250609',
30
+ 'china': 'vemlp-cn-beijing.cr.volces.com/preset-images/code-sandbox:server-20250609'
31
+ }
32
+ class DockerAPIRunner:
33
+ def __init__(self, use_china_mirror=True, silent=False):
34
+ self.image = IMAGES['china'] if use_china_mirror else IMAGES['global']
35
+ self.container = None
36
+ self.silent = silent
37
+ self.client = docker.from_env()
38
+ self.port = self._find_free_port()
39
+
40
+ def _find_free_port(self):
41
+ """Find an available port dynamically"""
42
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
43
+ s.bind(('', 0))
44
+ s.listen(1)
45
+ port = s.getsockname()[1]
46
+ return port
47
+
48
+ def start(self):
49
+ """Start the Docker container using Docker API"""
50
+ try:
51
+ # Pull image if not exists
52
+ if not self.silent:
53
+ print(f"Pulling image: {self.image}")
54
+ self.client.images.pull(self.image)
55
+
56
+ # Run container
57
+ self.container = self.client.containers.run(
58
+ self.image,
59
+ ports={'8080/tcp': self.port},
60
+ detach=True,
61
+ remove=True # Auto-remove when stopped
62
+ )
63
+
64
+ if not self.silent:
65
+ print(f"Container started: {self.container.short_id}")
66
+ return True
67
+
68
+ except DockerException as e:
69
+ if not self.silent:
70
+ print(f"Error starting container: {e}")
71
+ return False
72
+
73
+ def stop(self):
74
+ """Stop the Docker container"""
75
+ if self.container:
76
+ try:
77
+ self.container.stop()
78
+ if not self.silent:
79
+ print("Container stopped")
80
+ return True
81
+ except DockerException as e:
82
+ if not self.silent:
83
+ print(f"Error stopping container: {e}")
84
+ return False
85
+ return False
86
+
87
+ def _wait_for_container_ready(self, max_wait_time: int = 60, check_interval: float = 1.0):
88
+ """Wait for the Docker container to be ready"""
89
+ if not self.container:
90
+ raise Exception("Container not started")
91
+
92
+ start_time = time.time()
93
+ while time.time() - start_time < max_wait_time:
94
+ # Reload container status
95
+ self.container.reload()
96
+
97
+ if not self.silent:
98
+ print(f"Container status: {self.container.status}")
99
+
100
+ if self.container.status == 'running':
101
+ # Container is running, now check if service is ready
102
+ # First try a simple port connection test
103
+ try:
104
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
105
+ sock.settimeout(2)
106
+ result = sock.connect_ex(('localhost', self.port))
107
+ sock.close()
108
+
109
+ if result == 0: # Port is open
110
+ # Try to make a simple request to test the service
111
+ try:
112
+ response = requests.get(f'http://localhost:{self.port}/', timeout=2)
113
+ if not self.silent:
114
+ print(f"Service responded with status: {response.status_code}")
115
+ return True # Service is responding
116
+ except requests.exceptions.RequestException:
117
+ # Try alternative endpoints or just accept that port is open
118
+ if not self.silent:
119
+ print(f"Port {self.port} is open, assuming service is ready")
120
+ return True
121
+ except:
122
+ pass
123
+ elif self.container.status in ['exited', 'dead']:
124
+ # Get container logs for debugging
125
+ logs = self.container.logs().decode('utf-8')
126
+ raise Exception(f"Container failed to start. Status: {self.container.status}. Logs: {logs[:500]}")
127
+
128
+ time.sleep(check_interval)
129
+
130
+ # Get final container logs for debugging
131
+ logs = self.container.logs().decode('utf-8') if self.container else "No container"
132
+ raise Exception(f"Container not ready after {max_wait_time} seconds. Final status: {self.container.status if self.container else 'None'}. Logs: {logs[:500]}")
133
+
134
+
135
+ class SandboxfusionExecutor:
136
+ def __init__(
137
+ self,
138
+ timeout_length: int = 10,
139
+ ast_check: bool = False,
140
+ max_workers: int = 1,
141
+ use_china_mirror: bool = True,
142
+ ) -> None:
143
+ self.runner = DockerAPIRunner(use_china_mirror=use_china_mirror)
144
+ running = self.runner.start()
145
+ if not running:
146
+ raise Exception("Failed to start Sandboxfusion Docker container")
147
+
148
+ # Wait for the container to be ready
149
+ self._wait_for_container_ready()
150
+ set_endpoint(f'http://localhost:{self.runner.port}')
151
+
152
+ self.timeout_length = timeout_length
153
+ self.ast_check = ast_check
154
+ self.max_workers = max_workers
155
+
156
+ def _wait_for_container_ready(self, max_wait_time: int = 60, check_interval: float = 1.0):
157
+ """Wait for the Docker container to be ready"""
158
+ self.runner._wait_for_container_ready(max_wait_time, check_interval)
159
+
160
+ def __del__(self):
161
+ try:
162
+ self.cleanup()
163
+ self.runner.stop()
164
+ except Exception as e:
165
+ print(f"Error terminating pool: {e}")
166
+ pass
167
+
168
+ def cleanup(self):
169
+ self.runner.stop()
170
+
171
+ def process_generation_to_code(self, gens: str):
172
+ return [g.strip().split('\n') for g in gens]
173
+
174
+ def run_code(self, code: str, inputs: str, imports: List[str] = []) -> Tuple[str, str]:
175
+ if isinstance(imports, np.ndarray):
176
+ imports = imports.tolist()
177
+ if imports:
178
+ code = '\n'.join(imports) + '\n' + code
179
+ code_snippet = RUN_CODE_TEMPLATE_REPR.format(code=code, inputs=inputs)
180
+ # print(code_snippet)
181
+ if self.ast_check:
182
+ try:
183
+ ast.parse(code_snippet)
184
+ except:
185
+ return '', 'error'
186
+ return self.apply(code_snippet)
187
+
188
+ def validate_code(self, code: str, inputs: str, imports: List[str] = []) -> bool:
189
+ if isinstance(imports, np.ndarray):
190
+ imports = imports.tolist()
191
+ if imports:
192
+ code = '\n'.join(imports) + '\n' + code
193
+ code_snippet = VALIDATE_CODE_TEMPLATE_REPR.format(code=code, inputs=inputs)
194
+ if self.ast_check:
195
+ try:
196
+ ast.parse(code_snippet)
197
+ except:
198
+ return False
199
+ _, status = self.apply(code_snippet)
200
+ return not 'error' in status.lower()
201
+
202
+ def eval_input_prediction(self, code: str, gold_output: str, agent_input: str, imports: List[str] = []) -> float:
203
+ if isinstance(imports, np.ndarray):
204
+ imports = imports.tolist()
205
+ if imports:
206
+ code = '\n'.join(imports) + '\n' + code
207
+ code_snippet = EVAL_INPUT_PREDICTION_TEMPLATE_REPR.format(code=code, gold_output=gold_output, agent_input=agent_input)
208
+ if self.ast_check:
209
+ try:
210
+ ast.parse(code_snippet)
211
+ except:
212
+ return 0.0
213
+ max_retries = 3
214
+ for retry in range(max_retries):
215
+ try:
216
+ correct, status = self.apply(code_snippet)
217
+ return 0.0 if 'error' in status.lower() or not eval(correct) else 1.0
218
+ except Exception as e:
219
+ if retry == max_retries - 1:
220
+ error_details = traceback.format_exc()
221
+ print(f"Error in eval_input_prediction: {e}\n{error_details}")
222
+ return
223
+ time.sleep(0.1 * (retry + 1)) # Exponential backoff
224
+
225
+ def eval_output_prediction(self, code: str, gold_output: str, agent_output: str, imports: List[str] = []) -> float:
226
+ try: # fast check if we dont need to run the code
227
+ if eval(gold_output) == eval(agent_output):
228
+ return 1.0
229
+ except:
230
+ pass
231
+ if isinstance(imports, np.ndarray):
232
+ imports = imports.tolist()
233
+ if imports:
234
+ code = '\n'.join(imports) + '\n' + code
235
+ code_snippet = EVAL_OUTPUT_PREDICTION_TEMPLATE_REPR.format(code=code, gold_output=gold_output, agent_output=agent_output)
236
+ if self.ast_check:
237
+ try:
238
+ ast.parse(code_snippet)
239
+ except:
240
+ return 0.0
241
+ max_retries = 3
242
+ for retry in range(max_retries):
243
+ try:
244
+ correct, status = self.apply(code_snippet)
245
+ return 0.0 if 'error' in status.lower() or not eval(correct) else 1.0
246
+ except Exception as e:
247
+ if retry == max_retries - 1:
248
+ error_details = traceback.format_exc()
249
+ print(f"Error in eval_output_prediction: {e}\n{error_details}")
250
+ return
251
+ time.sleep(0.1 * (retry + 1)) # Exponential backoff
252
+
253
+ def eval_k_input_prediction(self, code: str, gold_output: str, k_agent_inputs: List[str], imports: List[str] = []) -> List[float]:
254
+ if isinstance(imports, np.ndarray):
255
+ imports = imports.tolist()
256
+ if imports:
257
+ code = '\n'.join(imports) + '\n' + code
258
+ invalid_lists = []
259
+ valid_k_agent_inputs = []
260
+ for k_agent_input in k_agent_inputs:
261
+ try:
262
+ ast.parse(f'f({k_agent_input})')
263
+ valid_k_agent_inputs.append(k_agent_input)
264
+ except:
265
+ invalid_lists.append(0.0)
266
+ acc_list, status = self.apply(EVAL_K_INPUT_PREDICTION_TEMPLATE(code=code, gold_output=gold_output, k_agent_inputs=valid_k_agent_inputs, repr_output=True))
267
+ assert 'error' not in status.lower()
268
+ output_acc = eval(acc_list) + invalid_lists
269
+ assert len(output_acc) == len(k_agent_inputs)
270
+ return output_acc
271
+
272
+ def eval_k_output_prediction(self, code: str, gold_output: str, k_agent_outputs: List[str], imports: List[str] = []) -> List[float]:
273
+ if isinstance(imports, np.ndarray):
274
+ imports = imports.tolist()
275
+ if imports:
276
+ code = '\n'.join(imports) + '\n' + code
277
+ invalid_lists = []
278
+ valid_k_agent_outputs = []
279
+ for k_agent_output in k_agent_outputs:
280
+ try:
281
+ if k_agent_output != '':
282
+ ast.parse(f'f({k_agent_output})')
283
+ valid_k_agent_outputs.append(k_agent_output)
284
+ else:
285
+ invalid_lists.append(0.0)
286
+ except:
287
+ invalid_lists.append(0.0)
288
+ acc_list, status = self.apply(EVAL_K_OUTPUT_PREDICTION_TEMPLATE(code=code, gold_output=gold_output, k_agent_outputs=valid_k_agent_outputs, repr_output=True))
289
+ assert 'error' not in status.lower()
290
+ output_acc = eval(acc_list) + invalid_lists
291
+ assert len(output_acc) == len(k_agent_outputs)
292
+ return output_acc
293
+
294
+ def check_all(
295
+ self,
296
+ code: str,
297
+ inputs: str,
298
+ banned_keywords: List[str] = [],
299
+ check_determinism: bool = True,
300
+ imports: List[str] = [],
301
+ check_error: bool = False,
302
+ banned_keywords_for_errors_and_exceptions: List[str] = [],
303
+ ) -> Tuple[bool, str]:
304
+ if isinstance(imports, np.ndarray):
305
+ imports = imports.tolist()
306
+ if imports:
307
+ code = '\n'.join(imports) + '\n' + code
308
+ if contains_banned_imports(code=code, banned_keywords=banned_keywords, banned_keywords_for_errors_and_exceptions=banned_keywords_for_errors_and_exceptions if check_error else []):
309
+ return False, None
310
+ if check_error:
311
+ code_snippet = RUN_CODE_TEMPLATE_REPR.format(code=code, inputs=inputs)
312
+ try:
313
+ ast.parse(code_snippet)
314
+ except:
315
+ return False, 'error'
316
+ output, status = self.apply(code_snippet)
317
+ if check_determinism: # run the code again, see if outputs are same
318
+ output_2, status_2 = self.apply(code_snippet)
319
+ if status_2.lower() != status.lower() and output != output_2:
320
+ return False, 'error'
321
+ # True if the code is valid code but might have error, output no error if the code returns something
322
+ return True, 'NoError' if status.lower() == 'done' else parse_error(status)
323
+ else:
324
+ if check_determinism:
325
+ code_snippet = CHECK_DETERMINISM_TEMPLATE_REPR.format(code=code, inputs=inputs)
326
+ else:
327
+ code_snippet = RUN_CODE_TEMPLATE_REPR.format(code=code, inputs=inputs)
328
+ if self.ast_check:
329
+ try:
330
+ ast.parse(code_snippet)
331
+ except:
332
+ return False, 'error'
333
+ output, status = self.apply(code_snippet)
334
+ return not 'error' in status.lower(), output
335
+
336
+ def apply(self, code) -> Tuple[str, str]:
337
+ try:
338
+ response = run_code(
339
+ RunCodeRequest(
340
+ code=code,
341
+ language='python',
342
+ compile_timeout=self.timeout_length,
343
+ run_timeout=self.timeout_length,
344
+ )
345
+ )
346
+ if response.status == RunStatus.Success:
347
+ # taking [1:-1] to exclude prefix space and suffix newline
348
+ return response.run_result.stdout.split('<FINAL_REPR_SYMBOL>')[-1][1:-1], 'done'
349
+ else:
350
+ return '', 'error'
351
+
352
+ except Exception as e:
353
+ error_msg = f"Execution error: {str(e)}"
354
+ return error_msg, 'error'
355
+
356
+
357
+ def _test():
358
+ batch_code = [
359
+ """
360
+ def f(a):
361
+ return a
362
+ print('<FINAL_REPR_SYMBOL>', repr(f(12eee)))
363
+ """
364
+ ]
365
+
366
+ executor = SandboxfusionExecutor()
367
+ predictions = executor.apply(batch_code[0])
368
+ print(predictions)
369
+
370
+
371
+ if __name__ == '__main__':
372
+ _test()
absolute_zero_reasoner/utils/code_utils/templates.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+
4
+ RUN_CODE_TEMPLATE = """{code}
5
+ repr(f({inputs}))"""
6
+
7
+ RUN_CODE_TEMPLATE_REPR = """{code}
8
+ print('<FINAL_REPR_SYMBOL>', repr(f({inputs})))"""
9
+
10
+ VALIDATE_CODE_TEMPLATE = """{code}
11
+ repr(f({inputs}))"""
12
+
13
+ VALIDATE_CODE_TEMPLATE_REPR = """{code}
14
+ print('<FINAL_REPR_SYMBOL>', repr(f({inputs})))"""
15
+
16
+ EVAL_INPUT_PREDICTION_TEMPLATE = """{code}
17
+ {gold_output} == f({agent_input})"""
18
+
19
+ EVAL_INPUT_PREDICTION_TEMPLATE_REPR = """{code}
20
+ print('<FINAL_REPR_SYMBOL>', repr({gold_output} == f({agent_input})))"""
21
+
22
+ EVAL_OUTPUT_PREDICTION_TEMPLATE = """{code}
23
+ eval({gold_output}) == eval({agent_output})"""
24
+
25
+ EVAL_OUTPUT_PREDICTION_TEMPLATE_REPR = """{code}
26
+ print('<FINAL_REPR_SYMBOL>', repr(eval({gold_output}) == eval({agent_output})))"""
27
+
28
+ CHECK_DETERMINISM_TEMPLATE = """{code}
29
+ returns = f({inputs})
30
+ if returns != f({inputs}):
31
+ raise Exception('Non-deterministic code')
32
+ repr(returns)"""
33
+
34
+ CHECK_DETERMINISM_TEMPLATE_REPR = """{code}
35
+ returns = f({inputs})
36
+ if returns != f({inputs}):
37
+ raise Exception('Non-deterministic code')
38
+ print('<FINAL_REPR_SYMBOL>', repr(returns))"""
39
+
40
+ def EVAL_K_INPUT_PREDICTION_TEMPLATE(code: str, gold_output: str, k_agent_inputs: List[str], repr_output: bool = False):
41
+ output_string = f"""{code}
42
+ acc_list = []"""
43
+ for inp in k_agent_inputs:
44
+ output_string += f"""\ntry:
45
+ acc_list.append({gold_output} == f({inp}))
46
+ except:
47
+ acc_list.append(False)"""
48
+ # then compute the mean of the list
49
+ if repr_output:
50
+ output_string += """\nprint('<FINAL_REPR_SYMBOL>', repr(acc_list))"""
51
+ else:
52
+ output_string += """\nacc_list"""
53
+ return output_string
54
+
55
+ def EVAL_K_OUTPUT_PREDICTION_TEMPLATE(code: str, gold_output: str, k_agent_outputs: List[str], repr_output: bool = False):
56
+ output_string = f"""{code}
57
+ acc_list = []"""
58
+ for out in k_agent_outputs:
59
+ output_string += f"""\ntry:
60
+ acc_list.append({gold_output} == {out})
61
+ except:
62
+ acc_list.append(False)"""
63
+ # then compute the mean of the list
64
+ if repr_output:
65
+ output_string += """\nprint('<FINAL_REPR_SYMBOL>', repr(acc_list))"""
66
+ else:
67
+ output_string += """\nacc_list"""
68
+ return output_string
absolute_zero_reasoner/utils/convert2hf.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
2
+ import torch
3
+ import fire
4
+ from collections import defaultdict
5
+
6
+
7
+ def main(
8
+ fsdp_checkpoint_path, huggingface_model_path, output_path, pretrained_tokenizer=True, world_size=4
9
+ ):
10
+ """
11
+ Convert FSDP checkpoint to HuggingFace checkpoint
12
+ Args:
13
+ fsdp_checkpoint_path: path to the FSDP checkpoint
14
+ huggingface_model_path: path to the HuggingFace model
15
+ output_path: path to save the converted checkpoint
16
+ Usage:
17
+ python reason_rl/utils/convert2hf.py \
18
+ checkpoints/azr/azr/test/test_answer/Qwen2.5-7B/answer_conditional/global_step_160_copy/actor \
19
+ checkpoints/azr/azr/test/test_answer/Qwen2.5-7B/answer_conditional/global_step_160_copy/actor/huggingface/ \
20
+ azr_90_composite_160_steps
21
+ """
22
+ state_dict = defaultdict(list)
23
+
24
+ for rank in range(int(world_size)):
25
+ filepath = f"{fsdp_checkpoint_path}/model_world_size_{world_size}_rank_{rank}.pt"
26
+ print("loading", filepath)
27
+ this_state_dict = torch.load(filepath)
28
+ for key, value in this_state_dict.items():
29
+ state_dict[key].append(value.to_local())
30
+
31
+ for key in state_dict:
32
+ state_dict[key] = torch.cat(state_dict[key], dim=0)
33
+
34
+ config = AutoConfig.from_pretrained(huggingface_model_path)
35
+ model = AutoModelForCausalLM.from_config(config)
36
+ model.load_state_dict(state_dict)
37
+
38
+ model.save_pretrained(output_path, max_shard_size="10GB")
39
+
40
+ tokenizer = AutoTokenizer.from_pretrained(huggingface_model_path)
41
+ tokenizer.save_pretrained(output_path)
42
+
43
+ # manually change the tokenizer.chat_template to
44
+ if pretrained_tokenizer:
45
+ chat_template = "{%- for message in messages -%}{{- '\n' if not loop.first -}}{{- message['content'] -}}{%- endfor -%}"
46
+ import os
47
+ import json
48
+ with open(os.path.join(output_path, "tokenizer_config.json"), "r") as f:
49
+ tokenizer_config = json.load(f)
50
+ tokenizer_config["chat_template"] = chat_template
51
+ with open(os.path.join(output_path, "tokenizer_config.json"), "w") as f:
52
+ json.dump(tokenizer_config, f)
53
+
54
+ if __name__ == "__main__":
55
+ fire.Fire(main)
absolute_zero_reasoner/utils/dataset/__init__.py ADDED
File without changes
absolute_zero_reasoner/utils/dataset/ipo_grouped_sampler.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IPO Group-aware Batch Sampler for TTRLVR
3
+
4
+ 동일한 ipo_group_id를 가진 task들을 같은 배치에 묶는 커스텀 샘플러
5
+ 이를 통해 동일한 IPO triple에서 생성된 induction/deduction/abduction task들이
6
+ 함께 학습되도록 보장합니다.
7
+ """
8
+
9
+ import torch
10
+ from torch.utils.data import Sampler, BatchSampler
11
+ from typing import Iterator, List, Optional
12
+ import random
13
+ from collections import defaultdict
14
+ import pandas as pd
15
+ import numpy as np
16
+
17
+
18
+ class IPOGroupedBatchSampler(Sampler):
19
+ """동일한 IPO에서 생성된 task들을 같은 배치에 묶는 샘플러"""
20
+
21
+ def __init__(self,
22
+ dataset,
23
+ batch_size: int,
24
+ shuffle: bool = True,
25
+ drop_last: bool = False,
26
+ seed: int = 42):
27
+ """
28
+ Args:
29
+ dataset: ipo_group_id를 가진 데이터셋 (TTRLVRDataset)
30
+ batch_size: 배치 크기
31
+ shuffle: 그룹 순서를 섞을지 여부
32
+ drop_last: 마지막 불완전한 배치를 버릴지 여부
33
+ seed: 랜덤 시드
34
+ """
35
+ self.dataset = dataset
36
+ self.batch_size = batch_size
37
+ self.shuffle = shuffle
38
+ self.drop_last = drop_last
39
+ self.generator = torch.Generator()
40
+ self.generator.manual_seed(seed)
41
+
42
+ # ipo_group_id별로 인덱스 그룹핑
43
+ self.groups = defaultdict(list)
44
+ self._build_groups()
45
+
46
+ # 배치 생성
47
+ self._create_batches()
48
+
49
+ def _build_groups(self):
50
+ """데이터셋에서 ipo_group_id별로 인덱스를 그룹핑"""
51
+
52
+ for idx in range(len(self.dataset)):
53
+ # TTRLVRDataset의 dataframe에서 직접 접근
54
+ if hasattr(self.dataset, 'dataframe'):
55
+ row = self.dataset.dataframe.iloc[idx]
56
+ ipo_group_id = row.get('ipo_group_id', None)
57
+
58
+ # ipo_group_id가 없으면 개별 그룹으로 처리
59
+ if not ipo_group_id or ipo_group_id == '':
60
+ ipo_group_id = f'individual_{idx}'
61
+ else:
62
+ # Fallback: 개별 그룹
63
+ ipo_group_id = f'individual_{idx}'
64
+
65
+ self.groups[ipo_group_id].append(idx)
66
+
67
+ print(f"[IPOGroupedBatchSampler] Built {len(self.groups)} IPO groups from {len(self.dataset)} samples")
68
+
69
+ # 그룹 크기 통계
70
+ group_sizes = [len(indices) for indices in self.groups.values()]
71
+ if group_sizes:
72
+ print(f" - Group sizes: min={min(group_sizes)}, max={max(group_sizes)}, avg={np.mean(group_sizes):.2f}")
73
+
74
+ def _create_batches(self):
75
+ """그룹별로 배치 생성"""
76
+ self.batches = []
77
+
78
+ # 모든 인덱스를 수집 (그룹 단위로)
79
+ all_indices = []
80
+
81
+ for group_id, indices in self.groups.items():
82
+ # 같은 IPO 그룹의 task들을 함께 유지
83
+ # 일반적으로 3개 (induction, deduction, abduction)
84
+ if len(indices) <= self.batch_size:
85
+ # 그룹이 배치 크기보다 작으면 그대로 사용
86
+ all_indices.extend(indices)
87
+ else:
88
+ # 그룹이 배치 크기보다 크면 분할 (드물지만 가능)
89
+ for i in range(0, len(indices), self.batch_size):
90
+ chunk = indices[i:i + self.batch_size]
91
+ all_indices.extend(chunk)
92
+
93
+ # 배치 생성
94
+ current_batch = []
95
+ for idx in all_indices:
96
+ current_batch.append(idx)
97
+
98
+ if len(current_batch) == self.batch_size:
99
+ self.batches.append(current_batch)
100
+ current_batch = []
101
+
102
+ # 마지막 불완전한 배치 처리
103
+ if current_batch and not self.drop_last:
104
+ self.batches.append(current_batch)
105
+ elif current_batch and self.drop_last:
106
+ print(f"[IPOGroupedBatchSampler] Dropped last incomplete batch of size {len(current_batch)}")
107
+
108
+ print(f"[IPOGroupedBatchSampler] Created {len(self.batches)} batches")
109
+
110
+ def __iter__(self) -> Iterator[List[int]]:
111
+ """배치 반복자"""
112
+ # 배치 순서 섞기
113
+ if self.shuffle:
114
+ indices = torch.randperm(len(self.batches), generator=self.generator).tolist()
115
+ shuffled_batches = [self.batches[i] for i in indices]
116
+ else:
117
+ shuffled_batches = self.batches
118
+
119
+ # 각 배치 yield
120
+ for batch in shuffled_batches:
121
+ # 배치 내부도 섞을 수 있음 (선택적)
122
+ if self.shuffle:
123
+ random.shuffle(batch)
124
+ yield batch
125
+
126
+ def __len__(self) -> int:
127
+ """전체 배치 수"""
128
+ return len(self.batches)
129
+
130
+
131
+ class IPOGroupPreservingBatchSampler(BatchSampler):
132
+ """
133
+ IPO 그룹을 최대한 보존하면서 배치를 생성하는 샘플러
134
+
135
+ 이 샘플러는 다음 우선순위로 작동합니다:
136
+ 1. 같은 ipo_group_id를 가진 샘플들을 우선적으로 같은 배치에 배치
137
+ 2. 배치 크기를 채우기 위해 필요시 다른 그룹의 샘플 추가
138
+ 3. 모든 샘플이 정확히 한 번씩 사용되도록 보장
139
+ """
140
+
141
+ def __init__(self,
142
+ dataset,
143
+ batch_size: int,
144
+ shuffle: bool = True,
145
+ drop_last: bool = False,
146
+ seed: int = 42):
147
+ """
148
+ Args:
149
+ dataset: TTRLVRDataset 인스턴스
150
+ batch_size: 배치 크기
151
+ shuffle: 배치 및 그룹 순서 섞기
152
+ drop_last: 마지막 불완전한 배치 버리기
153
+ seed: 랜덤 시드
154
+ """
155
+ self.dataset = dataset
156
+ self.batch_size = batch_size
157
+ self.shuffle = shuffle
158
+ self.drop_last = drop_last
159
+ self.seed = seed
160
+
161
+ # 그룹별 인덱스 구축
162
+ self.groups = self._build_groups()
163
+
164
+ def _build_groups(self):
165
+ """ipo_group_id별로 샘플 인덱스 그룹핑"""
166
+ groups = defaultdict(list)
167
+
168
+ for idx in range(len(self.dataset)):
169
+ if hasattr(self.dataset, 'dataframe'):
170
+ row = self.dataset.dataframe.iloc[idx]
171
+ ipo_group_id = row.get('ipo_group_id', '')
172
+
173
+ # 빈 값이면 개별 처리
174
+ if not ipo_group_id:
175
+ ipo_group_id = f'single_{idx}'
176
+ else:
177
+ ipo_group_id = f'single_{idx}'
178
+
179
+ groups[ipo_group_id].append(idx)
180
+
181
+ return groups
182
+
183
+ def __iter__(self):
184
+ """배치 생성 및 반복"""
185
+ # 그룹들을 리스트로 변환
186
+ group_list = list(self.groups.items())
187
+
188
+ # 셔플
189
+ if self.shuffle:
190
+ random.seed(self.seed)
191
+ random.shuffle(group_list)
192
+
193
+ # 배치 생성
194
+ current_batch = []
195
+
196
+ for group_id, indices in group_list:
197
+ # 그룹 내 인덱스도 셔플
198
+ if self.shuffle:
199
+ random.shuffle(indices)
200
+
201
+ for idx in indices:
202
+ current_batch.append(idx)
203
+
204
+ # 배치가 가득 차면 yield
205
+ if len(current_batch) == self.batch_size:
206
+ yield current_batch
207
+ current_batch = []
208
+
209
+ # 마지막 배치 처리
210
+ if current_batch and not self.drop_last:
211
+ yield current_batch
212
+
213
+ def __len__(self):
214
+ """전체 배치 수 계산"""
215
+ total_samples = len(self.dataset)
216
+
217
+ if self.drop_last:
218
+ return total_samples // self.batch_size
219
+ else:
220
+ return (total_samples + self.batch_size - 1) // self.batch_size