Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""Medllama use.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1pZiJn21DK8U77WfKyxw94zNVYnxR40LP | |
""" | |
#!pip install transformers accelerate peft bitsandbytes gradio | |
#from huggingface_hub import notebook_login | |
import torch | |
#notebook_login() | |
import os | |
secret_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN") | |
from torch import nn | |
# Load model directly | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
tokenizer = AutoTokenizer.from_pretrained("tmberooney/medllama-merged") | |
model = AutoModelForCausalLM.from_pretrained("tmberooney/medllama-merged") | |
device_map = {"transformer.word_embeddings": "cpu", | |
"transformer.word_embeddings_layernorm": "cpu", | |
"lm_head": "cpu", | |
"transformer.h": "cpu", | |
"transformer.ln_f": "cpu"} | |
model = nn.DataParallel(model) | |
# Move the model parameters to the specified devices | |
for name, param in model.named_parameters(): | |
if name in device_map: | |
param.data = param.to(device=device_map[name]) | |
#model = model.to('cuda:0') | |
"""### Using Gradio App""" | |
from transformers import pipeline | |
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
llama_pipeline = pipeline( | |
"text-generation", # LLM task | |
model=model, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
tokenizer=tokenizer | |
) | |
SYSTEM_PROMPT = """<s>[INST] <<SYS>> | |
ou are a Doctor, Who is trained to listen to patients claims about what they have been feeling. You are to notice their sypmtoms, the medications they have been taking and their medical history. Based on this You have made a appropriate Diagnosis. | |
<</SYS>> | |
""" | |
# Formatting function for message and history | |
def format_message(message: str, history: list, memory_limit: int = 3) -> str: | |
""" | |
Formats the message and history for the Llama model. | |
Parameters: | |
message (str): Current message to send. | |
history (list): Past conversation history. | |
memory_limit (int): Limit on how many past interactions to consider. | |
Returns: | |
str: Formatted message string | |
""" | |
# always keep len(history) <= memory_limit | |
if len(history) > memory_limit: | |
history = history[-memory_limit:] | |
if len(history) == 0: | |
return SYSTEM_PROMPT + f"{message} [/INST]" | |
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>" | |
# Handle conversation history | |
for user_msg, model_answer in history[1:]: | |
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>" | |
# Handle the current message | |
formatted_message += f"<s>[INST] {message} [/INST]" | |
return formatted_message | |
from transformers import TextIteratorStreamer | |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True) | |
# Generate a response from the Llama model | |
def get_model_response(message: str, history: list) -> str: | |
""" | |
Generates a conversational response from the Llama model. | |
Parameters: | |
message (str): User's input message. | |
history (list): Past conversation history. | |
Returns: | |
str: Generated response from the Llama model. | |
""" | |
query = format_message(message, history) | |
response = "" | |
sequences = llama_pipeline( | |
query, | |
generation_config = model.generation_config, | |
do_sample=True, | |
top_k=10, | |
streamer=streamer, | |
top_p=0.7, | |
temperature=0.7, | |
num_return_sequences=1, | |
eos_token_id=tokenizer.eos_token_id, | |
max_length=1024, | |
) | |
generated_text = sequences[0]['generated_text'] | |
response = generated_text[len(query):] #emove the prompt from the output | |
partial_message = "" | |
for new_token in streamer: | |
if new_token != '<': | |
partial_message += new_token | |
yield partial_message | |
import gradio as gr | |
gr.ChatInterface(fn=get_model_response, | |
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"), | |
title="Medllama : The Medically Fine-tuned LLaMA-2").queue().launch() | |
#!gradio deploy | |