Foxy4377 commited on
Commit
f53a7b1
·
verified ·
1 Parent(s): cb346ef

Create obyzala.py

Browse files
Files changed (1) hide show
  1. obyzala.py +56 -0
obyzala.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, Trainer, TrainingArguments
2
+ from datasets import Dataset
3
+ import requests
4
+
5
+ pisyn = requests.get("https://raw.githubusercontent.com/Fixyres/FHeta/refs/heads/main/modules.json")
6
+ data = [
7
+ {"question": "Какая твоя база данных модулей? И по какой базе ты ищешь все модули?", "answer": pisyn.text}
8
+ ]
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
11
+ model = AutoModelForQuestionAnswering.from_pretrained("distilbert-base-uncased")
12
+
13
+ dataset = Dataset.from_dict(data)
14
+
15
+ def preprocess_function(examples):
16
+ questions = examples["question"]
17
+ answers = examples["answer"]
18
+ inputs = tokenizer(questions, padding=True, truncation=True, return_tensors="pt")
19
+ with tokenizer.as_target_tokenizer():
20
+ labels = tokenizer(answers, padding=True, truncation=True, return_tensors="pt")
21
+ inputs["labels"] = labels["input_ids"]
22
+ return inputs
23
+
24
+ tokenized_datasets = dataset.map(preprocess_function, batched=True)
25
+
26
+ training_args = TrainingArguments(
27
+ output_dir="./results",
28
+ num_train_epochs=3,
29
+ per_device_train_batch_size=8,
30
+ logging_dir="./logs",
31
+ logging_steps=10,
32
+ )
33
+
34
+ trainer = Trainer(
35
+ model=model,
36
+ args=training_args,
37
+ train_dataset=tokenized_datasets,
38
+ )
39
+
40
+ trainer.train()
41
+
42
+ model.save_pretrained("./FHeta")
43
+ tokenizer.save_pretrained("./FHeta")
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained("./FHeta")
46
+ model = AutoModelForQuestionAnswering.from_pretrained("./FHeta")
47
+
48
+ def get_answer(query):
49
+ inputs = tokenizer(query, return_tensors="pt")
50
+ outputs = model(**inputs)
51
+ answer = tokenizer.decode(outputs["logits"][0], skip_special_tokens=True)
52
+ return answer
53
+
54
+ query = "Модуль FHeta"
55
+ answer = get_answer(query)
56
+ print(answer)