MedLLama_Demo / app.py
tmberooney's picture
Update app.py
2e856b1
raw
history blame
4.22 kB
# -*- coding: utf-8 -*-
"""Medllama use.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1pZiJn21DK8U77WfKyxw94zNVYnxR40LP
"""
#!pip install transformers accelerate peft bitsandbytes gradio
#from huggingface_hub import notebook_login
import torch
#notebook_login()
import os
secret_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN")
from torch import nn
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("tmberooney/medllama-merged")
model = AutoModelForCausalLM.from_pretrained("tmberooney/medllama-merged")
device_map = {"transformer.word_embeddings": "cpu",
"transformer.word_embeddings_layernorm": "cpu",
"lm_head": "cpu",
"transformer.h": "cpu",
"transformer.ln_f": "cpu"}
model = nn.DataParallel(model)
# Move the model parameters to the specified devices
for name, param in model.named_parameters():
if name in device_map:
param.data = param.to(device=device_map[name])
#model = model.to('cuda:0')
"""### Using Gradio App"""
from transformers import pipeline
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
llama_pipeline = pipeline(
"text-generation", # LLM task
model=model,
torch_dtype=torch.float16,
device_map="auto",
tokenizer=tokenizer
)
SYSTEM_PROMPT = """<s>[INST] <<SYS>>
ou are a Doctor, Who is trained to listen to patients claims about what they have been feeling. You are to notice their sypmtoms, the medications they have been taking and their medical history. Based on this You have made a appropriate Diagnosis.
<</SYS>>
"""
# Formatting function for message and history
def format_message(message: str, history: list, memory_limit: int = 3) -> str:
"""
Formats the message and history for the Llama model.
Parameters:
message (str): Current message to send.
history (list): Past conversation history.
memory_limit (int): Limit on how many past interactions to consider.
Returns:
str: Formatted message string
"""
# always keep len(history) <= memory_limit
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return SYSTEM_PROMPT + f"{message} [/INST]"
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
# Handle conversation history
for user_msg, model_answer in history[1:]:
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
# Handle the current message
formatted_message += f"<s>[INST] {message} [/INST]"
return formatted_message
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Generate a response from the Llama model
def get_model_response(message: str, history: list) -> str:
"""
Generates a conversational response from the Llama model.
Parameters:
message (str): User's input message.
history (list): Past conversation history.
Returns:
str: Generated response from the Llama model.
"""
query = format_message(message, history)
response = ""
sequences = llama_pipeline(
query,
generation_config = model.generation_config,
do_sample=True,
top_k=10,
streamer=streamer,
top_p=0.7,
temperature=0.7,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=1024,
)
generated_text = sequences[0]['generated_text']
response = generated_text[len(query):] #emove the prompt from the output
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
import gradio as gr
gr.ChatInterface(fn=get_model_response,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
title="Medllama : The Medically Fine-tuned LLaMA-2").queue().launch()
#!gradio deploy