Truck2 / QWQ_infer.py
Wendy-Fly's picture
Upload QWQ_infer.py with huggingface_hub
6f972a2 verified
import json
import pprint
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
data = read_json("DataSet/train_samples_all_tuning.json")
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Model/QwQ-32B-Preview"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def chat_QwQ(prompt):
messages = [
{"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(response)
return response
# from transformers import MarianMTModel, MarianTokenizer
# model_name = "Model/opus-mt-en-zh"
# tokenizer = MarianTokenizer.from_pretrained(model_name)
# model = MarianMTModel.from_pretrained(model_name)
for i in data:
sent1 = i['conversations'][0]['value']
sent2 = i['conversations'][1]['value']
sentence = sent1 + sent2
prompt = "This is a question-answering datapoint based on image information. Determine whether the answer can be judged without relying on the image. If it can, this is considered bad data; if it requires the image, it is considered good data. Rate this datapoint on a scale from 1 (bad) to 5 (good). ###### " + sentence
answer = chat_QwQ(prompt)
# english_text = answer
# inputs = tokenizer.encode(english_text, return_tensors="pt", truncation=True)
# translated = model.generate(inputs, max_length=40, num_beams=4, early_stopping=True)
# chinese_translation = tokenizer.decode(translated[0], skip_special_tokens=True)
pprint.pprint(prompt)
pprint.pprint(answer)
break