adel67460 commited on
Commit
1946897
·
verified ·
1 Parent(s): d583187

Upload train_interface.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_interface.py +194 -0
train_interface.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import HfApi, login
3
+ import json
4
+ import os
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
6
+ from datasets import load_dataset
7
+ import traceback
8
+ import torch
9
+ from connect_huggingface import setup_huggingface
10
+ from accelerate import Accelerator
11
+ from accelerate.utils import set_seed
12
+ from transformers import DataCollatorForLanguageModeling
13
+ import pandas as pd
14
+ from datetime import datetime
15
+
16
+ # Charger la configuration
17
+ with open('config.json', 'r') as f:
18
+ config = json.load(f)
19
+
20
+ class TrainingCallback:
21
+ def __init__(self, status_box, log_box):
22
+ self.status_box = status_box
23
+ self.log_box = log_box
24
+ self.logs = []
25
+
26
+ def on_log(self, args, state, control, logs=None):
27
+ if logs:
28
+ timestamp = datetime.now().strftime("%H:%M:%S")
29
+ log_entry = f"[{timestamp}] Loss: {logs.get('loss', 'N/A'):.4f}"
30
+ if 'eval_loss' in logs:
31
+ log_entry += f", Eval Loss: {logs['eval_loss']:.4f}"
32
+ self.logs.append(log_entry)
33
+ self.log_box.update(value="\n".join(self.logs[-20:])) # Keep last 20 logs
34
+
35
+ def on_step_end(self, args, state, control):
36
+ self.status_box.update(value=f"Étape {state.global_step}/{state.max_steps}")
37
+
38
+ def format_prompt(instruction, input_text, output):
39
+ """Formate le prompt pour l'entraînement"""
40
+ if input_text and input_text.strip():
41
+ return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
42
+ return f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
43
+
44
+ def preprocess_function(examples, tokenizer):
45
+ """Prétraite les données pour l'entraînement"""
46
+ # Créer les prompts
47
+ prompts = [
48
+ format_prompt(instruction, input_text, output)
49
+ for instruction, input_text, output in zip(
50
+ examples['instruction'],
51
+ examples['input'],
52
+ examples['output']
53
+ )
54
+ ]
55
+
56
+ # Tokenizer les prompts avec padding
57
+ model_inputs = tokenizer(
58
+ prompts,
59
+ padding=True,
60
+ truncation=True,
61
+ max_length=512,
62
+ return_tensors=None
63
+ )
64
+
65
+ # Créer les labels (décalés de 1 pour l'entraînement causal)
66
+ labels = model_inputs["input_ids"].copy()
67
+
68
+ # Mettre -100 sur le padding pour l'ignorer dans la loss
69
+ for i, label in enumerate(labels):
70
+ labels[i] = [-100 if token == tokenizer.pad_token_id else token for token in label]
71
+
72
+ model_inputs["labels"] = labels
73
+ return model_inputs
74
+
75
+ def compute_metrics(eval_pred):
76
+ """Calcule les métriques d'évaluation"""
77
+ predictions, labels = eval_pred
78
+ # Convertir en tenseurs PyTorch
79
+ predictions = torch.tensor(predictions)
80
+ labels = torch.tensor(labels)
81
+
82
+ # Calculer la loss
83
+ loss = torch.nn.functional.cross_entropy(
84
+ predictions.view(-1, predictions.size(-1)),
85
+ labels.view(-1),
86
+ ignore_index=-100
87
+ )
88
+
89
+ return {
90
+ "loss": loss.item()
91
+ }
92
+
93
+ def start_training(status_box=gr.Textbox(), log_box=gr.Markdown()):
94
+ """Lance l'entraînement du modèle"""
95
+ try:
96
+ # Configuration de Hugging Face
97
+ if not setup_huggingface():
98
+ return "Erreur : Impossible de configurer Hugging Face"
99
+
100
+ status_box.update(value="Configuration de l'environnement...")
101
+
102
+ # Charger le modèle et le tokenizer
103
+ tokenizer = AutoTokenizer.from_pretrained(
104
+ config['model']['name'],
105
+ trust_remote_code=True
106
+ )
107
+ model = AutoModelForCausalLM.from_pretrained(
108
+ config['model']['name'],
109
+ trust_remote_code=True,
110
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
111
+ )
112
+
113
+ # Configurer le tokenizer
114
+ tokenizer.pad_token = tokenizer.eos_token
115
+
116
+ status_box.update(value="Chargement du dataset...")
117
+
118
+ # Charger le dataset
119
+ dataset = load_dataset(config['dataset']['name'])
120
+
121
+ # Prétraiter le dataset
122
+ tokenized_dataset = dataset.map(
123
+ lambda x: preprocess_function(x, tokenizer),
124
+ batched=True,
125
+ remove_columns=dataset["train"].column_names
126
+ )
127
+
128
+ # Créer le data collator
129
+ data_collator = DataCollatorForLanguageModeling(
130
+ tokenizer=tokenizer,
131
+ mlm=False
132
+ )
133
+
134
+ status_box.update(value="Configuration de l'entraînement...")
135
+
136
+ # Configuration de l'entraînement
137
+ training_args = TrainingArguments(
138
+ output_dir=config['training']['output_dir'],
139
+ num_train_epochs=config['training']['epochs'],
140
+ per_device_train_batch_size=config['training']['batch_size'],
141
+ gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
142
+ learning_rate=float(config['training']['learning_rate']),
143
+ bf16=config['training'].get('bf16', True),
144
+ logging_steps=10,
145
+ evaluation_strategy="steps",
146
+ eval_steps=100,
147
+ save_strategy="steps",
148
+ save_steps=100,
149
+ save_total_limit=1,
150
+ load_best_model_at_end=True,
151
+ )
152
+
153
+ # Créer le callback
154
+ callback = TrainingCallback(status_box, log_box)
155
+
156
+ # Créer le trainer
157
+ trainer = Trainer(
158
+ model=model,
159
+ args=training_args,
160
+ train_dataset=tokenized_dataset["train"],
161
+ eval_dataset=tokenized_dataset["validation"] if "validation" in tokenized_dataset else None,
162
+ tokenizer=tokenizer,
163
+ data_collator=data_collator,
164
+ compute_metrics=compute_metrics,
165
+ callbacks=[callback]
166
+ )
167
+
168
+ status_box.update(value="Démarrage de l'entraînement...")
169
+
170
+ # Lancer l'entraînement
171
+ trainer.train()
172
+
173
+ # Sauvegarder le modèle final
174
+ trainer.save_model()
175
+
176
+ status_box.update(value="Entraînement terminé !")
177
+ return "Entraînement terminé avec succès !"
178
+
179
+ except Exception as e:
180
+ error_msg = f"Erreur pendant l'entraînement : {str(e)}\n{traceback.format_exc()}"
181
+ print(error_msg)
182
+ return error_msg
183
+
184
+ # Interface Gradio
185
+ demo = gr.Interface(
186
+ fn=start_training,
187
+ inputs=[gr.Textbox(label="Statut de l'entraînement"), gr.Markdown(label="Logs de l'entraînement")],
188
+ outputs=[gr.Textbox(label="Statut de l'entraînement"), gr.Markdown(label="Logs de l'entraînement")],
189
+ title="AUTO Training Space",
190
+ description="Cliquez sur le bouton pour lancer l'entraînement du modèle."
191
+ )
192
+
193
+ if __name__ == "__main__":
194
+ demo.launch()