medical_chatbot / train.py
nadeen-elsayed's picture
Update train.py
09f0a44 verified
import torch
import json
import os
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from huggingface_hub import login
# βœ… Authenticate with Hugging Face
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise ValueError("Hugging Face token not found. Add it in 'Secrets'.")
login(token=HF_TOKEN)
# βœ… Load Extracted Data
dataset_path = "medical_dataset.json"
if not os.path.exists(dataset_path):
raise FileNotFoundError(f"Dataset file '{dataset_path}' not found!")
with open(dataset_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Dataset should be a list of dictionaries.")
dataset = Dataset.from_list(data)
# βœ… Load Tokenizer
model_name = "tiiuae/falcon-rw-1b"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# βœ… Tokenize Data (Fixed)
def preprocess_function(examples):
prompt = examples.get("prompt", "")
response = examples.get("response", "")
inputs = f"Medical Q&A: {prompt} {response}"
model_inputs = tokenizer(inputs, padding="max_length", truncation=True, max_length=512)
# βœ… Ensure labels have the same length
model_inputs["labels"] = model_inputs["input_ids"]
return {key: [val] for key, val in model_inputs.items()} # βœ… Wrap values in lists
# βœ… Apply tokenization
tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset.column_names)
# βœ… Load Model with LoRA (Optimized for Falcon)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # βœ… Save memory
device_map="auto" # βœ… Auto-assign to CPU/GPU
)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["query_key_value"], # βœ… Correct target module for Falcon
lora_dropout=0.05,
bias="none"
)
model = get_peft_model(model, lora_config)
# βœ… Define Training Arguments
training_args = TrainingArguments(
output_dir="./medical_falcon",
per_device_train_batch_size=1,
num_train_epochs=3, # βœ… Adjust epochs as needed
logging_dir="./logs",
save_steps=100,
evaluation_strategy="no",
save_total_limit=2,
fp16=True # βœ… Enable mixed precision training
)
# βœ… Train Model
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset
)
trainer.train()
# βœ… Save Model
model_path = "fine_tuned_medical_falcon"
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)
print(f"βœ… Model fine-tuned and saved at: {model_path}")