|
import os |
|
import json |
|
import sys |
|
from datetime import datetime |
|
import traceback |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer |
|
from datasets import load_dataset |
|
import torch |
|
import pandas as pd |
|
from huggingface_hub import login |
|
from connect_huggingface import setup_huggingface |
|
import gradio as gr |
|
|
|
class TrainingCallback: |
|
def __init__(self): |
|
self.logs = [] |
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
if logs: |
|
self.logs.append(logs) |
|
|
|
def get_logs(self): |
|
return "\n".join([str(log) for log in self.logs]) |
|
|
|
def start_training(): |
|
try: |
|
|
|
if not setup_huggingface(): |
|
return "Erreur : Impossible de configurer Hugging Face", "### Logs d'entraînement\nErreur de configuration Hugging Face" |
|
|
|
|
|
status = "Configuration de l'environnement..." |
|
logs = f"### Logs d'entraînement\nDémarrage à {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" |
|
|
|
|
|
with open('config.json', 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
logs += "- Chargement du modèle et du tokenizer...\n" |
|
tokenizer = AutoTokenizer.from_pretrained(config['model']['tokenizer']) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
config['model']['name'], |
|
torch_dtype=torch.bfloat16 if config['training']['bf16'] else torch.float32, |
|
device_map="auto" |
|
) |
|
|
|
|
|
status = "Chargement du dataset..." |
|
logs += f"- Chargement de {config['dataset']['name']}...\n" |
|
dataset = load_dataset(config['dataset']['name']) |
|
|
|
|
|
status = "Configuration de l'entraînement..." |
|
logs += "- Configuration des paramètres d'entraînement...\n" |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
num_train_epochs=config['training']['epochs'], |
|
per_device_train_batch_size=config['training']['batch_size'], |
|
learning_rate=config['training']['learning_rate'], |
|
warmup_ratio=config['training']['warmup_ratio'], |
|
evaluation_strategy=config['training']['evaluation_strategy'], |
|
eval_steps=config['training']['eval_steps'], |
|
save_strategy=config['training']['save_strategy'], |
|
save_steps=config['training']['save_steps'], |
|
save_total_limit=config['training']['save_total_limit'], |
|
load_best_model_at_end=config['training']['load_best_model_at_end'], |
|
metric_for_best_model=config['training']['metric_for_best_model'], |
|
greater_is_better=config['training']['greater_is_better'], |
|
gradient_accumulation_steps=config['training']['gradient_accumulation_steps'], |
|
logging_steps=config['training']['logging_steps'], |
|
fp16=config['training']['fp16'], |
|
bf16=config['training']['bf16'] |
|
) |
|
|
|
|
|
callback = TrainingCallback() |
|
|
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=dataset[config['dataset']['train_split']], |
|
eval_dataset=dataset[config['dataset']['eval_split']], |
|
callbacks=[callback] |
|
) |
|
|
|
|
|
status = "Entraînement en cours..." |
|
logs += "- Début de l'entraînement...\n" |
|
|
|
trainer.train() |
|
|
|
|
|
logs += "\n### Logs détaillés\n" |
|
logs += callback.get_logs() |
|
|
|
status = "Entraînement terminé avec succès!" |
|
logs += "\n\nEntraînement terminé avec succès!" |
|
|
|
return status, logs |
|
|
|
except Exception as e: |
|
error_msg = f"Erreur pendant l'entraînement : {str(e)}\n{traceback.format_exc()}" |
|
print(error_msg) |
|
return "Erreur pendant l'entraînement", f"### Logs d'entraînement\n❌ {error_msg}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=start_training, |
|
inputs=[], |
|
outputs=[gr.Textbox(label="Statut de l'entraînement"), gr.Markdown(label="Logs de l'entraînement")], |
|
title="AUTO Training Space", |
|
description="Cliquez sur le bouton pour lancer l'entraînement du modèle." |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|