thesis / model /mistral.py
LennardZuendorf's picture
feat/fix: several minor fixes and additions
d4dd3c5
raw
history blame
3.9 kB
# Mistral model module for chat interaction and model instance control
# external imports
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import gradio as gr
# internal imports
from utils import modelling as mdl
from utils import formatting as fmt
# global model and tokenizer instance (created on inital build)
device = mdl.get_device()
if device == torch.device("cuda"):
n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
MODEL = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2",
quantization_config=bnb_config,
device_map="auto", # dispatch efficiently the model on the available ressources
max_memory={i: max_memory for i in range(n_gpus)},
)
else:
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
MODEL.to(device)
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
# default model config
CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
CONFIG.update(**{
"temperature": 0.7,
"max_new_tokens": 50,
"max_length": 50,
"top_p": 0.9,
"repetition_penalty": 1.2,
"do_sample": True,
"seed": 42,
})
# function to (re) set config
def set_config(config_dict: dict):
# if config dict is not given, set to default
if config_dict == {}:
config_dict = {
"temperature": 0.7,
"max_new_tokens": 50,
"max_length": 50,
"top_p": 0.9,
"repetition_penalty": 1.2,
"do_sample": True,
"seed": 42,
}
CONFIG.update(**dict)
# advanced formatting function that takes into a account a conversation history
# CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
prompt = ""
if knowledge != "":
gr.Info("""
Mistral doesn't support additional knowledge, it's gonna be ignored.
""")
# if no history, use system prompt and example message
if len(history) == 0:
prompt = f"""
<s>[INST] {system_prompt} [/INST] How can I help you today? </s>
[INST] {message} [/INST]
"""
else:
# takes the very first exchange and the system prompt as base
prompt = f"""
<s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>
"""
# adds conversation history to the prompt
for conversation in history[1:]:
# takes all the following conversations and adds them as context
prompt += "".join(f"[INST] {conversation[0]} [/INST] {conversation[1]}</s>")
return prompt
# function to extract real answer because mistral always returns the full prompt
def format_answer(answer: str):
# empty answer string
formatted_answer = ""
# extracting text after INST tokens
parts = answer.split("[/INST]")
if len(parts) >= 3:
# Return the text after the second occurrence of [/INST]
formatted_answer = parts[2].strip()
else:
# Return an empty string if there are fewer than two occurrences of [/INST]
formatted_answer = ""
print(f"Cut {answer} into {formatted_answer}.")
return formatted_answer
def respond(prompt: str):
# tokenizing inputs and configuring model
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)
# generating text with tokenized input, returning output
output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
output_text = TOKENIZER.batch_decode(output_ids)
output_text = fmt.format_output_text(output_text)
return format_answer(output_text)