vic35get commited on
Commit
ea69137
·
verified ·
1 Parent(s): 6b75a34

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -9
app.py CHANGED
@@ -1,21 +1,19 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
- import torch
4
 
5
  # Carregar o modelo e o tokenizer
6
  model_name = "vic35get/nhtsa_complaints_classifier"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
9
 
10
  # Função para inferência
11
- def predict(text):
12
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
13
- with torch.no_grad():
14
- outputs = model(**inputs)
15
- return torch.argmax(outputs.logits, dim=1).item()
16
-
17
  # Interface Gradio
18
  iface = gr.Interface(fn=predict, inputs="text", outputs="text")
19
 
20
  # Rodar a interface
21
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
3
 
4
  # Carregar o modelo e o tokenizer
5
  model_name = "vic35get/nhtsa_complaints_classifier"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
+ pipeline_clf = pipeline("text-classification", tokenizer=tokenizer, model=model)
9
 
10
  # Função para inferência
11
+ def predict(text: str):
12
+ classification = pipeline_clf(text)[0]
13
+ return classification.get('label')
14
+
 
 
15
  # Interface Gradio
16
  iface = gr.Interface(fn=predict, inputs="text", outputs="text")
17
 
18
  # Rodar a interface
19
+ iface.launch()