MjolnirThor commited on
Commit
629121e
·
verified ·
1 Parent(s): a5a7bfd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ print("Starting training process...")
2
+ from datasets import load_dataset
3
+ from transformers import (
4
+ AutoModelForSeq2SeqLM,
5
+ AutoTokenizer,
6
+ Trainer,
7
+ DataCollatorForSeq2Seq
8
+ )
9
+ from training_config import training_args
10
+
11
+ # Load dataset
12
+ print("Loading dataset...")
13
+ dataset = load_dataset("health360/Healix-Shot", split=f"train[:100000]")
14
+
15
+ # Initialize model and tokenizer
16
+ print("Initializing model and tokenizer...")
17
+ model_name = "google/flan-t5-large"
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
20
+
21
+ def tokenize_function(examples):
22
+ return tokenizer(
23
+ examples['text'],
24
+ padding="max_length",
25
+ truncation=True,
26
+ max_length=512,
27
+ return_attention_mask=True
28
+ )
29
+
30
+ # Process dataset
31
+ print("Processing dataset...")
32
+ train_test_split = dataset.train_test_split(test_size=0.1)
33
+ tokenized_train = train_test_split['train'].map(
34
+ tokenize_function,
35
+ batched=True,
36
+ remove_columns=dataset.column_names
37
+ )
38
+ tokenized_eval = train_test_split['test'].map(
39
+ tokenize_function,
40
+ batched=True,
41
+ remove_columns=dataset.column_names
42
+ )
43
+
44
+ # Initialize trainer
45
+ print("Initializing trainer...")
46
+ trainer = Trainer(
47
+ model=model,
48
+ args=training_args,
49
+ train_dataset=tokenized_train,
50
+ eval_dataset=tokenized_eval,
51
+ data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
52
+ )
53
+
54
+ # Train and save
55
+ print("Starting the training...")
56
+ trainer.train()
57
+ print("Training complete, saving model...")
58
+ model.push_to_hub("MjolnirThor/flan-t5-custom-handler")
59
+ tokenizer.push_to_hub("MjolnirThor/flan-t5-custom-handler")
60
+ print("Model saved successfully!")