MoR / Planning /model.py
GagaLey's picture
framework
7bf4b88
raw
history blame
2.99 kB
import sys
import os
import torch.nn as nn
import ast
from unsloth import FastLanguageModel
from transformers import TextStreamer
import contractions
import re
from Planning.utils import remove_inner_single_quotes
class Planner(nn.Module):
def __init__(self, dataset_name):
super(Planner, self).__init__()
self.dataset_name = dataset_name
self.checkpoint_path = f"Planning/checkpoints/{dataset_name}/lora_model/"
self.max_seq_length = 2048
self.dtype = None
self.load_in_4bit = True
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = self.checkpoint_path,
max_seq_length = self.max_seq_length,
dtype = self.dtype,
load_in_4bit = self.load_in_4bit
)
FastLanguageModel.for_inference(model)
self.model = model
self.tokenizer = tokenizer
def forward(self, query):
message = {'content': query, 'role': 'user'}
inputs = self.tokenizer.apply_chat_template(
[message],
tokenize = True,
add_generation_prompt = True, # Must add for generation
return_tensors = "pt",
).to("cuda")
text_streamer = TextStreamer(self.tokenizer, skip_prompt = True)
outputs = self.model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 128, # max_new_tokens is the maximum number of new tokens generated beyond the input
use_cache = True, temperature = 1.5, min_p = 0.1) # min_p is a cumulative probability, which makes the generation more diverse
outputs = self.tokenizer.batch_decode(outputs)
parts = outputs[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")
if len(parts) > 1:
results = parts[1].replace("<|eot_id|>", "")
else:
raise ValueError
# ******* special processing for prime dataset
if self.dataset_name == 'prime':
try:
# Parse the string using ast.literal_eval
parsed_dict = ast.literal_eval(results)
return parsed_dict
except (SyntaxError, ValueError) as e:
print(f"Error parsing the string: {e}")
return {
"Metapath": "",
"Restriction": {}
}
results = contractions.fix(results)
try:
results = ast.literal_eval(results)
except:
print(f"Fail")
try:
results = re.sub(r"\['(.*?)'", remove_inner_single_quotes, results) # TODO: need optimize
results = ast.literal_eval(results)
except:
results = {
"Metapath": "",
"Restriction": {},
}
rg = results
return rg