File size: 3,128 Bytes
0d3411a
 
 
dd3dbad
 
 
0d3411a
dd3dbad
 
 
 
 
 
 
0d3411a
dd3dbad
0d3411a
 
 
dd3dbad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d3411a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer, BertForSequenceClassification
import torch
from sklearn.preprocessing import LabelEncoder

labels = ['мода', 'спорт', 'технологии', 'финансы', 'крипта']
label_encoder = LabelEncoder()
label_encoder.fit(labels)

# Загрузка сохраненной модели и токенизатора в Streamlit
loaded_model_path = "rubert-base-cased"
loaded_tokenizer_path = BertForSequenceClassification.from_pretrained(loaded_model_path)

# Инициализация модели и токенизатора
loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)

# Создание модели с архитектурой BertForSequenceClassification
# Передайте в аргумент `num_labels` количество классов, для которых модель будет выполнять классификацию
model = BertForSequenceClassification(num_labels=len(labels))

# Загрузка весов из сохраненного файла
weights_path = "model_weights_epoch_8.pt"
state_dict = torch.load(weights_path, map_location='cpu')  # Укажите 'cuda' вместо 'cpu', если используете GPU
model.load_state_dict(state_dict)

# Пример использования загруженной модели
user_input = "Ваш текст для классификации"
predicted_class = predict_class(user_input, model=model, tokenizer=loaded_tokenizer, label_encoder=label_encoder)
print(predicted_class)


# #Загрузка сохраненной модели и токенизатора в Streamlit
# loaded_model_path = "nlp_project/model"
# loaded_tokenizer_path = "nlp_project/tokenizer"

# loaded_model = BertForSequenceClassification.from_pretrained(loaded_model_path)
# loaded_tokenizer = BertTokenizer.from_pretrained(loaded_tokenizer_path)



def predict_class(user_input, model=loaded_model, tokenizer=loaded_tokenizer, label_encoder=label_encoder, max_length=128):
    if not user_input:
        return "Введите текст"
    def tokenize_text(text):
        encoded_text = tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=max_length,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return encoded_text

    encoded_text = tokenize_text(user_input)
    with torch.no_grad():
        model.eval()
        input_ids = encoded_text['input_ids']
        attention_mask = encoded_text['attention_mask']
        outputs = model(input_ids, attention_mask=attention_mask)
    logits = outputs.logits
    predicted_class_index = torch.argmax(logits, dim=1).item()
    
    # Получение названия класса
    predicted_class = label_encoder.classes_[predicted_class_index]
    return predicted_class