hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
#!/usr/bin/env python3
"""
๋ชจ๋ธ๋ณ„ ํŠน์ • ๋ฌธ์ œ ์‘๋‹ต ๋น„๊ต ๋„๊ตฌ (์ „์ฒด ์‘๋‹ต ์ถœ๋ ฅ ๋ฒ„์ „)
์‚ฌ์šฉ๋ฒ•: python compare_models.py --benchmark aime24 --problem_idx 0 --models qwen25_7b_base azr_coder_7b azr_base_7b
์ฃผ์š” ๋ณ€๊ฒฝ์‚ฌํ•ญ:
- ๋ฌธ์ œ ํ…์ŠคํŠธ์™€ ์ถ”๋ก  ๊ณผ์ •์„ ์ „์ฒด ์ถœ๋ ฅ (๊ธธ์ด ์ œํ•œ ์ œ๊ฑฐ)
- ๋” ์ƒ์„ธํ•œ ๋ถ„์„์„ ์œ„ํ•œ ์™„์ „ํ•œ ์‘๋‹ต ํ™•์ธ ๊ฐ€๋Šฅ
"""
import json
import argparse
import os
import glob
import sys
from typing import List, Dict, Any
import re
def extract_answer_from_response(response: str) -> str:
"""์‘๋‹ต์—์„œ ์ตœ์ข… ๋‹ต์•ˆ ์ถ”์ถœ"""
# \boxed{...} ํŒจํ„ด ์ฐพ๊ธฐ
boxed_pattern = r'\\boxed\{([^}]+)\}'
matches = re.findall(boxed_pattern, response)
if matches:
return matches[-1] # ๋งˆ์ง€๋ง‰ boxed ๋‹ต์•ˆ ์‚ฌ์šฉ
# ์ˆซ์ž ํŒจํ„ด ์ฐพ๊ธฐ (fallback)
number_pattern = r'\b\d+(?:\.\d+)?\b'
numbers = re.findall(number_pattern, response)
if numbers:
return numbers[-1]
return "๋‹ต์•ˆ์„ ์ฐพ์„ ์ˆ˜ ์—†์Œ"
def load_problem_result(model_dir: str, benchmark: str, problem_idx: int) -> Dict[str, Any]:
"""ํŠน์ • ๋ชจ๋ธ์˜ ํŠน์ • ๋ฌธ์ œ ๊ฒฐ๊ณผ ๋กœ๋“œ"""
# ๊ฒฐ๊ณผ ํŒŒ์ผ ์ฐพ๊ธฐ
result_files = glob.glob(f"{model_dir}/{benchmark}/*.jsonl")
if not result_files:
return {"error": f"๋ฒค์น˜๋งˆํฌ '{benchmark}' ๊ฒฐ๊ณผ ํŒŒ์ผ์„ ์ฐพ์„ ์ˆ˜ ์—†์Œ"}
result_file = result_files[0] # ์ฒซ ๋ฒˆ์งธ ํŒŒ์ผ ์‚ฌ์šฉ
try:
with open(result_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if problem_idx >= len(lines):
return {"error": f"๋ฌธ์ œ ์ธ๋ฑ์Šค {problem_idx}๊ฐ€ ๋ฒ”์œ„๋ฅผ ๋ฒ—์–ด๋‚จ (์ด {len(lines)}๊ฐœ ๋ฌธ์ œ)"}
problem_data = json.loads(lines[problem_idx])
return problem_data
except Exception as e:
return {"error": f"ํŒŒ์ผ ์ฝ๊ธฐ ์˜ค๋ฅ˜: {str(e)}"}
def format_response_text(text: str, max_length: int = None) -> str:
"""์‘๋‹ต ํ…์ŠคํŠธ ํฌ๋งทํŒ…"""
if max_length is not None and len(text) > max_length:
return text[:max_length] + "... (truncated)"
return text
def compare_models_on_problem(
benchmark: str,
problem_idx: int,
model_dirs: List[str],
results_base_dir: str = "/home/ubuntu/RLVR/Absolute-Zero-Reasoner/evaluation/math_eval/EVAL/results"
):
"""ํŠน์ • ๋ฌธ์ œ์— ๋Œ€ํ•œ ๋ชจ๋ธ๋ณ„ ๋น„๊ต"""
print("=" * 100)
print(f"๐Ÿ“Š ๋ฒค์น˜๋งˆํฌ: {benchmark.upper()}")
print(f"๐Ÿ”ข ๋ฌธ์ œ ๋ฒˆํ˜ธ: {problem_idx}")
print("=" * 100)
# ๊ฐ ๋ชจ๋ธ์˜ ๊ฒฐ๊ณผ ๋กœ๋“œ
model_results = {}
problem_text = None
correct_answer = None
for model_name in model_dirs:
model_path = os.path.join(results_base_dir, model_name)
result = load_problem_result(model_path, benchmark, problem_idx)
model_results[model_name] = result
# ๋ฌธ์ œ ํ…์ŠคํŠธ์™€ ์ •๋‹ต์€ ์ฒซ ๋ฒˆ์งธ ์„ฑ๊ณตํ•œ ๋ชจ๋ธ์—์„œ ๊ฐ€์ ธ์˜ค๊ธฐ
if problem_text is None and "error" not in result:
problem_text = result.get("question", "๋ฌธ์ œ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Œ")
correct_answer = result.get("gt", result.get("answer", "์ •๋‹ต์„ ์ฐพ์„ ์ˆ˜ ์—†์Œ"))
# ๋ฌธ์ œ ์ •๋ณด ์ถœ๋ ฅ
print(f"๐Ÿ“ ๋ฌธ์ œ:")
print("-" * 80)
print(format_response_text(problem_text))
print()
print(f"โœ… ์ •๋‹ต: {correct_answer}")
print()
# ๋ชจ๋ธ ์ถœ๋ ฅ ์ˆœ์„œ ์ •์˜ (์„ ํ˜ธํ•˜๋Š” ์ˆœ์„œ)
preferred_order = [
"qwen25_7b_base",
"azr_base_7b",
"qwen25_7b_coder",
"azr_coder_7b"
]
# ์ž…๋ ฅ๋œ ๋ชจ๋ธ๋“ค์„ ์„ ํ˜ธ ์ˆœ์„œ๋Œ€๋กœ ์ •๋ ฌ
ordered_models = []
# ๋จผ์ € ์„ ํ˜ธ ์ˆœ์„œ์— ์žˆ๋Š” ๋ชจ๋ธ๋“ค์„ ์ˆœ์„œ๋Œ€๋กœ ์ถ”๊ฐ€
for preferred_model in preferred_order:
if preferred_model in model_dirs:
ordered_models.append(preferred_model)
# ์„ ํ˜ธ ์ˆœ์„œ์— ์—†๋Š” ๋ชจ๋ธ๋“ค์„ ๋‚˜๋จธ์ง€์— ์ถ”๊ฐ€
for model in model_dirs:
if model not in ordered_models:
ordered_models.append(model)
# ๋ชจ๋ธ๋ณ„ ๊ฒฐ๊ณผ ๋น„๊ต
print("๐Ÿค– ๋ชจ๋ธ๋ณ„ ์‘๋‹ต ๋น„๊ต:")
print("=" * 100)
for model_name in ordered_models:
result = model_results[model_name]
print(f"\n๐Ÿ”ธ {model_name.upper()}")
print("-" * 60)
if "error" in result:
print(f"โŒ ์˜ค๋ฅ˜: {result['error']}")
continue
# ๋ชจ๋ธ ์˜ˆ์ธก๊ณผ ์ •๋‹ต ์—ฌ๋ถ€
model_pred = result.get("pred", [""])[0] if result.get("pred") else ""
is_correct = result.get("score", [False])[0] if result.get("score") else False
# ์ „์ฒด ์‘๋‹ต์—์„œ ๋‹ต์•ˆ ์ถ”์ถœ
full_response = result.get("code", [""])[0] if result.get("code") else ""
extracted_answer = extract_answer_from_response(full_response)
# ๊ฒฐ๊ณผ ์ถœ๋ ฅ
status = "โœ… ์ •๋‹ต" if is_correct else "โŒ ์˜ค๋‹ต"
print(f"๊ฒฐ๊ณผ: {status}")
print(f"๋ชจ๋ธ ๋‹ต์•ˆ: {model_pred}")
print(f"์ถ”์ถœ๋œ ๋‹ต์•ˆ: {extracted_answer}")
print()
# ์ถ”๋ก  ๊ณผ์ • (์ „์ฒด)
if full_response:
print("๐Ÿง  ์ถ”๋ก  ๊ณผ์ • (์ „์ฒด):")
print(format_response_text(full_response))
else:
print("์ถ”๋ก  ๊ณผ์ •์„ ์ฐพ์„ ์ˆ˜ ์—†์Œ")
print("-" * 60)
# ์š”์•ฝ ํ†ต๊ณ„
print(f"\n๐Ÿ“ˆ ์š”์•ฝ:")
print("-" * 40)
correct_models = []
wrong_models = []
error_models = []
for model_name in ordered_models:
result = model_results[model_name]
if "error" in result:
error_models.append(model_name)
elif result.get("score", [False])[0]:
correct_models.append(model_name)
else:
wrong_models.append(model_name)
if correct_models:
print(f"โœ… ์ •๋‹ต ๋ชจ๋ธ: {', '.join(correct_models)}")
if wrong_models:
print(f"โŒ ์˜ค๋‹ต ๋ชจ๋ธ: {', '.join(wrong_models)}")
if error_models:
print(f"โš ๏ธ ์˜ค๋ฅ˜ ๋ชจ๋ธ: {', '.join(error_models)}")
print(f"์ •๋‹ต๋ฅ : {len(correct_models)}/{len(model_dirs) - len(error_models)} ({len(correct_models)/(len(model_dirs) - len(error_models))*100:.1f}%)")
def list_available_benchmarks(results_base_dir: str, model_name: str) -> List[str]:
"""์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ฒค์น˜๋งˆํฌ ๋ชฉ๋ก ๋ฐ˜ํ™˜"""
model_path = os.path.join(results_base_dir, model_name)
if not os.path.exists(model_path):
return []
benchmarks = []
for item in os.listdir(model_path):
item_path = os.path.join(model_path, item)
if os.path.isdir(item_path):
benchmarks.append(item)
return sorted(benchmarks)
def get_problem_count(results_base_dir: str, model_name: str, benchmark: str) -> int:
"""ํŠน์ • ๋ฒค์น˜๋งˆํฌ์˜ ๋ฌธ์ œ ์ˆ˜ ๋ฐ˜ํ™˜"""
result_files = glob.glob(f"{results_base_dir}/{model_name}/{benchmark}/*.jsonl")
if not result_files:
return 0
try:
with open(result_files[0], 'r', encoding='utf-8') as f:
return len(f.readlines())
except:
return 0
def main():
parser = argparse.ArgumentParser(description="๋ชจ๋ธ๋ณ„ ํŠน์ • ๋ฌธ์ œ ์‘๋‹ต ๋น„๊ต")
parser.add_argument("--benchmark", "-b", type=str, required=True,
help="๋ฒค์น˜๋งˆํฌ ์ด๋ฆ„ (์˜ˆ: aime24, aime25, amc23, math500)")
parser.add_argument("--problem_idx", "-p", type=int, required=True,
help="๋ฌธ์ œ ์ธ๋ฑ์Šค (0๋ถ€ํ„ฐ ์‹œ์ž‘)")
parser.add_argument("--models", "-m", nargs="+", required=True,
help="๋น„๊ตํ•  ๋ชจ๋ธ ๋””๋ ‰ํ† ๋ฆฌ ์ด๋ฆ„๋“ค")
parser.add_argument("--results_dir", "-r", type=str,
default="/home/ubuntu/RLVR/Absolute-Zero-Reasoner/evaluation/math_eval/EVAL/results",
help="๊ฒฐ๊ณผ ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ")
parser.add_argument("--list", "-l", action="store_true",
help="์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ฒค์น˜๋งˆํฌ์™€ ๋ฌธ์ œ ์ˆ˜ ๋‚˜์—ด")
parser.add_argument("--full", "-f", action="store_true",
help="์ „์ฒด ์‘๋‹ต ์ถœ๋ ฅ (๊ธฐ๋ณธ๊ฐ’: True, ์ด์ œ ํ•ญ์ƒ ์ „์ฒด ์ถœ๋ ฅ)")
parser.add_argument("--output", "-o", type=str,
help="๊ฒฐ๊ณผ๋ฅผ ํŒŒ์ผ๋กœ ์ €์žฅ (์˜ˆ: --output result.txt)")
args = parser.parse_args()
if args.list:
print("๐Ÿ“‹ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์ •๋ณด:")
print("-" * 50)
# ์ฒซ ๋ฒˆ์งธ ๋ชจ๋ธ ๊ธฐ์ค€์œผ๋กœ ๋ฒค์น˜๋งˆํฌ ๋‚˜์—ด
if args.models:
benchmarks = list_available_benchmarks(args.results_dir, args.models[0])
for benchmark in benchmarks:
problem_count = get_problem_count(args.results_dir, args.models[0], benchmark)
print(f"โ€ข {benchmark}: {problem_count}๊ฐœ ๋ฌธ์ œ (์ธ๋ฑ์Šค 0-{problem_count-1})")
return
# ์ถœ๋ ฅ ํŒŒ์ผ ์„ค์ •
output_file = None
if args.output:
try:
output_file = open(args.output, 'w', encoding='utf-8')
# ํ‘œ์ค€ ์ถœ๋ ฅ์„ ํŒŒ์ผ๋กœ ๋ฆฌ๋””๋ ‰์…˜
original_stdout = sys.stdout
sys.stdout = output_file
print(f"๊ฒฐ๊ณผ๋ฅผ {args.output} ํŒŒ์ผ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค...")
# ํŒŒ์ผ์— ์ €์žฅ๋  ๋‚ด์šฉ์—์„œ๋Š” ์ด ๋ฉ”์‹œ์ง€๋ฅผ ์ œ๊ฑฐํ•˜๊ธฐ ์œ„ํ•ด ๋‹ค์‹œ ์›๋ž˜ stdout์œผ๋กœ ์ถœ๋ ฅ
sys.stdout = original_stdout
print(f"๐Ÿ“ ๊ฒฐ๊ณผ๋ฅผ {args.output} ํŒŒ์ผ์— ์ €์žฅํ•ฉ๋‹ˆ๋‹ค...")
sys.stdout = output_file
except Exception as e:
print(f"โŒ ํŒŒ์ผ ์ƒ์„ฑ ์˜ค๋ฅ˜: {e}")
return
try:
# ๋ชจ๋ธ ๋น„๊ต ์‹คํ–‰
compare_models_on_problem(
benchmark=args.benchmark,
problem_idx=args.problem_idx,
model_dirs=args.models,
results_base_dir=args.results_dir
)
finally:
# ํŒŒ์ผ ์ถœ๋ ฅ์ธ ๊ฒฝ์šฐ ์›๋ž˜ stdout ๋ณต์› ๋ฐ ํŒŒ์ผ ๋‹ซ๊ธฐ
if output_file:
sys.stdout = original_stdout
output_file.close()
print(f"โœ… ๊ฒฐ๊ณผ๊ฐ€ {args.output} ํŒŒ์ผ์— ์ €์žฅ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
print(f"๐Ÿ“– ํŒŒ์ผ ๋‚ด์šฉ ํ™•์ธ: cat {args.output}")
print(f"๐Ÿ“„ ํŒŒ์ผ ์—ด๊ธฐ: less {args.output}")
print(f"๐Ÿ” ํŠน์ • ๋ชจ๋ธ ๊ฒ€์ƒ‰: grep -A 10 'QWEN25_7B_BASE' {args.output}")
if __name__ == "__main__":
main()