kenchan0226 commited on
Commit
c17ea75
·
verified ·
1 Parent(s): d48ef54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, Qwen2TokenizerFast, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # SeaLLMs Med
12
+
13
+ SeaLLMs Med is a medical version of SeaLLMs
14
+ """
15
+
16
+ MAX_MAX_NEW_TOKENS = 2048
17
+ DEFAULT_MAX_NEW_TOKENS = 1024
18
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
19
+
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+
22
+ model_id = "SeaLLMs/SeaLLMs-v3-7B-Chat"
23
+ tokenizer = Qwen2TokenizerFast.from_pretrained(model_id)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ device_map="auto",
27
+ torch_dtype=torch.bfloat16,
28
+ )
29
+ model.config.sliding_window = 4096
30
+ model.eval()
31
+
32
+
33
+ @spaces.GPU(duration=90)
34
+ def generate(
35
+ message: str,
36
+ chat_history: list[tuple[str, str]],
37
+ max_new_tokens: int = 1024,
38
+ temperature: float = 0.6,
39
+ top_p: float = 0.9,
40
+ top_k: int = 50,
41
+ repetition_penalty: float = 1.2,
42
+ ) -> Iterator[str]:
43
+ conversation = []
44
+ for user, assistant in chat_history:
45
+ conversation.extend(
46
+ [
47
+ {"role": "user", "content": user},
48
+ {"role": "assistant", "content": assistant},
49
+ ]
50
+ )
51
+ conversation.append({"role": "user", "content": message})
52
+
53
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
54
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
55
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
56
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
57
+ input_ids = input_ids.to(model.device)
58
+
59
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
60
+ generate_kwargs = dict(
61
+ {"input_ids": input_ids},
62
+ streamer=streamer,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ temperature=temperature,
68
+ num_beams=1,
69
+ repetition_penalty=repetition_penalty,
70
+ )
71
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ t.start()
73
+
74
+ outputs = []
75
+ for text in streamer:
76
+ outputs.append(text)
77
+ yield "".join(outputs)
78
+
79
+
80
+ chat_interface = gr.ChatInterface(
81
+ fn=generate,
82
+ additional_inputs=[],
83
+ stop_btn=None,
84
+ examples=[
85
+ ["Explain the Symptoms of Covid-19"],
86
+ ],
87
+ cache_examples=False,
88
+ )
89
+
90
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
91
+ gr.Markdown(DESCRIPTION)
92
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
93
+ chat_interface.render()
94
+
95
+ if __name__ == "__main__":
96
+ demo.queue(max_size=20).launch()