|
|
|
""" |
|
๊ฐ๋จํ TTRLVR + AZR ํตํฉ ํ
์คํธ |
|
|
|
๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ์ปดํฌ๋ํธ ํ
์คํธ: |
|
1. Task Generator ํ
์คํธ |
|
2. Data Converter ํ
์คํธ |
|
3. Pipeline ๊ธฐ๋ณธ ์คํ ํ
์คํธ |
|
""" |
|
|
|
import os |
|
import sys |
|
import tempfile |
|
import shutil |
|
from pathlib import Path |
|
|
|
|
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2') |
|
|
|
def test_task_generator(): |
|
"""Task Generator ๊ธฐ๋ณธ ํ
์คํธ""" |
|
print("๐งช Testing Task Generator...") |
|
|
|
try: |
|
from absolute_zero_reasoner.testtime.config import TestTimeConfig |
|
from absolute_zero_reasoner.testtime.logger import TestTimeLogger |
|
from absolute_zero_reasoner.testtime.task_generator import TestTimeTaskGenerator |
|
|
|
config = TestTimeConfig() |
|
config.model_name = "Qwen/Qwen2.5-7B" |
|
logger = TestTimeLogger() |
|
|
|
task_generator = TestTimeTaskGenerator(config, logger) |
|
|
|
|
|
test_ipo_triples = [{ |
|
'id': 'test_triple_0', |
|
'input': '[1, 2, 3]', |
|
'actual_output': '[2, 4, 6]', |
|
'program': 'def test_func(lst):\n return [x * 2 for x in lst]', |
|
'full_input_str': 'test_func([1, 2, 3])', |
|
'source_program_id': 'program_0', |
|
'ipo_index': 0 |
|
}] |
|
|
|
|
|
tasks = task_generator.generate_tasks(test_ipo_triples, "TestProblem", 1) |
|
|
|
|
|
assert 'induction' in tasks |
|
assert 'deduction' in tasks |
|
assert 'abduction' in tasks |
|
|
|
total_tasks = sum(len(task_list) for task_list in tasks.values()) |
|
print(f"โ
Task Generator: Generated {total_tasks} tasks") |
|
|
|
|
|
for task_type, task_list in tasks.items(): |
|
if task_list: |
|
task = task_list[0] |
|
assert 'uid' in task |
|
assert 'ipo_group_id' in task |
|
assert 'basic_accuracy' in task |
|
print(f"โ
Task Generator: {task_type} has AZR metadata") |
|
|
|
return True |
|
|
|
except Exception as e: |
|
print(f"โ Task Generator test failed: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return False |
|
|
|
def test_data_converter(): |
|
"""Data Converter ๊ธฐ๋ณธ ํ
์คํธ""" |
|
print("\n๐งช Testing Data Converter...") |
|
|
|
try: |
|
from absolute_zero_reasoner.testtime.complete_pipeline import CompleteTestTimePipeline |
|
from absolute_zero_reasoner.testtime.config import TestTimeConfig |
|
from absolute_zero_reasoner.testtime.logger import TestTimeLogger |
|
|
|
config = TestTimeConfig() |
|
logger = TestTimeLogger() |
|
pipeline = CompleteTestTimePipeline(config, logger) |
|
|
|
|
|
mock_tasks = { |
|
'induction': [{ |
|
'task_id': 'induction_0', |
|
'task_type': 'induction', |
|
'prompt': 'Test prompt', |
|
'uid': 'TestProblem_round_1_induction_0', |
|
'ipo_group_id': 'TestProblem_program_0_ipo_0', |
|
'source_program_id': 'program_0', |
|
'ipo_index': 0, |
|
'ipo_triple': { |
|
'input': '[1, 2, 3]', |
|
'output': '[2, 4, 6]', |
|
'program': 'def test_func(lst):\n return [x * 2 for x in lst]' |
|
}, |
|
'ground_truth': 'def test_func(lst):\n return [x * 2 for x in lst]', |
|
'extra_info': {'metric': 'code_f'}, |
|
'basic_accuracy': 1.0, |
|
'original_problem_id': 'TestProblem', |
|
'round': 1 |
|
}] |
|
} |
|
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir: |
|
|
|
saved_files = pipeline._save_azr_training_data( |
|
mock_tasks, "TestProblem", 1, temp_dir |
|
) |
|
|
|
|
|
assert 'induction' in saved_files |
|
assert os.path.exists(saved_files['induction']) |
|
|
|
|
|
import pandas as pd |
|
df = pd.read_parquet(saved_files['induction']) |
|
|
|
assert len(df) == 1 |
|
assert 'prompt' in df.columns |
|
assert 'uid' in df.columns |
|
assert 'ipo_group_id' in df.columns |
|
|
|
print("โ
Data Converter: Parquet file created and validated") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"โ Data Converter test failed: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return False |
|
|
|
def test_iterative_trainer_basic(): |
|
"""Iterative Trainer ๊ธฐ๋ณธ ์ค์ ํ
์คํธ""" |
|
print("\n๐งช Testing Iterative Trainer Setup...") |
|
|
|
try: |
|
from utils.iterative_trainer import IterativeTrainer |
|
from absolute_zero_reasoner.testtime.config import TestTimeConfig, BenchmarkConfig |
|
from absolute_zero_reasoner.testtime.logger import TestTimeLogger |
|
|
|
config = TestTimeConfig() |
|
logger = TestTimeLogger() |
|
|
|
trainer = IterativeTrainer(config, logger) |
|
|
|
|
|
assert trainer.current_model_path == "Qwen/Qwen2.5-7B" |
|
assert trainer.checkpoint_dir == "/data/RLVR/checkpoints/ttrlvr_azr" |
|
|
|
print("โ
Iterative Trainer: Basic setup successful") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"โ Iterative Trainer test failed: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return False |
|
|
|
def main(): |
|
"""๋ฉ์ธ ํ
์คํธ ์คํ""" |
|
print("๐ TTRLVR + AZR Simple Integration Test") |
|
print("=" * 60) |
|
|
|
tests = [ |
|
("Task Generator", test_task_generator), |
|
("Data Converter", test_data_converter), |
|
("Iterative Trainer", test_iterative_trainer_basic) |
|
] |
|
|
|
results = [] |
|
|
|
for test_name, test_func in tests: |
|
try: |
|
result = test_func() |
|
results.append((test_name, result)) |
|
except Exception as e: |
|
print(f"๐ฅ {test_name} crashed: {e}") |
|
results.append((test_name, False)) |
|
|
|
|
|
print("\n" + "=" * 60) |
|
print("๐ Test Results:") |
|
|
|
passed = 0 |
|
total = len(results) |
|
|
|
for test_name, result in results: |
|
status = "โ
PASS" if result else "โ FAIL" |
|
print(f" {status} {test_name}") |
|
if result: |
|
passed += 1 |
|
|
|
print(f"\nOverall: {passed}/{total} tests passed ({passed/total*100:.1f}%)") |
|
|
|
if passed == total: |
|
print("\n๐ All simple integration tests passed!") |
|
return 0 |
|
else: |
|
print(f"\nโ ๏ธ {total-passed} tests failed") |
|
return 1 |
|
|
|
if __name__ == '__main__': |
|
exit_code = main() |
|
sys.exit(exit_code) |