MohamedRasheqA
Second Commit
57880bf
# app.py
import os
import json
import torch
import pandas as pd
import gradio as gr
from sqlalchemy import create_engine, text
from transformers import (
TrainingArguments,
Trainer,
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from peft import (
prepare_model_for_kbit_training,
LoraConfig,
get_peft_model
)
from datetime import datetime
# Changed to a model that doesn't require flash-attention
MODEL_NAME = "deepseek-ai/deepseek-coder-6.7b-base"
OUTPUT_DIR = "/tmp/finetuned_models"
LOGS_DIR = "/tmp/training_logs"
class TrainingInterface:
def __init__(self):
self.current_status = "Idle"
self.progress = 0
self.is_training = False
def get_database_url(self):
database_url = os.environ.get('DATABASE_URL')
if not database_url:
raise Exception("DATABASE_URL not found in environment variables")
return database_url
def fetch_training_data(self, progress=gr.Progress()):
try:
database_url = self.get_database_url()
engine = create_engine(database_url)
progress(0, desc="Connecting to database...")
with engine.connect() as conn:
result = conn.execute(text("SELECT COUNT(*) FROM bents"))
total_rows = result.scalar()
query = text("SELECT chunk_id, text FROM bents")
df = pd.read_sql_query(query, conn)
progress(0.5, desc="Data fetched successfully")
return df
except Exception as e:
raise gr.Error(f"Database error: {str(e)}")
def prepare_training_data(self, df, progress=gr.Progress()):
formatted_data = []
try:
total_rows = len(df)
for idx, row in enumerate(df.iterrows()):
progress(idx/total_rows, desc="Preparing training data...")
_, row_data = row
chunk_id = str(row_data['chunk_id']).strip()
text = str(row_data['text']).strip()
if chunk_id and text:
formatted_text = f"Question: {chunk_id}\nAnswer: {text}" # Changed format for deepseek-coder
formatted_data.append({"text": formatted_text})
if not formatted_data:
raise ValueError("No valid training data found")
return formatted_data
except Exception as e:
raise gr.Error(f"Data preparation error: {str(e)}")
def stop_training(self):
self.is_training = False
return "Training stopped by user."
def train_model(
self,
learning_rate=2e-4,
num_epochs=3,
batch_size=4,
progress=gr.Progress()
):
try:
self.is_training = True
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
specific_output_dir = os.path.join(OUTPUT_DIR, f"run_{timestamp}")
os.makedirs(specific_output_dir, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)
progress(0.1, desc="Fetching data...")
if not self.is_training:
return "Training cancelled."
df = self.fetch_training_data()
formatted_data = self.prepare_training_data(df)
progress(0.2, desc="Loading model...")
if not self.is_training:
return "Training cancelled."
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
)
progress(0.3, desc="Setting up LoRA...")
if not self.is_training:
return "Training cancelled."
# Updated LoRA config for deepseek-coder model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
progress(0.4, desc="Configuring training...")
if not self.is_training:
return "Training cancelled."
training_args = TrainingArguments(
output_dir=specific_output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
fp16=True,
gradient_accumulation_steps=8,
gradient_checkpointing=True,
logging_dir=os.path.join(LOGS_DIR, f"run_{timestamp}"),
logging_steps=10,
save_strategy="epoch",
evaluation_strategy="no", # Changed to "no" since we don't have eval data
save_total_limit=2,
remove_unused_columns=False,
)
dataset = Dataset.from_dict({
'text': [item['text'] for item in formatted_data]
})
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False
)
class ProgressCallback(gr.Progress):
def __init__(self, progress_callback, training_interface):
self.progress_callback = progress_callback
self.training_interface = training_interface
def on_train_begin(self, args, state, control, **kwargs):
if not self.training_interface.is_training:
control.should_training_stop = True
self.progress_callback(0.5, desc="Training started...")
def on_epoch_begin(self, args, state, control, **kwargs):
if not self.training_interface.is_training:
control.should_training_stop = True
epoch_progress = (state.epoch / args.num_train_epochs)
total_progress = 0.5 + (epoch_progress * 0.4)
self.progress_callback(total_progress,
desc=f"Training epoch {state.epoch + 1}/{args.num_train_epochs}...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator,
callbacks=[ProgressCallback(progress, self)]
)
if not self.is_training:
return "Training cancelled."
trainer.train()
if not self.is_training:
return "Training cancelled."
progress(0.9, desc="Saving model...")
trainer.save_model()
tokenizer.save_pretrained(specific_output_dir)
progress(1.0, desc="Training completed!")
return f"Training completed! Model saved in {specific_output_dir}"
except Exception as e:
self.is_training = False
raise gr.Error(f"Training error: {str(e)}")
def create_training_interface():
interface = TrainingInterface()
with gr.Blocks(title="DeepSeek Coder Training Interface") as app:
gr.Markdown("# DeepSeek Coder Fine-tuning Interface")
with gr.Row():
with gr.Column():
learning_rate = gr.Slider(
minimum=1e-5,
maximum=1e-3,
value=2e-4,
label="Learning Rate"
)
num_epochs = gr.Slider(
minimum=1,
maximum=10,
value=3,
step=1,
label="Number of Epochs"
)
batch_size = gr.Slider(
minimum=1,
maximum=8,
value=4,
step=1,
label="Batch Size"
)
with gr.Row():
train_button = gr.Button("Start Training", variant="primary")
stop_button = gr.Button("Stop Training", variant="secondary")
output_text = gr.Textbox(
label="Training Status",
placeholder="Training status will appear here...",
lines=10
)
train_button.click(
fn=interface.train_model,
inputs=[learning_rate, num_epochs, batch_size],
outputs=output_text
)
stop_button.click(
fn=interface.stop_training,
inputs=[],
outputs=output_text
)
return app
if __name__ == "__main__":
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)
app = create_training_interface()
app.launch()