awacke1 commited on
Commit
bfae0ee
Β·
verified Β·
1 Parent(s): 24b7070

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +301 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import shutil
4
+ import glob
5
+ import base64
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from torch.utils.data import Dataset, DataLoader
11
+ import csv
12
+ import time
13
+ from dataclasses import dataclass
14
+ from typing import Optional
15
+ import zipfile
16
+
17
+ # Page Configuration
18
+ st.set_page_config(
19
+ page_title="SFT Model Builder πŸš€",
20
+ page_icon="πŸ€–",
21
+ layout="wide",
22
+ initial_sidebar_state="expanded",
23
+ )
24
+
25
+ # Meta class for model configuration
26
+ class ModelMeta(type):
27
+ def __new__(cls, name, bases, attrs):
28
+ attrs['registry'] = {}
29
+ return super().__new__(cls, name, bases, attrs)
30
+
31
+ # Model Configuration Class
32
+ @dataclass
33
+ class ModelConfig(metaclass=ModelMeta):
34
+ name: str
35
+ base_model: str
36
+ size: str
37
+ domain: Optional[str] = None
38
+
39
+ def __init_subclass__(cls):
40
+ ModelConfig.registry[cls.__name__] = cls
41
+
42
+ @property
43
+ def model_path(self):
44
+ return f"models/{self.name}"
45
+
46
+ # Custom Dataset for SFT
47
+ class SFTDataset(Dataset):
48
+ def __init__(self, data, tokenizer, max_length=128):
49
+ self.data = data
50
+ self.tokenizer = tokenizer
51
+ self.max_length = max_length
52
+
53
+ def __len__(self):
54
+ return len(self.data)
55
+
56
+ def __getitem__(self, idx):
57
+ prompt = self.data[idx]["prompt"]
58
+ response = self.data[idx]["response"]
59
+ input_text = f"{prompt} {response}"
60
+ encoding = self.tokenizer(
61
+ input_text,
62
+ max_length=self.max_length,
63
+ padding="max_length",
64
+ truncation=True,
65
+ return_tensors="pt"
66
+ )
67
+ return {
68
+ "input_ids": encoding["input_ids"].squeeze(),
69
+ "attention_mask": encoding["attention_mask"].squeeze(),
70
+ "labels": encoding["input_ids"].squeeze()
71
+ }
72
+
73
+ # Model Builder Class
74
+ class ModelBuilder:
75
+ def __init__(self):
76
+ self.config = None
77
+ self.model = None
78
+ self.tokenizer = None
79
+ self.sft_data = None
80
+
81
+ def load_model(self, model_path: str, config: Optional[ModelConfig] = None):
82
+ """Load a model from a path with an optional config"""
83
+ with st.spinner("Loading model... ⏳"):
84
+ self.model = AutoModelForCausalLM.from_pretrained(model_path)
85
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
86
+ if self.tokenizer.pad_token is None:
87
+ self.tokenizer.pad_token = self.tokenizer.eos_token
88
+ if config:
89
+ self.config = config
90
+ st.success("Model loaded! βœ…")
91
+ return self
92
+
93
+ def fine_tune_sft(self, csv_path: str, epochs: int = 3, batch_size: int = 4):
94
+ """Perform Supervised Fine-Tuning with CSV data"""
95
+ self.sft_data = []
96
+ with open(csv_path, "r") as f:
97
+ reader = csv.DictReader(f)
98
+ for row in reader:
99
+ self.sft_data.append({"prompt": row["prompt"], "response": row["response"]})
100
+
101
+ dataset = SFTDataset(self.sft_data, self.tokenizer)
102
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
103
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
104
+
105
+ self.model.train()
106
+ for epoch in range(epochs):
107
+ with st.spinner(f"Training epoch {epoch + 1}/{epochs}... βš™οΈ"):
108
+ total_loss = 0
109
+ for batch in dataloader:
110
+ optimizer.zero_grad()
111
+ input_ids = batch["input_ids"].to(self.model.device)
112
+ attention_mask = batch["attention_mask"].to(self.model.device)
113
+ labels = batch["labels"].to(self.model.device)
114
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
115
+ loss = outputs.loss
116
+ loss.backward()
117
+ optimizer.step()
118
+ total_loss += loss.item()
119
+ st.write(f"Epoch {epoch + 1} completed. Average loss: {total_loss / len(dataloader):.4f}")
120
+ st.success("SFT Fine-tuning completed! πŸŽ‰")
121
+ return self
122
+
123
+ def save_model(self, path: str):
124
+ """Save the fine-tuned model"""
125
+ with st.spinner("Saving model... πŸ’Ύ"):
126
+ os.makedirs(os.path.dirname(path), exist_ok=True)
127
+ self.model.save_pretrained(path)
128
+ self.tokenizer.save_pretrained(path)
129
+ st.success(f"Model saved at {path}! βœ…")
130
+
131
+ def evaluate(self, prompt: str):
132
+ """Evaluate the model with a prompt"""
133
+ self.model.eval()
134
+ with torch.no_grad():
135
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
136
+ outputs = self.model.generate(**inputs, max_new_tokens=50)
137
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
138
+
139
+ # Utility Functions
140
+ def get_download_link(file_path, mime_type="text/plain", label="Download"):
141
+ """Generate a download link for a file."""
142
+ with open(file_path, 'rb') as f:
143
+ data = f.read()
144
+ b64 = base64.b64encode(data).decode()
145
+ return f'<a href="data:{mime_type};base64,{b64}" download="{os.path.basename(file_path)}">{label} πŸ“₯</a>'
146
+
147
+ def zip_directory(directory_path, zip_path):
148
+ """Create a zip file from a directory."""
149
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
150
+ for root, _, files in os.walk(directory_path):
151
+ for file in files:
152
+ file_path = os.path.join(root, file)
153
+ arcname = os.path.relpath(file_path, os.path.dirname(directory_path))
154
+ zipf.write(file_path, arcname)
155
+
156
+ def get_model_files():
157
+ """List all saved model directories."""
158
+ return [d for d in glob.glob("models/*") if os.path.isdir(d)]
159
+
160
+ # Main App
161
+ st.title("SFT Model Builder πŸ€–πŸš€")
162
+
163
+ # Sidebar for Model Management
164
+ st.sidebar.header("Model Management πŸ—‚οΈ")
165
+ model_dirs = get_model_files()
166
+ selected_model = st.sidebar.selectbox("Select Saved Model", ["None"] + model_dirs)
167
+
168
+ if selected_model != "None" and st.sidebar.button("Load Model πŸ“‚"):
169
+ if 'builder' not in st.session_state:
170
+ st.session_state['builder'] = ModelBuilder()
171
+ config = ModelConfig(name=os.path.basename(selected_model), base_model="unknown", size="small", domain="general")
172
+ st.session_state['builder'].load_model(selected_model, config)
173
+ st.session_state['model_loaded'] = True
174
+ st.rerun()
175
+
176
+ # Main UI with Tabs
177
+ tab1, tab2, tab3 = st.tabs(["Build New Model 🌱", "Fine-Tune Model πŸ”§", "Test Model πŸ§ͺ"])
178
+
179
+ with tab1:
180
+ st.header("Build New Model 🌱")
181
+ base_model = st.selectbox(
182
+ "Select Base Model",
183
+ ["distilgpt2", "gpt2", "EleutherAI/pythia-70m"],
184
+ help="Choose a small model to start with"
185
+ )
186
+ model_name = st.text_input("Model Name", f"new-model-{int(time.time())}")
187
+ domain = st.text_input("Target Domain", "general")
188
+
189
+ if st.button("Download Model ⬇️"):
190
+ config = ModelConfig(name=model_name, base_model=base_model, size="small", domain=domain)
191
+ builder = ModelBuilder()
192
+ builder.load_model(base_model, config)
193
+ builder.save_model(config.model_path)
194
+ st.session_state['builder'] = builder
195
+ st.session_state['model_loaded'] = True
196
+ st.success(f"Model downloaded and saved to {config.model_path}! πŸŽ‰")
197
+ st.rerun()
198
+
199
+ with tab2:
200
+ st.header("Fine-Tune Model πŸ”§")
201
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
202
+ st.warning("Please download or load a model first! ⚠️")
203
+ else:
204
+ # Generate Sample CSV
205
+ if st.button("Generate Sample CSV πŸ“"):
206
+ sample_data = [
207
+ {"prompt": "What is AI?", "response": "AI is artificial intelligence, simulating human intelligence in machines."},
208
+ {"prompt": "Explain machine learning", "response": "Machine learning is a subset of AI where models learn from data."},
209
+ {"prompt": "What is a neural network?", "response": "A neural network is a model inspired by the human brain."},
210
+ ]
211
+ csv_path = f"sft_data_{int(time.time())}.csv"
212
+ with open(csv_path, "w", newline="") as f:
213
+ writer = csv.DictWriter(f, fieldnames=["prompt", "response"])
214
+ writer.writeheader()
215
+ writer.writerows(sample_data)
216
+ st.markdown(get_download_link(csv_path, "text/csv", "Download Sample CSV"), unsafe_allow_html=True)
217
+ st.success(f"Sample CSV generated as {csv_path}! βœ…")
218
+
219
+ # Upload CSV and Fine-Tune
220
+ uploaded_csv = st.file_uploader("Upload CSV for SFT", type="csv")
221
+ if uploaded_csv and st.button("Fine-Tune with Uploaded CSV πŸ”„"):
222
+ csv_path = f"uploaded_sft_data_{int(time.time())}.csv"
223
+ with open(csv_path, "wb") as f:
224
+ f.write(uploaded_csv.read())
225
+ new_model_name = f"{st.session_state['builder'].config.name}-sft-{int(time.time())}"
226
+ new_config = ModelConfig(
227
+ name=new_model_name,
228
+ base_model=st.session_state['builder'].config.base_model,
229
+ size="small",
230
+ domain=st.session_state['builder'].config.domain
231
+ )
232
+ st.session_state['builder'].config = new_config
233
+ with st.status("Fine-tuning model... ⏳", expanded=True) as status:
234
+ st.session_state['builder'].fine_tune_sft(csv_path)
235
+ st.session_state['builder'].save_model(new_config.model_path)
236
+ status.update(label="Fine-tuning completed! πŸŽ‰", state="complete")
237
+
238
+ # Create a zip file of the model directory
239
+ zip_path = f"{new_config.model_path}.zip"
240
+ zip_directory(new_config.model_path, zip_path)
241
+ st.markdown(get_download_link(zip_path, "application/zip", "Download Fine-Tuned Model"), unsafe_allow_html=True)
242
+ st.rerun()
243
+
244
+ with tab3:
245
+ st.header("Test Model πŸ§ͺ")
246
+ if 'builder' not in st.session_state or not st.session_state.get('model_loaded', False):
247
+ st.warning("Please download or load a model first! ⚠️")
248
+ else:
249
+ if st.session_state['builder'].sft_data:
250
+ st.write("Testing with SFT Data:")
251
+ for item in st.session_state['builder'].sft_data[:3]:
252
+ prompt = item["prompt"]
253
+ expected = item["response"]
254
+ generated = st.session_state['builder'].evaluate(prompt)
255
+ st.write(f"**Prompt**: {prompt}")
256
+ st.write(f"**Expected**: {expected}")
257
+ st.write(f"**Generated**: {generated}")
258
+ st.write("---")
259
+
260
+ test_prompt = st.text_area("Enter Test Prompt", "What is AI?")
261
+ if st.button("Run Test ▢️"):
262
+ result = st.session_state['builder'].evaluate(test_prompt)
263
+ st.write(f"**Generated Response**: {result}")
264
+
265
+ # Export Model Files
266
+ if st.button("Export Model Files πŸ“¦"):
267
+ config = st.session_state['builder'].config
268
+ app_code = f"""
269
+ import streamlit as st
270
+ from transformers import AutoModelForCausalLM, AutoTokenizer
271
+
272
+ model = AutoModelForCausalLM.from_pretrained("{config.model_path}")
273
+ tokenizer = AutoTokenizer.from_pretrained("{config.model_path}")
274
+
275
+ st.title("SFT Model Demo")
276
+ input_text = st.text_area("Enter prompt")
277
+ if st.button("Generate"):
278
+ inputs = tokenizer(input_text, return_tensors="pt")
279
+ outputs = model.generate(**inputs, max_new_tokens=50)
280
+ st.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
281
+ """
282
+ with open("sft_app.py", "w") as f:
283
+ f.write(app_code)
284
+ reqs = "streamlit\ntorch\ntransformers\n"
285
+ with open("sft_requirements.txt", "w") as f:
286
+ f.write(reqs)
287
+ readme = f"""
288
+ # SFT Model Demo
289
+
290
+ ## How to run
291
+ 1. Install requirements: `pip install -r sft_requirements.txt`
292
+ 2. Run the app: `streamlit run sft_app.py`
293
+ 3. Input a prompt and click "Generate".
294
+ """
295
+ with open("sft_README.md", "w") as f:
296
+ f.write(readme)
297
+
298
+ st.markdown(get_download_link("sft_app.py", "text/plain", "Download App"), unsafe_allow_html=True)
299
+ st.markdown(get_download_link("sft_requirements.txt", "text/plain", "Download Requirements"), unsafe_allow_html=True)
300
+ st.markdown(get_download_link("sft_README.md", "text/markdown", "Download README"), unsafe_allow_html=True)
301
+ st.success("Model files exported! βœ…")