|
|
|
""" |
|
๋ชจ๋ธ๋ณ ํน์ ๋ฌธ์ ์๋ต ๋น๊ต ๋๊ตฌ (์ ์ฒด ์๋ต ์ถ๋ ฅ ๋ฒ์ ) |
|
์ฌ์ฉ๋ฒ: python compare_models.py --benchmark aime24 --problem_idx 0 --models qwen25_7b_base azr_coder_7b azr_base_7b |
|
|
|
์ฃผ์ ๋ณ๊ฒฝ์ฌํญ: |
|
- ๋ฌธ์ ํ
์คํธ์ ์ถ๋ก ๊ณผ์ ์ ์ ์ฒด ์ถ๋ ฅ (๊ธธ์ด ์ ํ ์ ๊ฑฐ) |
|
- ๋ ์์ธํ ๋ถ์์ ์ํ ์์ ํ ์๋ต ํ์ธ ๊ฐ๋ฅ |
|
""" |
|
|
|
import json |
|
import argparse |
|
import os |
|
import glob |
|
import sys |
|
from typing import List, Dict, Any |
|
import re |
|
|
|
def extract_answer_from_response(response: str) -> str: |
|
"""์๋ต์์ ์ต์ข
๋ต์ ์ถ์ถ""" |
|
|
|
boxed_pattern = r'\\boxed\{([^}]+)\}' |
|
matches = re.findall(boxed_pattern, response) |
|
if matches: |
|
return matches[-1] |
|
|
|
|
|
number_pattern = r'\b\d+(?:\.\d+)?\b' |
|
numbers = re.findall(number_pattern, response) |
|
if numbers: |
|
return numbers[-1] |
|
|
|
return "๋ต์์ ์ฐพ์ ์ ์์" |
|
|
|
def load_problem_result(model_dir: str, benchmark: str, problem_idx: int) -> Dict[str, Any]: |
|
"""ํน์ ๋ชจ๋ธ์ ํน์ ๋ฌธ์ ๊ฒฐ๊ณผ ๋ก๋""" |
|
|
|
result_files = glob.glob(f"{model_dir}/{benchmark}/*.jsonl") |
|
|
|
if not result_files: |
|
return {"error": f"๋ฒค์น๋งํฌ '{benchmark}' ๊ฒฐ๊ณผ ํ์ผ์ ์ฐพ์ ์ ์์"} |
|
|
|
result_file = result_files[0] |
|
|
|
try: |
|
with open(result_file, 'r', encoding='utf-8') as f: |
|
lines = f.readlines() |
|
|
|
if problem_idx >= len(lines): |
|
return {"error": f"๋ฌธ์ ์ธ๋ฑ์ค {problem_idx}๊ฐ ๋ฒ์๋ฅผ ๋ฒ์ด๋จ (์ด {len(lines)}๊ฐ ๋ฌธ์ )"} |
|
|
|
problem_data = json.loads(lines[problem_idx]) |
|
return problem_data |
|
|
|
except Exception as e: |
|
return {"error": f"ํ์ผ ์ฝ๊ธฐ ์ค๋ฅ: {str(e)}"} |
|
|
|
def format_response_text(text: str, max_length: int = None) -> str: |
|
"""์๋ต ํ
์คํธ ํฌ๋งทํ
""" |
|
if max_length is not None and len(text) > max_length: |
|
return text[:max_length] + "... (truncated)" |
|
return text |
|
|
|
def compare_models_on_problem( |
|
benchmark: str, |
|
problem_idx: int, |
|
model_dirs: List[str], |
|
results_base_dir: str = "/home/ubuntu/RLVR/Absolute-Zero-Reasoner/evaluation/math_eval/EVAL/results" |
|
): |
|
"""ํน์ ๋ฌธ์ ์ ๋ํ ๋ชจ๋ธ๋ณ ๋น๊ต""" |
|
|
|
print("=" * 100) |
|
print(f"๐ ๋ฒค์น๋งํฌ: {benchmark.upper()}") |
|
print(f"๐ข ๋ฌธ์ ๋ฒํธ: {problem_idx}") |
|
print("=" * 100) |
|
|
|
|
|
model_results = {} |
|
problem_text = None |
|
correct_answer = None |
|
|
|
for model_name in model_dirs: |
|
model_path = os.path.join(results_base_dir, model_name) |
|
result = load_problem_result(model_path, benchmark, problem_idx) |
|
model_results[model_name] = result |
|
|
|
|
|
if problem_text is None and "error" not in result: |
|
problem_text = result.get("question", "๋ฌธ์ ๋ฅผ ์ฐพ์ ์ ์์") |
|
correct_answer = result.get("gt", result.get("answer", "์ ๋ต์ ์ฐพ์ ์ ์์")) |
|
|
|
|
|
print(f"๐ ๋ฌธ์ :") |
|
print("-" * 80) |
|
print(format_response_text(problem_text)) |
|
print() |
|
print(f"โ
์ ๋ต: {correct_answer}") |
|
print() |
|
|
|
|
|
preferred_order = [ |
|
"qwen25_7b_base", |
|
"azr_base_7b", |
|
"qwen25_7b_coder", |
|
"azr_coder_7b" |
|
] |
|
|
|
|
|
ordered_models = [] |
|
|
|
for preferred_model in preferred_order: |
|
if preferred_model in model_dirs: |
|
ordered_models.append(preferred_model) |
|
|
|
|
|
for model in model_dirs: |
|
if model not in ordered_models: |
|
ordered_models.append(model) |
|
|
|
|
|
print("๐ค ๋ชจ๋ธ๋ณ ์๋ต ๋น๊ต:") |
|
print("=" * 100) |
|
|
|
for model_name in ordered_models: |
|
result = model_results[model_name] |
|
print(f"\n๐ธ {model_name.upper()}") |
|
print("-" * 60) |
|
|
|
if "error" in result: |
|
print(f"โ ์ค๋ฅ: {result['error']}") |
|
continue |
|
|
|
|
|
model_pred = result.get("pred", [""])[0] if result.get("pred") else "" |
|
is_correct = result.get("score", [False])[0] if result.get("score") else False |
|
|
|
|
|
full_response = result.get("code", [""])[0] if result.get("code") else "" |
|
extracted_answer = extract_answer_from_response(full_response) |
|
|
|
|
|
status = "โ
์ ๋ต" if is_correct else "โ ์ค๋ต" |
|
print(f"๊ฒฐ๊ณผ: {status}") |
|
print(f"๋ชจ๋ธ ๋ต์: {model_pred}") |
|
print(f"์ถ์ถ๋ ๋ต์: {extracted_answer}") |
|
print() |
|
|
|
|
|
if full_response: |
|
print("๐ง ์ถ๋ก ๊ณผ์ (์ ์ฒด):") |
|
print(format_response_text(full_response)) |
|
else: |
|
print("์ถ๋ก ๊ณผ์ ์ ์ฐพ์ ์ ์์") |
|
|
|
print("-" * 60) |
|
|
|
|
|
print(f"\n๐ ์์ฝ:") |
|
print("-" * 40) |
|
correct_models = [] |
|
wrong_models = [] |
|
error_models = [] |
|
|
|
for model_name in ordered_models: |
|
result = model_results[model_name] |
|
if "error" in result: |
|
error_models.append(model_name) |
|
elif result.get("score", [False])[0]: |
|
correct_models.append(model_name) |
|
else: |
|
wrong_models.append(model_name) |
|
|
|
if correct_models: |
|
print(f"โ
์ ๋ต ๋ชจ๋ธ: {', '.join(correct_models)}") |
|
if wrong_models: |
|
print(f"โ ์ค๋ต ๋ชจ๋ธ: {', '.join(wrong_models)}") |
|
if error_models: |
|
print(f"โ ๏ธ ์ค๋ฅ ๋ชจ๋ธ: {', '.join(error_models)}") |
|
|
|
print(f"์ ๋ต๋ฅ : {len(correct_models)}/{len(model_dirs) - len(error_models)} ({len(correct_models)/(len(model_dirs) - len(error_models))*100:.1f}%)") |
|
|
|
def list_available_benchmarks(results_base_dir: str, model_name: str) -> List[str]: |
|
"""์ฌ์ฉ ๊ฐ๋ฅํ ๋ฒค์น๋งํฌ ๋ชฉ๋ก ๋ฐํ""" |
|
model_path = os.path.join(results_base_dir, model_name) |
|
if not os.path.exists(model_path): |
|
return [] |
|
|
|
benchmarks = [] |
|
for item in os.listdir(model_path): |
|
item_path = os.path.join(model_path, item) |
|
if os.path.isdir(item_path): |
|
benchmarks.append(item) |
|
|
|
return sorted(benchmarks) |
|
|
|
def get_problem_count(results_base_dir: str, model_name: str, benchmark: str) -> int: |
|
"""ํน์ ๋ฒค์น๋งํฌ์ ๋ฌธ์ ์ ๋ฐํ""" |
|
result_files = glob.glob(f"{results_base_dir}/{model_name}/{benchmark}/*.jsonl") |
|
if not result_files: |
|
return 0 |
|
|
|
try: |
|
with open(result_files[0], 'r', encoding='utf-8') as f: |
|
return len(f.readlines()) |
|
except: |
|
return 0 |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description="๋ชจ๋ธ๋ณ ํน์ ๋ฌธ์ ์๋ต ๋น๊ต") |
|
parser.add_argument("--benchmark", "-b", type=str, required=True, |
|
help="๋ฒค์น๋งํฌ ์ด๋ฆ (์: aime24, aime25, amc23, math500)") |
|
parser.add_argument("--problem_idx", "-p", type=int, required=True, |
|
help="๋ฌธ์ ์ธ๋ฑ์ค (0๋ถํฐ ์์)") |
|
parser.add_argument("--models", "-m", nargs="+", required=True, |
|
help="๋น๊ตํ ๋ชจ๋ธ ๋๋ ํ ๋ฆฌ ์ด๋ฆ๋ค") |
|
parser.add_argument("--results_dir", "-r", type=str, |
|
default="/home/ubuntu/RLVR/Absolute-Zero-Reasoner/evaluation/math_eval/EVAL/results", |
|
help="๊ฒฐ๊ณผ ๋๋ ํ ๋ฆฌ ๊ฒฝ๋ก") |
|
parser.add_argument("--list", "-l", action="store_true", |
|
help="์ฌ์ฉ ๊ฐ๋ฅํ ๋ฒค์น๋งํฌ์ ๋ฌธ์ ์ ๋์ด") |
|
parser.add_argument("--full", "-f", action="store_true", |
|
help="์ ์ฒด ์๋ต ์ถ๋ ฅ (๊ธฐ๋ณธ๊ฐ: True, ์ด์ ํญ์ ์ ์ฒด ์ถ๋ ฅ)") |
|
parser.add_argument("--output", "-o", type=str, |
|
help="๊ฒฐ๊ณผ๋ฅผ ํ์ผ๋ก ์ ์ฅ (์: --output result.txt)") |
|
|
|
args = parser.parse_args() |
|
|
|
if args.list: |
|
print("๐ ์ฌ์ฉ ๊ฐ๋ฅํ ์ ๋ณด:") |
|
print("-" * 50) |
|
|
|
|
|
if args.models: |
|
benchmarks = list_available_benchmarks(args.results_dir, args.models[0]) |
|
for benchmark in benchmarks: |
|
problem_count = get_problem_count(args.results_dir, args.models[0], benchmark) |
|
print(f"โข {benchmark}: {problem_count}๊ฐ ๋ฌธ์ (์ธ๋ฑ์ค 0-{problem_count-1})") |
|
return |
|
|
|
|
|
output_file = None |
|
if args.output: |
|
try: |
|
output_file = open(args.output, 'w', encoding='utf-8') |
|
|
|
original_stdout = sys.stdout |
|
sys.stdout = output_file |
|
print(f"๊ฒฐ๊ณผ๋ฅผ {args.output} ํ์ผ์ ์ ์ฅํฉ๋๋ค...") |
|
|
|
sys.stdout = original_stdout |
|
print(f"๐ ๊ฒฐ๊ณผ๋ฅผ {args.output} ํ์ผ์ ์ ์ฅํฉ๋๋ค...") |
|
sys.stdout = output_file |
|
except Exception as e: |
|
print(f"โ ํ์ผ ์์ฑ ์ค๋ฅ: {e}") |
|
return |
|
|
|
try: |
|
|
|
compare_models_on_problem( |
|
benchmark=args.benchmark, |
|
problem_idx=args.problem_idx, |
|
model_dirs=args.models, |
|
results_base_dir=args.results_dir |
|
) |
|
finally: |
|
|
|
if output_file: |
|
sys.stdout = original_stdout |
|
output_file.close() |
|
print(f"โ
๊ฒฐ๊ณผ๊ฐ {args.output} ํ์ผ์ ์ ์ฅ๋์์ต๋๋ค.") |
|
print(f"๐ ํ์ผ ๋ด์ฉ ํ์ธ: cat {args.output}") |
|
print(f"๐ ํ์ผ ์ด๊ธฐ: less {args.output}") |
|
print(f"๐ ํน์ ๋ชจ๋ธ ๊ฒ์: grep -A 10 'QWEN25_7B_BASE' {args.output}") |
|
|
|
if __name__ == "__main__": |
|
main() |