Spaces:
Sleeping
Sleeping
File size: 3,684 Bytes
fd3967b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# -*- coding: utf-8 -*-
"""Medllama use.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1pZiJn21DK8U77WfKyxw94zNVYnxR40LP
"""
#!pip install transformers accelerate peft bitsandbytes gradio
from huggingface_hub import notebook_login
import torch
notebook_login()
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
config = PeftConfig.from_pretrained("tmberooney/medllama")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf",load_in_4bit=True, torch_dtype=torch.float16, device_map="auto")
model = PeftModel.from_pretrained(model, "tmberooney/medllama")
tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = model.to('cuda:0')
"""### Using Gradio App"""
from transformers import pipeline
llama_pipeline = pipeline(
"text-generation", # LLM task
model=model,
torch_dtype=torch.float16,
device_map="auto",
tokenizer=tokenizer
)
SYSTEM_PROMPT = """<s>[INST] <<SYS>>
You are a helpful medical bot. Your answers are clear and concise with medical information.
<</SYS>>
"""
# Formatting function for message and history
def format_message(message: str, history: list, memory_limit: int = 3) -> str:
"""
Formats the message and history for the Llama model.
Parameters:
message (str): Current message to send.
history (list): Past conversation history.
memory_limit (int): Limit on how many past interactions to consider.
Returns:
str: Formatted message string
"""
# always keep len(history) <= memory_limit
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return SYSTEM_PROMPT + f"{message} [/INST]"
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
# Handle conversation history
for user_msg, model_answer in history[1:]:
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
# Handle the current message
formatted_message += f"<s>[INST] {message} [/INST]"
return formatted_message
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Generate a response from the Llama model
def get_model_response(message: str, history: list) -> str:
"""
Generates a conversational response from the Llama model.
Parameters:
message (str): User's input message.
history (list): Past conversation history.
Returns:
str: Generated response from the Llama model.
"""
query = format_message(message, history)
response = ""
sequences = llama_pipeline(
query,
generation_config = model.generation_config,
do_sample=True,
top_k=10,
streamer=streamer,
top_p=0.7,
temperature=0.7,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=1024,
)
generated_text = sequences[0]['generated_text']
response = generated_text[len(query):] # Remove the prompt from the output
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
import gradio as gr
gr.ChatInterface(fn=get_model_response,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
title="Medllama : The Medically Fine-tuned LLaMA-2").queue().launch()
!gradio deploy
|