tmberooney commited on
Commit
fd3967b
·
1 Parent(s): 85d95c8

Upload medllama_use.py

Browse files
Files changed (1) hide show
  1. medllama_use.py +126 -0
medllama_use.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Medllama use.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1pZiJn21DK8U77WfKyxw94zNVYnxR40LP
8
+ """
9
+
10
+ #!pip install transformers accelerate peft bitsandbytes gradio
11
+
12
+ from huggingface_hub import notebook_login
13
+ import torch
14
+ notebook_login()
15
+
16
+ import torch
17
+ from peft import PeftModel, PeftConfig
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+
20
+ config = PeftConfig.from_pretrained("tmberooney/medllama")
21
+ model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf",load_in_4bit=True, torch_dtype=torch.float16, device_map="auto")
22
+ model = PeftModel.from_pretrained(model, "tmberooney/medllama")
23
+ tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
24
+ model = model.to('cuda:0')
25
+
26
+
27
+
28
+
29
+
30
+ """### Using Gradio App"""
31
+
32
+ from transformers import pipeline
33
+
34
+ llama_pipeline = pipeline(
35
+ "text-generation", # LLM task
36
+ model=model,
37
+ torch_dtype=torch.float16,
38
+ device_map="auto",
39
+ tokenizer=tokenizer
40
+ )
41
+
42
+ SYSTEM_PROMPT = """<s>[INST] <<SYS>>
43
+ You are a helpful medical bot. Your answers are clear and concise with medical information.
44
+ <</SYS>>
45
+
46
+ """
47
+
48
+ # Formatting function for message and history
49
+ def format_message(message: str, history: list, memory_limit: int = 3) -> str:
50
+ """
51
+ Formats the message and history for the Llama model.
52
+
53
+ Parameters:
54
+ message (str): Current message to send.
55
+ history (list): Past conversation history.
56
+ memory_limit (int): Limit on how many past interactions to consider.
57
+
58
+ Returns:
59
+ str: Formatted message string
60
+ """
61
+ # always keep len(history) <= memory_limit
62
+ if len(history) > memory_limit:
63
+ history = history[-memory_limit:]
64
+
65
+ if len(history) == 0:
66
+ return SYSTEM_PROMPT + f"{message} [/INST]"
67
+
68
+ formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
69
+
70
+ # Handle conversation history
71
+ for user_msg, model_answer in history[1:]:
72
+ formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
73
+
74
+ # Handle the current message
75
+ formatted_message += f"<s>[INST] {message} [/INST]"
76
+
77
+ return formatted_message
78
+
79
+ from transformers import TextIteratorStreamer
80
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
81
+
82
+ # Generate a response from the Llama model
83
+ def get_model_response(message: str, history: list) -> str:
84
+ """
85
+ Generates a conversational response from the Llama model.
86
+
87
+ Parameters:
88
+ message (str): User's input message.
89
+ history (list): Past conversation history.
90
+
91
+ Returns:
92
+ str: Generated response from the Llama model.
93
+ """
94
+ query = format_message(message, history)
95
+ response = ""
96
+
97
+ sequences = llama_pipeline(
98
+ query,
99
+ generation_config = model.generation_config,
100
+ do_sample=True,
101
+ top_k=10,
102
+ streamer=streamer,
103
+ top_p=0.7,
104
+ temperature=0.7,
105
+ num_return_sequences=1,
106
+ eos_token_id=tokenizer.eos_token_id,
107
+ max_length=1024,
108
+ )
109
+
110
+ generated_text = sequences[0]['generated_text']
111
+ response = generated_text[len(query):] # Remove the prompt from the output
112
+
113
+ partial_message = ""
114
+ for new_token in streamer:
115
+ if new_token != '<':
116
+ partial_message += new_token
117
+ yield partial_message
118
+
119
+ import gradio as gr
120
+
121
+ gr.ChatInterface(fn=get_model_response,
122
+ chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
123
+ title="Medllama : The Medically Fine-tuned LLaMA-2").queue().launch()
124
+
125
+ !gradio deploy
126
+