Dmtlant commited on
Commit
02eca50
·
verified ·
1 Parent(s): 48052da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -37
app.py CHANGED
@@ -1,41 +1,14 @@
1
- import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
- # Загрузка модели и токенизатора
5
- model_name = "openvla/openvla-7b"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
8
 
9
- # Перемещение модели на GPU, если доступно
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- model = model.to(device)
12
 
13
- def generate_response(prompt):
14
- # Токенизация входного текста
15
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
16
-
17
- # Генерация ответа
18
- with torch.no_grad():
19
- outputs = model.generate(
20
- **inputs,
21
- max_new_tokens=100,
22
- do_sample=True,
23
- temperature=0.7,
24
- top_p=0.9
25
- )
26
-
27
- # Декодирование и возврат ответа
28
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
29
- return response.strip()
30
 
31
- # Основной цикл чат-бота
32
- print("Чат-бот готов! Введите 'выход' для завершения.")
33
- while True:
34
- user_input = input("Вы: ")
35
- if user_input.lower() == 'выход':
36
- break
37
-
38
- response = generate_response(user_input)
39
- print("Бот:", response)
40
-
41
- print("До свидания!")
 
1
+ from gliclass import GLiClassModel, ZeroShotClassificationPipeline
2
+ from transformers import AutoTokenizer
3
 
4
+ model = GLiClassModel.from_pretrained("knowledgator/gliclass-large-v1.0")
5
+ tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-large-v1.0")
 
 
6
 
7
+ pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
 
 
8
 
9
+ text = "One day I will see the world!"
10
+ labels = ["travel", "dreams", "sport", "science", "politics"]
11
+ results = pipeline(text, labels, threshold=0.5)[0] #because we have one text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ for result in results:
14
+ print(result["label"], "=>", result["score"])