Andro0s commited on
Commit
8eae154
·
verified ·
1 Parent(s): 932f265

Update Train.py

Browse files
Files changed (1) hide show
  1. Train.py +41 -16
Train.py CHANGED
@@ -1,8 +1,22 @@
1
- def train_lora(epochs, batch_size, learning_rate):
 
 
 
 
 
 
 
 
 
 
 
 
2
  try:
3
- dataset = load_dataset("json", data_files=DATASET_PATH)
4
-
5
- # Tokenización correcta
 
 
6
  def tokenize_fn(example):
7
  return tokenizer(
8
  example["prompt"] + example["completion"],
@@ -11,34 +25,45 @@ def train_lora(epochs, batch_size, learning_rate):
11
  max_length=256,
12
  )
13
 
14
- tokenized = dataset.map(tokenize_fn, batched=False)
15
-
16
- # Asegúrate que las columnas correctas estén
17
- tokenized.set_format(type="torch", columns=["input_ids", "attention_mask"])
18
 
 
 
 
 
 
19
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
20
 
 
21
  training_args = TrainingArguments(
22
- output_dir=LORA_PATH,
23
  per_device_train_batch_size=int(batch_size),
24
- num_train_epochs=int(epochs),
25
- learning_rate=learning_rate,
26
  save_total_limit=1,
27
  logging_steps=10,
28
  push_to_hub=False
29
  )
30
 
 
31
  trainer = Trainer(
32
- model=base_model,
 
33
  args=training_args,
34
  train_dataset=tokenized["train"],
35
  data_collator=data_collator,
36
  )
37
 
 
38
  trainer.train()
39
- base_model.save_pretrained(LORA_PATH)
40
- tokenizer.save_pretrained(LORA_PATH)
 
 
 
41
 
42
- return "✅ Entrenamiento completado y guardado en ./lora_output"
 
43
  except Exception as e:
44
- return f"❌ Error durante el entrenamiento: {e}"
 
1
+ from datasets import load_dataset
2
+ from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
3
+ # Se asume que peft, tokenizer, base_model, etc., están definidos globalmente.
4
+
5
+ def train_lora(epochs, batch_size, learning_rate, model_to_train, tokenizer, dataset_path, lora_path):
6
+ """
7
+ Ejecuta el entrenamiento del modelo LoRA de forma eficiente.
8
+
9
+ :param model_to_train: El modelo PEFT (LoRA) ya envuelto y listo para entrenar.
10
+ :param tokenizer: El tokenizer cargado.
11
+ :param dataset_path: Ruta al archivo JSON del dataset.
12
+ :param lora_path: Ruta donde se guardarán los adaptadores LoRA.
13
+ """
14
  try:
15
+ # 1. Carga del Dataset (Asegúrate de que 'tu_dataset.json' exista)
16
+ print(f"🔄 Cargando dataset desde: {dataset_path}")
17
+ dataset = load_dataset("json", data_files=dataset_path)
18
+
19
+ # 2. Tokenización eficiente
20
  def tokenize_fn(example):
21
  return tokenizer(
22
  example["prompt"] + example["completion"],
 
25
  max_length=256,
26
  )
27
 
28
+ # 🟢 MEJORA: batched=True para tokenización más rápida
29
+ tokenized = dataset.map(tokenize_fn, batched=True, remove_columns=dataset["train"].column_names)
 
 
30
 
31
+ # 3. Preparación final de los datos
32
+ # No es estrictamente necesario si ya se usa DataCollator, pero es buena práctica.
33
+ tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
34
+
35
+ # El DataCollatorForLanguageModeling se encarga de clonar 'input_ids' a 'labels'
36
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
37
 
38
+ # 4. Argumentos de Entrenamiento
39
  training_args = TrainingArguments(
40
+ output_dir=lora_path,
41
  per_device_train_batch_size=int(batch_size),
42
+ num_train_epochs=float(epochs), # 🟢 MEJORA: Usar float para aceptar épocas decimales
43
+ learning_rate=float(learning_rate), # 🟢 MEJORA: Usar float
44
  save_total_limit=1,
45
  logging_steps=10,
46
  push_to_hub=False
47
  )
48
 
49
+ # 5. Inicialización y Entrenamiento del Trainer
50
  trainer = Trainer(
51
+ # 🟢 CORRECCIÓN CRÍTICA: Debe usarse el modelo PEFT (LoRA) ya envuelto
52
+ model=model_to_train,
53
  args=training_args,
54
  train_dataset=tokenized["train"],
55
  data_collator=data_collator,
56
  )
57
 
58
+ print("🚀 Iniciando entrenamiento...")
59
  trainer.train()
60
+
61
+ # 6. Guardado Correcto de los Adaptadores
62
+ # 🟢 CORRECCIÓN CRÍTICA: Guardar solo los adaptadores LoRA (peft)
63
+ model_to_train.save_pretrained(lora_path)
64
+ tokenizer.save_pretrained(lora_path)
65
 
66
+ return f"✅ Entrenamiento completado. Adaptadores LoRA guardados en {lora_path}"
67
+
68
  except Exception as e:
69
+ return f"❌ Error durante el entrenamiento: {e}"