MisterAI commited on
Commit
eb96a2e
·
verified ·
1 Parent(s): e38bac9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #BS_app.py_02
2
+ #Training NOK
3
+
4
+ #testing bloom1b training
5
+
6
+ import gradio as gr
7
+ import os
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
9
+ from datasets import load_dataset, Dataset
10
+ from huggingface_hub import HfApi, HfFolder
11
+
12
+ # Récupérer token depuis les variables d'environnement
13
+ hf_token = os.getenv("MisterAI_bigscience_bloom_560m")
14
+
15
+ # Configurer le token pour l'utilisation avec Hugging Face
16
+ if hf_token:
17
+ HfFolder.save_token(hf_token)
18
+ else:
19
+ raise ValueError("Le token Hugging Face n'est pas configuré. Assurez-vous qu'il est défini dans les variables d'environnement.")
20
+
21
+ # Chargement du modèle et du tokenizer
22
+ model_name = "MisterAI/bigscience_bloom-560m"
23
+ model = AutoModelForCausalLM.from_pretrained(model_name)
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
+ # Fonction pour générer une réponse
27
+ def generate_response(input_text):
28
+ inputs = tokenizer(input_text, return_tensors="pt")
29
+ outputs = model.generate(**inputs, max_length=100)
30
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+ return response
32
+
33
+ # Fonction pour le fine-tuning
34
+ def fine_tune_model(dataset_path, dataset_file, epochs, batch_size, prefix):
35
+ # Chargement du dataset
36
+ if dataset_path.startswith("https://huggingface.co/datasets/"):
37
+ dataset = load_dataset('json', data_files={dataset_file: dataset_path})
38
+ else:
39
+ dataset = load_dataset('json', data_files={dataset_file: dataset_path})
40
+
41
+ # Préparation des données
42
+ dataset = Dataset.from_dict(dataset[dataset_file])
43
+ dataset = dataset.map(lambda x: tokenizer(x['question'] + ' ' + x['chosen'], truncation=True, padding='max_length'), batched=True)
44
+ dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
45
+
46
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
47
+
48
+ # Configuration de l'entraînement
49
+ training_args = TrainingArguments(
50
+ output_dir=f"./{prefix}_{model_name.split('/')[-1]}",
51
+ num_train_epochs=epochs,
52
+ per_device_train_batch_size=batch_size,
53
+ save_steps=10_000,
54
+ save_total_limit=2,
55
+ push_to_hub=True,
56
+ hub_model_id=f"{prefix}_{model_name.split('/')[-1]}",
57
+ hub_strategy="checkpoint",
58
+ hub_token=hf_token,
59
+ )
60
+
61
+ trainer = Trainer(
62
+ model=model,
63
+ args=training_args,
64
+ data_collator=data_collator,
65
+ train_dataset=dataset,
66
+ )
67
+
68
+ # Lancement de l'entraînement
69
+ trainer.train()
70
+
71
+ # Sauvegarde du modèle avec un préfixe
72
+ trainer.save_model(f"./{prefix}_{model_name.split('/')[-1]}")
73
+ tokenizer.save_pretrained(f"./{prefix}_{model_name.split('/')[-1]}")
74
+
75
+ # Push vers Hugging Face Hub
76
+ api = HfApi()
77
+ api.upload_folder(
78
+ folder_path=f"./{prefix}_{model_name.split('/')[-1]}",
79
+ repo_id=f"{prefix}_{model_name.split('/')[-1]}",
80
+ repo_type="model"
81
+ )
82
+
83
+ return "Fine-tuning terminé et modèle sauvegardé."
84
+
85
+ # Interface Gradio
86
+ with gr.Blocks() as demo:
87
+ with gr.Tab("Chatbot"):
88
+ chat_interface = gr.Interface(
89
+ fn=generate_response,
90
+ inputs="text",
91
+ outputs="text",
92
+ title="Chat avec le modèle",
93
+ description="Entrez votre message pour obtenir une réponse du modèle"
94
+ )
95
+
96
+ with gr.Tab("Fine-Tuning"):
97
+ with gr.Row():
98
+ dataset_path = gr.Textbox(label="Chemin du dataset")
99
+ dataset_file = gr.Textbox(label="Nom du fichier du dataset")
100
+ epochs = gr.Number(label="Nombre d'époques", value=1)
101
+ batch_size = gr.Number(label="Taille du batch", value=2)
102
+ prefix = gr.Textbox(label="Préfixe pour les fichiers sauvegardés")
103
+
104
+ fine_tune_button = gr.Button("Lancer le Fine-Tuning")
105
+
106
+ fine_tune_output = gr.Textbox(label="État du Fine-Tuning")
107
+
108
+ fine_tune_button.click(
109
+ fine_tune_model,
110
+ inputs=[dataset_path, dataset_file, epochs, batch_size, prefix],
111
+ outputs=fine_tune_output
112
+ )
113
+
114
+
115
+ # Lancement de la démo
116
+ if __name__ == "__main__":
117
+ demo.launch()
118
+
119
+
120
+