File size: 2,539 Bytes
6f972a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import json
import pprint
def read_json(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
return data
def write_json(file_path, data):
with open(file_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=4)
data = read_json("DataSet/train_samples_all_tuning.json")
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "Model/QwQ-32B-Preview"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def chat_QwQ(prompt):
messages = [
{"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(response)
return response
# from transformers import MarianMTModel, MarianTokenizer
# model_name = "Model/opus-mt-en-zh"
# tokenizer = MarianTokenizer.from_pretrained(model_name)
# model = MarianMTModel.from_pretrained(model_name)
for i in data:
sent1 = i['conversations'][0]['value']
sent2 = i['conversations'][1]['value']
sentence = sent1 + sent2
prompt = "This is a question-answering datapoint based on image information. Determine whether the answer can be judged without relying on the image. If it can, this is considered bad data; if it requires the image, it is considered good data. Rate this datapoint on a scale from 1 (bad) to 5 (good). ###### " + sentence
answer = chat_QwQ(prompt)
# english_text = answer
# inputs = tokenizer.encode(english_text, return_tensors="pt", truncation=True)
# translated = model.generate(inputs, max_length=40, num_beams=4, early_stopping=True)
# chinese_translation = tokenizer.decode(translated[0], skip_special_tokens=True)
pprint.pprint(prompt)
pprint.pprint(answer)
break |