Spaces:
Sleeping
Sleeping
File size: 3,747 Bytes
8de32f7 bf128d3 8de32f7 bf128d3 8de32f7 a54e1f4 093078e 7bd4b8c 093078e 8de32f7 a54e1f4 8de32f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
# -*- coding: utf-8 -*-
"""Medllama use.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1pZiJn21DK8U77WfKyxw94zNVYnxR40LP
"""
#!pip install transformers accelerate peft bitsandbytes gradio
#from huggingface_hub import notebook_login
import torch
#notebook_login()
import os
secret_key = os.getenv("AUTH")
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
config = PeftConfig.from_pretrained("tmberooney/medllama")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", use_auth_token=secret_key, load_in_4bit=True, torch_dtype=torch.float16, device_map="auto")
model = PeftModel.from_pretrained(model, "tmberooney/medllama")
tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = model.to('cuda:0')
"""### Using Gradio App"""
from transformers import pipeline
llama_pipeline = pipeline(
"text-generation", # LLM task
model=model,
torch_dtype=torch.float16,
device_map="auto",
tokenizer=tokenizer
)
SYSTEM_PROMPT = """<s>[INST] <<SYS>>
You are a helpful medical bot. Your answers are clear and concise with medical information.
<</SYS>>
"""
# Formatting function for message and history
def format_message(message: str, history: list, memory_limit: int = 3) -> str:
"""
Formats the message and history for the Llama model.
Parameters:
message (str): Current message to send.
history (list): Past conversation history.
memory_limit (int): Limit on how many past interactions to consider.
Returns:
str: Formatted message string
"""
# always keep len(history) <= memory_limit
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return SYSTEM_PROMPT + f"{message} [/INST]"
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
# Handle conversation history
for user_msg, model_answer in history[1:]:
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
# Handle the current message
formatted_message += f"<s>[INST] {message} [/INST]"
return formatted_message
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Generate a response from the Llama model
def get_model_response(message: str, history: list) -> str:
"""
Generates a conversational response from the Llama model.
Parameters:
message (str): User's input message.
history (list): Past conversation history.
Returns:
str: Generated response from the Llama model.
"""
query = format_message(message, history)
response = ""
sequences = llama_pipeline(
query,
generation_config = model.generation_config,
do_sample=True,
top_k=10,
streamer=streamer,
top_p=0.7,
temperature=0.7,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=1024,
)
generated_text = sequences[0]['generated_text']
response = generated_text[len(query):] # Remove the prompt from the output
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
import gradio as gr
gr.ChatInterface(fn=get_model_response,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
title="Medllama : The Medically Fine-tuned LLaMA-2").queue().launch()
#!gradio deploy
|