midmb-kculture-qa / 4bit_test.py
sooh098's picture
Upload folder using huggingface_hub
9f9d745 verified
import json
import os
import torch
import random
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# === 시드 고정 ===
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f"🔒 시드 고정 완료 (seed={seed})")
# === 설정 ===
fp16_ckpt = "sooh098/midmb-kculture-qa"
test_json_path = "../data/korean_culture_qa_V1.0_test+.json"
test_jsonl_path = "../data/test.jsonl"
output_json_path = "../data/output.json"
use_cuda = torch.cuda.is_available()
# === 4bit 설정 ===
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
# === 모델 및 토크나이저 로딩 ===
print("🚀 4bit 모델 로드 중...")
tokenizer = AutoTokenizer.from_pretrained(
fp16_ckpt,
trust_remote_code=True,
use_fast=True
)
model = AutoModelForCausalLM.from_pretrained(
fp16_ckpt,
quantization_config=bnb_config,
device_map={"": 0},
torch_dtype=torch.float16 if use_cuda else torch.float32,
trust_remote_code=True
)
model.eval()
print("✅ 4bit 모델 로드 완료")
# === 예측 시작 ===
predictions = []
with open(test_jsonl_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for i, line in enumerate(tqdm(lines, desc="🔍 예측 생성 중", unit="샘플")):
sample = json.loads(line)
input_text = sample["input"]
instruction = sample.get("instruction", "")
# Midm 프롬프트 포맷
prompt = (
"<|begin_of_text|>\n"
f"<|start_header_id|>system<|end_header_id|>\n{instruction}\n"
"<|eot_id|>\n"
f"<|start_header_id|>user<|end_header_id|>\n{input_text}\n"
"<|eot_id|>\n"
"<|start_header_id|>assistant<|end_header_id|>\n"
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
if "token_type_ids" in inputs:
del inputs["token_type_ids"]
with torch.no_grad():
# 시드 고정 (매번 동일한 결과를 위해)
torch.manual_seed(seed)
output_ids = model.generate(
**inputs,
max_new_tokens=256,
repetition_penalty=1.05,
temperature=0.28,
top_p=0.85,
top_k=20,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
raw_response = output_text.split("<|start_header_id|>assistant<|end_header_id|>\n")[-1]
prediction = raw_response.split("<|end_of_text|>")[0].strip()
predictions.append(prediction)
print(f"\n📝 샘플 {i + 1}")
print(f"🤖 예측: {prediction}")
# === 원본 JSON 불러오기 ===
with open(test_json_path, "r", encoding="utf-8") as f:
test_data = json.load(f)
# === 예측 결과 추가 ===
for i, item in enumerate(test_data):
answer = predictions[i] if i < len(predictions) else ""
item["output"] = {"answer": answer}
# === 결과 저장 ===
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)
with open(output_json_path, "w", encoding="utf-8") as f:
json.dump(test_data, f, ensure_ascii=False, indent=2)
print(f"\n✅ 최종 결과가 '{output_json_path}'에 저장되었습니다.")