Spaces:
Sleeping
Sleeping
File size: 4,191 Bytes
8de32f7 bf128d3 8de32f7 bf128d3 8de32f7 a54e1f4 f582bba 093078e cb91778 093078e 2e856b1 8de32f7 2e856b1 e1f4bc5 483eafb cb91778 483eafb 8de32f7 33b5fef 8de32f7 e1f4bc5 8de32f7 596ede5 8de32f7 09372b4 8de32f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# -*- coding: utf-8 -*-
"""Medllama use.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1pZiJn21DK8U77WfKyxw94zNVYnxR40LP
"""
#!pip install transformers accelerate peft bitsandbytes gradio
#from huggingface_hub import notebook_login
import torch
#notebook_login()
import os
secret_key = os.getenv("HUGGINGFACE_ACCESS_TOKEN")
from torch import nn
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("tmberooney/medllama-merged")
model = AutoModelForCausalLM.from_pretrained("tmberooney/medllama-merged")
'''device_map = {"transformer.word_embeddings": "cpu",
"transformer.word_embeddings_layernorm": "cpu",
"lm_head": "cpu",
"transformer.h": "cpu",
"transformer.ln_f": "cpu"}
model = nn.DataParallel(model)
# Move the model parameters to the specified devices
for name, param in model.named_parameters():
if name in device_map:
param.data = param.to(device=device_map[name])
#model = model.to('cuda:0')'''
"""### Using Gradio App"""
from transformers import pipeline
#tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
llama_pipeline = pipeline(
"text-generation", # LLM task
model=model,
torch_dtype=torch.float16,
device_map="cpu",
tokenizer=tokenizer
)
SYSTEM_PROMPT = """<s>[INST] <<SYS>>
You are a Doctor, Who is trained to listen to patients' claims about what they have been feeling. You are to notice their symptoms. Answer in concise medical terminology in giving a diagnosis in no more than 50 words
<</SYS>>
"""
# Formatting function for message and history
def format_message(message: str, history: list, memory_limit: int = 3) -> str:
"""
Formats the message and history for the Llama model.
Parameters:
message (str): Current message to send.
history (list): Past conversation history.
memory_limit (int): Limit on how many past interactions to consider.
Returns:
str: Formatted message string
"""
# always keep len(history) <= memory_limit
if len(history) > memory_limit:
history = history[-memory_limit:]
if len(history) == 0:
return SYSTEM_PROMPT + f"{message} [/INST]"
formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
# Handle conversation history
for user_msg, model_answer in history[1:]:
formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
# Handle the current message
formatted_message += f"<s>[INST] {message} [/INST]"
return formatted_message
from transformers import TextIteratorStreamer
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
# Generate a response from the Llama model
def get_model_response(message: str, history: list) -> str:
"""
Generates a conversational response from the Llama model.
Parameters:
message (str): User's input message.
history (list): Past conversation history.
Returns:
str: Generated response from the Llama model.
"""
query = format_message(message, history)
response = ""
sequences = llama_pipeline(
query,
generation_config = model.generation_config,
do_sample=True,
top_k=10,
streamer=streamer,
top_p=0.7,
temperature=0.7,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
max_length=1024,
)
generated_text = sequences[0]['generated_text']
response = generated_text[len(query):] #emove the prompt from the output
partial_message = ""
for new_token in streamer:
if new_token != '<':
partial_message += new_token
yield partial_message
import gradio as gr
gr.ChatInterface(fn=get_model_response,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
title="Medllama : The Medically Fine-tuned LLaMA-2").queue().launch()
#!gradio deploy
|