|
|
|
""" |
|
Step 5 μ μ© μ€ν μ€ν¬λ¦½νΈ |
|
κΈ°μ‘΄ AZR νμ΅ λ°μ΄ν°λ‘ VeRL PPO νμ΅λ§ μ€ν |
|
""" |
|
|
|
import os |
|
import sys |
|
import argparse |
|
from pathlib import Path |
|
|
|
|
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2') |
|
sys.path.append('/home/ubuntu/RLVR/TestTime-RLVR-v2/test') |
|
|
|
from test.utils.iterative_trainer import IterativeTrainer |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='Run VeRL training (Step 5) only with existing data') |
|
parser.add_argument('--data_path', type=str, required=True, |
|
help='Path to existing azr_training_data directory') |
|
parser.add_argument('--round', type=int, default=1, |
|
help='Round number for logging (default: 1)') |
|
parser.add_argument('--experiment_name', type=str, default=None, |
|
help='Custom experiment name') |
|
parser.add_argument('--config', type=str, |
|
default='/home/ubuntu/RLVR/TestTime-RLVR-v2/test/configs/ttrlvr_azr_ppo_4gpu.yaml', |
|
help='VeRL config file path') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
data_path = Path(args.data_path) |
|
if not data_path.exists(): |
|
print(f"β Error: Data path does not exist: {data_path}") |
|
return 1 |
|
|
|
|
|
required_files = ['induction.parquet', 'deduction.parquet', 'abduction.parquet'] |
|
missing_files = [] |
|
for file_name in required_files: |
|
if not (data_path / file_name).exists(): |
|
missing_files.append(file_name) |
|
|
|
if missing_files: |
|
print(f"β Error: Missing required files: {missing_files}") |
|
return 1 |
|
|
|
print(f"β
Found all required training data files in: {data_path}") |
|
|
|
|
|
for file_name in required_files: |
|
file_path = data_path / file_name |
|
file_size = file_path.stat().st_size |
|
print(f" π {file_name}: {file_size:,} bytes") |
|
|
|
|
|
print(f"π Initializing trainer with config: {args.config}") |
|
trainer = IterativeTrainer(config_path=args.config) |
|
|
|
|
|
print(f"π Starting VeRL training (Step 5 only)") |
|
print(f"π Data path: {data_path}") |
|
print(f"π Round: {args.round}") |
|
|
|
try: |
|
result = trainer.run_verl_training_only( |
|
training_data_path=str(data_path), |
|
round_num=args.round, |
|
experiment_name=args.experiment_name |
|
) |
|
|
|
if result.get('success', False): |
|
print(f"β
VeRL training completed successfully!") |
|
print(f"β±οΈ Duration: {result.get('duration', 'N/A')} seconds") |
|
if 'model_path' in result: |
|
print(f"π€ Updated model: {result['model_path']}") |
|
else: |
|
print(f"β VeRL training failed: {result.get('error', 'Unknown error')}") |
|
return 1 |
|
|
|
except Exception as e: |
|
print(f"π₯ Training failed with exception: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return 1 |
|
|
|
print(f"π Step 5 training completed!") |
|
return 0 |
|
|
|
if __name__ == "__main__": |
|
exit_code = main() |
|
sys.exit(exit_code) |