Update app.py
Browse files
app.py
CHANGED
@@ -7,15 +7,22 @@ openchat_preprompt = (
|
|
7 |
)
|
8 |
|
9 |
def get_client(model: str):
|
10 |
-
|
|
|
|
|
|
|
11 |
|
12 |
def get_usernames(model: str):
|
13 |
"""
|
14 |
Returns:
|
15 |
(str, str, str, str): pre-prompt, username, bot name, separator
|
16 |
"""
|
17 |
-
if model in ("OpenAssistant/oasst-sft-1-pythia-12b"):
|
18 |
return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
|
|
|
|
|
|
|
|
|
19 |
|
20 |
def predict(
|
21 |
model: str,
|
@@ -52,7 +59,7 @@ def predict(
|
|
52 |
|
53 |
partial_words = ""
|
54 |
|
55 |
-
if model in ("OpenAssistant/oasst-sft-1-pythia-12b"):
|
56 |
iterator = client.generate_stream(
|
57 |
total_inputs,
|
58 |
typical_p=typical_p,
|
@@ -109,7 +116,7 @@ def radio_on_change(
|
|
109 |
repetition_penalty,
|
110 |
watermark,
|
111 |
):
|
112 |
-
if value in ("OpenAssistant/oasst-sft-1-pythia-12b"):
|
113 |
typical_p = typical_p.update(value=0.2, visible=True)
|
114 |
top_p = top_p.update(visible=False)
|
115 |
top_k = top_k.update(visible=False)
|
@@ -117,6 +124,14 @@ def radio_on_change(
|
|
117 |
disclaimer = disclaimer.update(visible=False)
|
118 |
repetition_penalty = repetition_penalty.update(visible=False)
|
119 |
watermark = watermark.update(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
else:
|
121 |
typical_p = typical_p.update(visible=False)
|
122 |
top_p = top_p.update(value=0.95, visible=True)
|
@@ -136,18 +151,14 @@ def radio_on_change(
|
|
136 |
)
|
137 |
|
138 |
|
139 |
-
title = """<h1 align="center">
|
140 |
-
description = """
|
141 |
-
```
|
142 |
-
Mjenjajući predložak.
|
143 |
"""
|
144 |
|
145 |
text_generation_inference = """
|
146 |
-
<div align="center">Pokrenuto od: <a href=https://xnet.ba/>xnet.ba</a></div>
|
147 |
"""
|
148 |
|
149 |
openchat_disclaimer = """
|
150 |
-
<div align="center">Checkout the official <a href=https://xnet.ba>xChat app</a> for the full experience.</div>
|
151 |
"""
|
152 |
|
153 |
with gr.Blocks(
|
@@ -162,6 +173,7 @@ with gr.Blocks(
|
|
162 |
choices=[
|
163 |
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
164 |
"OpenAssistant/oasst-sft-1-pythia-12b",
|
|
|
165 |
],
|
166 |
label="Model",
|
167 |
interactive=True,
|
|
|
7 |
)
|
8 |
|
9 |
def get_client(model: str):
|
10 |
+
if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
|
11 |
+
return Client(os.getenv("OPENCHAT_API_URL"))
|
12 |
+
return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
|
13 |
+
|
14 |
|
15 |
def get_usernames(model: str):
|
16 |
"""
|
17 |
Returns:
|
18 |
(str, str, str, str): pre-prompt, username, bot name, separator
|
19 |
"""
|
20 |
+
if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
|
21 |
return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
|
22 |
+
if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
|
23 |
+
return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
|
24 |
+
return "", "User: ", "Assistant: ", "\n"
|
25 |
+
|
26 |
|
27 |
def predict(
|
28 |
model: str,
|
|
|
59 |
|
60 |
partial_words = ""
|
61 |
|
62 |
+
if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
|
63 |
iterator = client.generate_stream(
|
64 |
total_inputs,
|
65 |
typical_p=typical_p,
|
|
|
116 |
repetition_penalty,
|
117 |
watermark,
|
118 |
):
|
119 |
+
if value in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
|
120 |
typical_p = typical_p.update(value=0.2, visible=True)
|
121 |
top_p = top_p.update(visible=False)
|
122 |
top_k = top_k.update(visible=False)
|
|
|
124 |
disclaimer = disclaimer.update(visible=False)
|
125 |
repetition_penalty = repetition_penalty.update(visible=False)
|
126 |
watermark = watermark.update(False)
|
127 |
+
elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
|
128 |
+
typical_p = typical_p.update(visible=False)
|
129 |
+
top_p = top_p.update(value=0.25, visible=True)
|
130 |
+
top_k = top_k.update(value=50, visible=True)
|
131 |
+
temperature = temperature.update(value=0.6, visible=True)
|
132 |
+
repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
|
133 |
+
watermark = watermark.update(False)
|
134 |
+
disclaimer = disclaimer.update(visible=True)
|
135 |
else:
|
136 |
typical_p = typical_p.update(visible=False)
|
137 |
top_p = top_p.update(value=0.95, visible=True)
|
|
|
151 |
)
|
152 |
|
153 |
|
154 |
+
title = """<h1 align="center">xChat</h1>"""
|
155 |
+
description = """
|
|
|
|
|
156 |
"""
|
157 |
|
158 |
text_generation_inference = """
|
|
|
159 |
"""
|
160 |
|
161 |
openchat_disclaimer = """
|
|
|
162 |
"""
|
163 |
|
164 |
with gr.Blocks(
|
|
|
173 |
choices=[
|
174 |
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
175 |
"OpenAssistant/oasst-sft-1-pythia-12b",
|
176 |
+
"togethercomputer/GPT-NeoXT-Chat-Base-20B",
|
177 |
],
|
178 |
label="Model",
|
179 |
interactive=True,
|