p1atdev commited on
Commit
dfbfb4f
·
verified ·
1 Parent(s): 3b0a6b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -14
app.py CHANGED
@@ -1,3 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
  import spaces
@@ -24,10 +39,6 @@ DEVICE = (
24
  BAD_WORD_KEYWORDS = ["(medium)"]
25
 
26
 
27
- def fix_compiled_state_dict(state_dict: dict):
28
- return {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
29
-
30
-
31
  def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
32
  ids = [
33
  [id]
@@ -38,17 +49,12 @@ def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
38
 
39
 
40
  def prepare_models():
41
- config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
42
- model = AutoModelForPreTraining.from_config(
43
- config, torch_dtype=torch.bfloat16, trust_remote_code=True
44
  )
45
  model.decoder_model.use_cache = True
46
  processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
47
 
48
- state_dict = load_file(MODEL_PATH)
49
- state_dict = {k.replace("._orig_mod.", "."): v for k, v in state_dict.items()}
50
- model.load_state_dict(state_dict)
51
-
52
  model.eval()
53
  model = model.to(DEVICE)
54
  # model = torch.compile(model)
@@ -60,11 +66,17 @@ def demo():
60
  model, processor = prepare_models()
61
  ban_ids = get_bad_words_ids(processor.decoder_tokenizer)
62
 
 
 
 
 
 
63
  @spaces.GPU(duration=5)
64
  @torch.inference_mode()
65
  def generate_tags(
66
  text: str,
67
  auto_detect: bool,
 
68
  copyright_tags: str = "",
69
  length: str = "short",
70
  max_new_tokens: int = 128,
@@ -77,7 +89,7 @@ def demo():
77
  "<|bos|>"
78
  f"<|aspect_ratio:tall|><|rating:general|><|length:{length}|>"
79
  "<|reserved_2|><|reserved_3|><|reserved_4|>"
80
- "<|translate:exact|><|input_end|>"
81
  "<copyright>" + copyright_tags.strip()
82
  )
83
  if not auto_detect:
@@ -146,6 +158,11 @@ def demo():
146
  ],
147
  value="short",
148
  )
 
 
 
 
 
149
  translate_btn = gr.Button(value="Translate", variant="primary")
150
 
151
  with gr.Accordion(label="Advanced", open=False):
@@ -174,7 +191,8 @@ def demo():
174
  )
175
 
176
  with gr.Column():
177
- output = gr.Textbox(label="Output", lines=4, interactive=False)
 
178
  time_elapsed = gr.Markdown(value="")
179
 
180
  gr.Examples(
@@ -239,6 +257,7 @@ def demo():
239
  inputs=[
240
  text,
241
  auto_detect,
 
242
  copyright_tags,
243
  length,
244
  max_new_tokens,
@@ -247,7 +266,7 @@ def demo():
247
  top_k,
248
  top_p,
249
  ],
250
- outputs=[output, time_elapsed],
251
  )
252
 
253
  ui.launch()
 
1
+ try:
2
+ import flash_attn
3
+ except:
4
+ import subprocess
5
+
6
+ print("Installing flash-attn...")
7
+ subprocess.run(
8
+ "pip install flash-attn --no-build-isolation",
9
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
10
+ shell=True,
11
+ )
12
+ import flash_attn
13
+
14
+ print("flash-attn installed.")
15
+
16
  import os
17
  import time
18
  import spaces
 
39
  BAD_WORD_KEYWORDS = ["(medium)"]
40
 
41
 
 
 
 
 
42
  def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast):
43
  ids = [
44
  [id]
 
49
 
50
 
51
  def prepare_models():
52
+ model = AutoModelForPreTraining.from_pretrained(
53
+ MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True
 
54
  )
55
  model.decoder_model.use_cache = True
56
  processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
57
 
 
 
 
 
58
  model.eval()
59
  model = model.to(DEVICE)
60
  # model = torch.compile(model)
 
66
  model, processor = prepare_models()
67
  ban_ids = get_bad_words_ids(processor.decoder_tokenizer)
68
 
69
+ translation_mode_map = {
70
+ "translate": "exact",
71
+ "translate+extend": "approx",
72
+ }
73
+
74
  @spaces.GPU(duration=5)
75
  @torch.inference_mode()
76
  def generate_tags(
77
  text: str,
78
  auto_detect: bool,
79
+ mode: str,
80
  copyright_tags: str = "",
81
  length: str = "short",
82
  max_new_tokens: int = 128,
 
89
  "<|bos|>"
90
  f"<|aspect_ratio:tall|><|rating:general|><|length:{length}|>"
91
  "<|reserved_2|><|reserved_3|><|reserved_4|>"
92
+ f"<|translate:{translation_mode_map[mode]}|><|input_end|>"
93
  "<copyright>" + copyright_tags.strip()
94
  )
95
  if not auto_detect:
 
158
  ],
159
  value="short",
160
  )
161
+ translation_mode = gr.Radio(
162
+ label="Translation mode",
163
+ choices=list(translation_mode_map.keys()),
164
+ value=list(translation_mode_map.keys())[0],
165
+ )
166
  translate_btn = gr.Button(value="Translate", variant="primary")
167
 
168
  with gr.Accordion(label="Advanced", open=False):
 
191
  )
192
 
193
  with gr.Column():
194
+ output_translation = gr.Textbox(label="Output (translation)", lines=4, interactive=False)
195
+ output_extension = gr.Textbox(label="Output (extension)", lines=4, interactive=False)
196
  time_elapsed = gr.Markdown(value="")
197
 
198
  gr.Examples(
 
257
  inputs=[
258
  text,
259
  auto_detect,
260
+ translation_mode,
261
  copyright_tags,
262
  length,
263
  max_new_tokens,
 
266
  top_k,
267
  top_p,
268
  ],
269
+ outputs=[output_translation, output_extension, time_elapsed],
270
  )
271
 
272
  ui.launch()