kryman27 commited on
Commit
e1bdc34
·
verified ·
1 Parent(s): f4a79f7

Update train_model.py

Browse files
Files changed (1) hide show
  1. train_model.py +6 -8
train_model.py CHANGED
@@ -1,12 +1,14 @@
1
  from transformers import LayoutLMForTokenClassification, Trainer, TrainingArguments
2
  from datasets import load_dataset
3
 
4
- # Wczytanie przygotowanego zbioru danych
5
  dataset = load_dataset("json", data_files="training_data.json")["train"]
6
- dataset = dataset.train_test_split(test_size=0.2) # Podział na trening i test
7
 
8
- # Ładowanie modelu LayoutLM do dostrajania
9
- model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=10)
 
 
10
 
11
  training_args = TrainingArguments(
12
  output_dir="./layoutlmv3_finetuned",
@@ -27,9 +29,5 @@ trainer = Trainer(
27
  )
28
 
29
  trainer.train()
30
-
31
- # Zapisanie modelu lokalnie
32
  model.save_pretrained("./layoutlmv3_finetuned")
33
-
34
- # Wysłanie modelu do Hugging Face (tylko jeśli masz konto)
35
  model.push_to_hub("kryman27/layoutlmv3-finetuned")
 
1
  from transformers import LayoutLMForTokenClassification, Trainer, TrainingArguments
2
  from datasets import load_dataset
3
 
4
+ # Upewnij się, że training_data.json zawiera etykiety odpowiadające nowym polom
5
  dataset = load_dataset("json", data_files="training_data.json")["train"]
6
+ dataset = dataset.train_test_split(test_size=0.2)
7
 
8
+ # Dostosuj liczbę etykiet do rozszerzonego zakresu ekstrakcji (przykładowo 15)
9
+ num_labels = 15
10
+
11
+ model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=num_labels)
12
 
13
  training_args = TrainingArguments(
14
  output_dir="./layoutlmv3_finetuned",
 
29
  )
30
 
31
  trainer.train()
 
 
32
  model.save_pretrained("./layoutlmv3_finetuned")
 
 
33
  model.push_to_hub("kryman27/layoutlmv3-finetuned")