Upload train_interface.py with huggingface_hub
Browse files- train_interface.py +64 -183
train_interface.py
CHANGED
@@ -1,194 +1,75 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
import json
|
4 |
import os
|
5 |
-
from
|
6 |
-
|
7 |
-
import traceback
|
8 |
-
import torch
|
9 |
from connect_huggingface import setup_huggingface
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
|
16 |
# Charger la configuration
|
|
|
17 |
with open('config.json', 'r') as f:
|
18 |
config = json.load(f)
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
def on_log(self, args, state, control, logs=None):
|
27 |
-
if logs:
|
28 |
-
timestamp = datetime.now().strftime("%H:%M:%S")
|
29 |
-
log_entry = f"[{timestamp}] Loss: {logs.get('loss', 'N/A'):.4f}"
|
30 |
-
if 'eval_loss' in logs:
|
31 |
-
log_entry += f", Eval Loss: {logs['eval_loss']:.4f}"
|
32 |
-
self.logs.append(log_entry)
|
33 |
-
self.log_box.update(value="\n".join(self.logs[-20:])) # Keep last 20 logs
|
34 |
-
|
35 |
-
def on_step_end(self, args, state, control):
|
36 |
-
self.status_box.update(value=f"Étape {state.global_step}/{state.max_steps}")
|
37 |
-
|
38 |
-
def format_prompt(instruction, input_text, output):
|
39 |
-
"""Formate le prompt pour l'entraînement"""
|
40 |
-
if input_text and input_text.strip():
|
41 |
-
return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n{output}"
|
42 |
-
return f"### Instruction:\n{instruction}\n\n### Response:\n{output}"
|
43 |
-
|
44 |
-
def preprocess_function(examples, tokenizer):
|
45 |
-
"""Prétraite les données pour l'entraînement"""
|
46 |
-
# Créer les prompts
|
47 |
-
prompts = [
|
48 |
-
format_prompt(instruction, input_text, output)
|
49 |
-
for instruction, input_text, output in zip(
|
50 |
-
examples['instruction'],
|
51 |
-
examples['input'],
|
52 |
-
examples['output']
|
53 |
-
)
|
54 |
-
]
|
55 |
-
|
56 |
-
# Tokenizer les prompts avec padding
|
57 |
-
model_inputs = tokenizer(
|
58 |
-
prompts,
|
59 |
-
padding=True,
|
60 |
-
truncation=True,
|
61 |
-
max_length=512,
|
62 |
-
return_tensors=None
|
63 |
-
)
|
64 |
-
|
65 |
-
# Créer les labels (décalés de 1 pour l'entraînement causal)
|
66 |
-
labels = model_inputs["input_ids"].copy()
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
trust_remote_code=True,
|
110 |
-
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
111 |
-
)
|
112 |
-
|
113 |
-
# Configurer le tokenizer
|
114 |
-
tokenizer.pad_token = tokenizer.eos_token
|
115 |
-
|
116 |
-
status_box.update(value="Chargement du dataset...")
|
117 |
-
|
118 |
-
# Charger le dataset
|
119 |
-
dataset = load_dataset(config['dataset']['name'])
|
120 |
-
|
121 |
-
# Prétraiter le dataset
|
122 |
-
tokenized_dataset = dataset.map(
|
123 |
-
lambda x: preprocess_function(x, tokenizer),
|
124 |
-
batched=True,
|
125 |
-
remove_columns=dataset["train"].column_names
|
126 |
-
)
|
127 |
-
|
128 |
-
# Créer le data collator
|
129 |
-
data_collator = DataCollatorForLanguageModeling(
|
130 |
-
tokenizer=tokenizer,
|
131 |
-
mlm=False
|
132 |
-
)
|
133 |
-
|
134 |
-
status_box.update(value="Configuration de l'entraînement...")
|
135 |
-
|
136 |
-
# Configuration de l'entraînement
|
137 |
-
training_args = TrainingArguments(
|
138 |
-
output_dir=config['training']['output_dir'],
|
139 |
-
num_train_epochs=config['training']['epochs'],
|
140 |
-
per_device_train_batch_size=config['training']['batch_size'],
|
141 |
-
gradient_accumulation_steps=config['training']['gradient_accumulation_steps'],
|
142 |
-
learning_rate=float(config['training']['learning_rate']),
|
143 |
-
bf16=config['training'].get('bf16', True),
|
144 |
-
logging_steps=10,
|
145 |
-
evaluation_strategy="steps",
|
146 |
-
eval_steps=100,
|
147 |
-
save_strategy="steps",
|
148 |
-
save_steps=100,
|
149 |
-
save_total_limit=1,
|
150 |
-
load_best_model_at_end=True,
|
151 |
-
)
|
152 |
-
|
153 |
-
# Créer le callback
|
154 |
-
callback = TrainingCallback(status_box, log_box)
|
155 |
-
|
156 |
-
# Créer le trainer
|
157 |
-
trainer = Trainer(
|
158 |
-
model=model,
|
159 |
-
args=training_args,
|
160 |
-
train_dataset=tokenized_dataset["train"],
|
161 |
-
eval_dataset=tokenized_dataset["validation"] if "validation" in tokenized_dataset else None,
|
162 |
-
tokenizer=tokenizer,
|
163 |
-
data_collator=data_collator,
|
164 |
-
compute_metrics=compute_metrics,
|
165 |
-
callbacks=[callback]
|
166 |
-
)
|
167 |
-
|
168 |
-
status_box.update(value="Démarrage de l'entraînement...")
|
169 |
-
|
170 |
-
# Lancer l'entraînement
|
171 |
-
trainer.train()
|
172 |
-
|
173 |
-
# Sauvegarder le modèle final
|
174 |
-
trainer.save_model()
|
175 |
-
|
176 |
-
status_box.update(value="Entraînement terminé !")
|
177 |
-
return "Entraînement terminé avec succès !"
|
178 |
-
|
179 |
-
except Exception as e:
|
180 |
-
error_msg = f"Erreur pendant l'entraînement : {str(e)}\n{traceback.format_exc()}"
|
181 |
-
print(error_msg)
|
182 |
-
return error_msg
|
183 |
-
|
184 |
-
# Interface Gradio
|
185 |
-
demo = gr.Interface(
|
186 |
-
fn=start_training,
|
187 |
-
inputs=[gr.Textbox(label="Statut de l'entraînement"), gr.Markdown(label="Logs de l'entraînement")],
|
188 |
-
outputs=[gr.Textbox(label="Statut de l'entraînement"), gr.Markdown(label="Logs de l'entraînement")],
|
189 |
-
title="AUTO Training Space",
|
190 |
-
description="Cliquez sur le bouton pour lancer l'entraînement du modèle."
|
191 |
-
)
|
192 |
|
193 |
if __name__ == "__main__":
|
194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from train_interface import start_training
|
|
|
3 |
import os
|
4 |
+
from huggingface_hub import login
|
5 |
+
import json
|
|
|
|
|
6 |
from connect_huggingface import setup_huggingface
|
7 |
+
import sys
|
8 |
+
|
9 |
+
print("=== Démarrage de l'application ===")
|
10 |
+
print(f"Python version: {sys.version}")
|
11 |
+
print(f"Working directory: {os.getcwd()}")
|
12 |
|
13 |
# Charger la configuration
|
14 |
+
print("\nChargement de la configuration...")
|
15 |
with open('config.json', 'r') as f:
|
16 |
config = json.load(f)
|
17 |
|
18 |
+
def create_interface():
|
19 |
+
# Configurer Hugging Face
|
20 |
+
print("\nConfiguration de Hugging Face...")
|
21 |
+
if not setup_huggingface():
|
22 |
+
print("Erreur : Impossible de configurer Hugging Face")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
+
with gr.Blocks(theme="huggingface") as demo:
|
25 |
+
gr.Markdown("# AUTO Training Space")
|
26 |
+
|
27 |
+
with gr.Row():
|
28 |
+
with gr.Column():
|
29 |
+
gr.Markdown(f"""
|
30 |
+
### Configuration actuelle
|
31 |
+
- **Modèle** : {config['model']['name']}
|
32 |
+
- **Dataset** : {config['dataset']['name']}
|
33 |
+
- **Nombre d'époques** : {config['training']['epochs']}
|
34 |
+
|
35 |
+
### Format du dataset
|
36 |
+
Le dataset contient des exemples structurés avec :
|
37 |
+
- Une instruction (question utilisateur)
|
38 |
+
- Une entrée (contexte optionnel)
|
39 |
+
- Une sortie (réponse avec recommandations)
|
40 |
+
|
41 |
+
### Optimisations
|
42 |
+
- Utilisation de BF16 pour une meilleure performance
|
43 |
+
- Gestion optimisée des données avec pandas
|
44 |
+
""")
|
45 |
+
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
status_output = gr.Textbox(
|
49 |
+
label="Statut de l'entraînement",
|
50 |
+
value="En attente de démarrage...",
|
51 |
+
interactive=False
|
52 |
+
)
|
53 |
+
logs_output = gr.Markdown(
|
54 |
+
value="### Logs d'entraînement\nLes logs seront affichés ici pendant l'entraînement."
|
55 |
+
)
|
56 |
+
|
57 |
+
train_button = gr.Button("Démarrer l'entraînement")
|
58 |
+
train_button.click(
|
59 |
+
fn=start_training,
|
60 |
+
inputs=[],
|
61 |
+
outputs=[status_output, logs_output]
|
62 |
+
)
|
63 |
+
|
64 |
+
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
+
print("\nCréation de l'interface...")
|
68 |
+
demo = create_interface()
|
69 |
+
|
70 |
+
print("\nLancement de l'application...")
|
71 |
+
demo.launch(
|
72 |
+
server_name="0.0.0.0",
|
73 |
+
server_port=7860,
|
74 |
+
show_api=False
|
75 |
+
)
|