xnetba commited on
Commit
a1b778d
·
1 Parent(s): 3847bfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -10
app.py CHANGED
@@ -7,15 +7,22 @@ openchat_preprompt = (
7
  )
8
 
9
  def get_client(model: str):
10
- InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
 
 
 
11
 
12
  def get_usernames(model: str):
13
  """
14
  Returns:
15
  (str, str, str, str): pre-prompt, username, bot name, separator
16
  """
17
- if model in ("OpenAssistant/oasst-sft-1-pythia-12b"):
18
  return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
 
 
 
 
19
 
20
  def predict(
21
  model: str,
@@ -52,7 +59,7 @@ def predict(
52
 
53
  partial_words = ""
54
 
55
- if model in ("OpenAssistant/oasst-sft-1-pythia-12b"):
56
  iterator = client.generate_stream(
57
  total_inputs,
58
  typical_p=typical_p,
@@ -109,7 +116,7 @@ def radio_on_change(
109
  repetition_penalty,
110
  watermark,
111
  ):
112
- if value in ("OpenAssistant/oasst-sft-1-pythia-12b"):
113
  typical_p = typical_p.update(value=0.2, visible=True)
114
  top_p = top_p.update(visible=False)
115
  top_k = top_k.update(visible=False)
@@ -117,6 +124,14 @@ def radio_on_change(
117
  disclaimer = disclaimer.update(visible=False)
118
  repetition_penalty = repetition_penalty.update(visible=False)
119
  watermark = watermark.update(False)
 
 
 
 
 
 
 
 
120
  else:
121
  typical_p = typical_p.update(visible=False)
122
  top_p = top_p.update(value=0.95, visible=True)
@@ -136,18 +151,14 @@ def radio_on_change(
136
  )
137
 
138
 
139
- title = """<h1 align="center">LLM Chat</h1>"""
140
- description = """LLM Chat predložak:
141
- ```
142
- Mjenjajući predložak.
143
  """
144
 
145
  text_generation_inference = """
146
- <div align="center">Pokrenuto od: <a href=https://xnet.ba/>xnet.ba</a></div>
147
  """
148
 
149
  openchat_disclaimer = """
150
- <div align="center">Checkout the official <a href=https://xnet.ba>xChat app</a> for the full experience.</div>
151
  """
152
 
153
  with gr.Blocks(
@@ -162,6 +173,7 @@ with gr.Blocks(
162
  choices=[
163
  "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
164
  "OpenAssistant/oasst-sft-1-pythia-12b",
 
165
  ],
166
  label="Model",
167
  interactive=True,
 
7
  )
8
 
9
  def get_client(model: str):
10
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
11
+ return Client(os.getenv("OPENCHAT_API_URL"))
12
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
13
+
14
 
15
  def get_usernames(model: str):
16
  """
17
  Returns:
18
  (str, str, str, str): pre-prompt, username, bot name, separator
19
  """
20
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
21
  return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
22
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
23
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
24
+ return "", "User: ", "Assistant: ", "\n"
25
+
26
 
27
  def predict(
28
  model: str,
 
59
 
60
  partial_words = ""
61
 
62
+ if model in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
63
  iterator = client.generate_stream(
64
  total_inputs,
65
  typical_p=typical_p,
 
116
  repetition_penalty,
117
  watermark,
118
  ):
119
+ if value in ("OpenAssistant/oasst-sft-1-pythia-12b", "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5"):
120
  typical_p = typical_p.update(value=0.2, visible=True)
121
  top_p = top_p.update(visible=False)
122
  top_k = top_k.update(visible=False)
 
124
  disclaimer = disclaimer.update(visible=False)
125
  repetition_penalty = repetition_penalty.update(visible=False)
126
  watermark = watermark.update(False)
127
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
128
+ typical_p = typical_p.update(visible=False)
129
+ top_p = top_p.update(value=0.25, visible=True)
130
+ top_k = top_k.update(value=50, visible=True)
131
+ temperature = temperature.update(value=0.6, visible=True)
132
+ repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
133
+ watermark = watermark.update(False)
134
+ disclaimer = disclaimer.update(visible=True)
135
  else:
136
  typical_p = typical_p.update(visible=False)
137
  top_p = top_p.update(value=0.95, visible=True)
 
151
  )
152
 
153
 
154
+ title = """<h1 align="center">xChat</h1>"""
155
+ description = """
 
 
156
  """
157
 
158
  text_generation_inference = """
 
159
  """
160
 
161
  openchat_disclaimer = """
 
162
  """
163
 
164
  with gr.Blocks(
 
173
  choices=[
174
  "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
175
  "OpenAssistant/oasst-sft-1-pythia-12b",
176
+ "togethercomputer/GPT-NeoXT-Chat-Base-20B",
177
  ],
178
  label="Model",
179
  interactive=True,