Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -53,16 +53,13 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
53 |
|
54 |
model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
|
55 |
|
56 |
-
class ChatBot:
|
57 |
-
def __init__(self):
|
58 |
-
self.history = []
|
59 |
|
60 |
class ChatBot:
|
61 |
def __init__(self):
|
62 |
# Initialize the ChatBot class with an empty history
|
63 |
self.history = []
|
64 |
|
65 |
-
def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
|
66 |
# Combine the user's input with the system prompt
|
67 |
formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
|
68 |
|
@@ -70,7 +67,7 @@ class ChatBot:
|
|
70 |
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
|
71 |
|
72 |
# Generate a response using the PEFT model
|
73 |
-
response =
|
74 |
|
75 |
# Decode the generated response to text
|
76 |
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
|
|
|
53 |
|
54 |
model = AutoModelForCausalLM.from_pretrained(model_id , torch_dtype=torch.float16 , device_map= "auto" )
|
55 |
|
|
|
|
|
|
|
56 |
|
57 |
class ChatBot:
|
58 |
def __init__(self):
|
59 |
# Initialize the ChatBot class with an empty history
|
60 |
self.history = []
|
61 |
|
62 |
+
def predict(self, user_input, system_prompt="You are an expert medical analyst:" , example_instruction="produce a json", example_answer = "please dont make small talk "):
|
63 |
# Combine the user's input with the system prompt
|
64 |
formatted_input = f"<s> [INST] {example_instruction} [/INST] {example_answer}</s> [INST] {system_prompt} [/INST]"
|
65 |
|
|
|
67 |
user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
|
68 |
|
69 |
# Generate a response using the PEFT model
|
70 |
+
response = model.generate(input_ids=user_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
|
71 |
|
72 |
# Decode the generated response to text
|
73 |
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
|