Update src/streamlit_app.py
Browse files- src/streamlit_app.py +2 -2
src/streamlit_app.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
import os
|
4 |
import pickle
|
5 |
import torch.nn.functional as F
|
6 |
-
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
import asyncio
|
8 |
import sys
|
9 |
|
@@ -27,7 +27,7 @@ def load_prediction_model():
|
|
27 |
label_encoder = pickle.load(f)
|
28 |
id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}
|
29 |
|
30 |
-
model =
|
31 |
# model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
|
32 |
model.eval()
|
33 |
return tokenizer, model, id_to_class
|
|
|
3 |
import os
|
4 |
import pickle
|
5 |
import torch.nn.functional as F
|
6 |
+
from transformers import BertTokenizer, BertForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
|
7 |
import asyncio
|
8 |
import sys
|
9 |
|
|
|
27 |
label_encoder = pickle.load(f)
|
28 |
id_to_class = {idx: class_name for idx, class_name in enumerate(label_encoder.classes_)}
|
29 |
|
30 |
+
model = AutoModelForSequenceClassification.from_pretrained("Divyanshu04/Issue_categorizer")
|
31 |
# model.load_state_dict(torch.load('Divyanshu04/Issue_categorizer', map_location=torch.device('cpu'))['model_state_dict'])
|
32 |
model.eval()
|
33 |
return tokenizer, model, id_to_class
|