amurienne commited on
Commit
a6c80c9
·
verified ·
1 Parent(s): 124b836

enabling zerogpu

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -20,24 +20,31 @@
20
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
  # THE SOFTWARE.
22
 
 
 
 
23
  import gradio as gr
24
 
25
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
26
 
 
 
 
27
  fw_modelcard = "amurienne/gallek-m2m100"
28
  bw_modelcard = "amurienne/kellag-m2m100"
29
 
30
  fw_model = AutoModelForSeq2SeqLM.from_pretrained(fw_modelcard)
31
  fw_tokenizer = AutoTokenizer.from_pretrained(fw_modelcard)
32
 
33
- fw_translation_pipeline = pipeline("translation", model=fw_model, tokenizer=fw_tokenizer, src_lang='fr', tgt_lang='br', max_length=400, device="cpu")
34
 
35
  bw_model = AutoModelForSeq2SeqLM.from_pretrained(bw_modelcard)
36
  bw_tokenizer = AutoTokenizer.from_pretrained(bw_modelcard)
37
 
38
- bw_translation_pipeline = pipeline("translation", model=bw_model, tokenizer=bw_tokenizer, src_lang='br', tgt_lang='fr', max_length=400, device="cpu")
39
 
40
  # translation function
 
41
  def translate(text, direction):
42
  if direction == "fr_to_br":
43
  return fw_translation_pipeline("traduis de français en breton: " + text)[0]['translation_text']
 
20
  # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
  # THE SOFTWARE.
22
 
23
+ import spaces
24
+ import torch
25
+
26
  import gradio as gr
27
 
28
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
29
 
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ print(f"current device is: {device}")
32
+
33
  fw_modelcard = "amurienne/gallek-m2m100"
34
  bw_modelcard = "amurienne/kellag-m2m100"
35
 
36
  fw_model = AutoModelForSeq2SeqLM.from_pretrained(fw_modelcard)
37
  fw_tokenizer = AutoTokenizer.from_pretrained(fw_modelcard)
38
 
39
+ fw_translation_pipeline = pipeline("translation", model=fw_model, tokenizer=fw_tokenizer, src_lang='fr', tgt_lang='br', max_length=400, device=device)
40
 
41
  bw_model = AutoModelForSeq2SeqLM.from_pretrained(bw_modelcard)
42
  bw_tokenizer = AutoTokenizer.from_pretrained(bw_modelcard)
43
 
44
+ bw_translation_pipeline = pipeline("translation", model=bw_model, tokenizer=bw_tokenizer, src_lang='br', tgt_lang='fr', max_length=400, device=device)
45
 
46
  # translation function
47
+ @spaces.GPU
48
  def translate(text, direction):
49
  if direction == "fr_to_br":
50
  return fw_translation_pipeline("traduis de français en breton: " + text)[0]['translation_text']