Upload 4 files
Browse files- .gitattributes +2 -0
- Screenshot 2025-09-02 at 4.10.14 PM.png +3 -0
- download_GPT_OSS_120B_MXFP4_Q4_Model.py +323 -0
- gpt_oss_ui.py +700 -0
- output.mp4 +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
output.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
Screenshot[[:space:]]2025-09-02[[:space:]]at[[:space:]]4.10.14 PM.png filter=lfs diff=lfs merge=lfs -text
|
Screenshot 2025-09-02 at 4.10.14 PM.png
ADDED
![]() |
Git LFS Details
|
download_GPT_OSS_120B_MXFP4_Q4_Model.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
MLX GPT-OSS-120B-MXFP4-Q4 Model Downloader
|
4 |
+
This script downloads the mlx-community/gpt-oss-120b-MXFP4-Q4 model from Hugging Face Hub
|
5 |
+
with various download options and verification features.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import shutil
|
13 |
+
from datetime import datetime
|
14 |
+
from pathlib import Path
|
15 |
+
from huggingface_hub import snapshot_download, hf_hub_download, HfApi, ModelCard
|
16 |
+
import torch
|
17 |
+
import mlx.core as mx
|
18 |
+
import mlx.nn as nn
|
19 |
+
from transformers import AutoConfig, AutoTokenizer
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
# Set up logging
|
23 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
def get_model_info(repo_id):
|
27 |
+
"""Get information about the model from Hugging Face Hub."""
|
28 |
+
try:
|
29 |
+
api = HfApi()
|
30 |
+
model_info = api.model_info(repo_id)
|
31 |
+
|
32 |
+
logger.info(f"📋 Model Information:")
|
33 |
+
logger.info(f" Name: {model_info.id}")
|
34 |
+
logger.info(f" Downloads: {model_info.downloads:,}")
|
35 |
+
logger.info(f" Likes: {model_info.likes}")
|
36 |
+
logger.info(f" Last Modified: {model_info.lastModified}")
|
37 |
+
logger.info(f" Library: {model_info.library_name}")
|
38 |
+
logger.info(f" Tags: {', '.join(model_info.tags)}")
|
39 |
+
|
40 |
+
# Try to get model card
|
41 |
+
try:
|
42 |
+
card = ModelCard.load(repo_id)
|
43 |
+
logger.info(f" Model Card: {card.data.get('model_name', 'N/A')}")
|
44 |
+
except:
|
45 |
+
logger.info(" Model Card: Not available")
|
46 |
+
|
47 |
+
return model_info
|
48 |
+
except Exception as e:
|
49 |
+
logger.warning(f"⚠️ Could not fetch model info: {e}")
|
50 |
+
return None
|
51 |
+
|
52 |
+
def calculate_disk_space_required(repo_id, revision="main"):
|
53 |
+
"""Calculate approximate disk space required for the model."""
|
54 |
+
try:
|
55 |
+
api = HfApi()
|
56 |
+
files = api.list_repo_files(repo_id, revision=revision)
|
57 |
+
|
58 |
+
total_size = 0
|
59 |
+
model_files = []
|
60 |
+
|
61 |
+
for file in files:
|
62 |
+
if any(file.endswith(ext) for ext in ['.safetensors', '.npz', '.json', '.txt', '.model', '.py']):
|
63 |
+
file_info = api.hf_hub_url(repo_id, file, revision=revision)
|
64 |
+
# This is approximate - actual download might use more space due to temp files
|
65 |
+
if 'safetensors' in file or 'npz' in file:
|
66 |
+
model_files.append(file)
|
67 |
+
|
68 |
+
# GPT-OSS-120B-MXFP4-Q4 is approximately 60-70GB in MXFP4-Q4 format
|
69 |
+
logger.info(f"💾 Estimated download size: ~60-70GB (MXFP4-Q4 format)")
|
70 |
+
logger.info(f" Model files: {len(model_files)} weight files")
|
71 |
+
|
72 |
+
return model_files
|
73 |
+
except Exception as e:
|
74 |
+
logger.warning(f"⚠️ Could not calculate disk space: {e}")
|
75 |
+
return []
|
76 |
+
|
77 |
+
def download_model(args):
|
78 |
+
"""Download the model with specified options."""
|
79 |
+
repo_id = "mlx-community/gpt-oss-120b-MXFP4-Q4"
|
80 |
+
|
81 |
+
logger.info("=" * 70)
|
82 |
+
logger.info("🤗 MLX GPT-OSS-120B-MXFP4-Q4 Model Downloader")
|
83 |
+
logger.info("=" * 70)
|
84 |
+
|
85 |
+
# Get model information
|
86 |
+
model_info = get_model_info(repo_id)
|
87 |
+
calculate_disk_space_required(repo_id, args.revision)
|
88 |
+
|
89 |
+
download_kwargs = {
|
90 |
+
"repo_id": repo_id,
|
91 |
+
"revision": args.revision,
|
92 |
+
"local_dir": args.output_dir,
|
93 |
+
"local_dir_use_symlinks": False, # Always copy files, don't symlink
|
94 |
+
"resume_download": True,
|
95 |
+
"force_download": args.force_download,
|
96 |
+
}
|
97 |
+
|
98 |
+
if args.allow_patterns:
|
99 |
+
download_kwargs["allow_patterns"] = args.allow_patterns
|
100 |
+
if args.ignore_patterns:
|
101 |
+
download_kwargs["ignore_patterns"] = args.ignore_patterns
|
102 |
+
|
103 |
+
try:
|
104 |
+
logger.info(f"🚀 Starting download of {repo_id}")
|
105 |
+
logger.info(f"📁 Output directory: {args.output_dir}")
|
106 |
+
logger.info(f"🔖 Revision: {args.revision}")
|
107 |
+
logger.info(f"💾 Cache dir: {args.cache_dir}")
|
108 |
+
|
109 |
+
if args.cache_dir:
|
110 |
+
download_kwargs["cache_dir"] = args.cache_dir
|
111 |
+
|
112 |
+
# Download the model
|
113 |
+
model_path = snapshot_download(**download_kwargs)
|
114 |
+
|
115 |
+
logger.info(f"✅ Download completed successfully!")
|
116 |
+
logger.info(f"📦 Model saved to: {model_path}")
|
117 |
+
|
118 |
+
return model_path
|
119 |
+
|
120 |
+
except Exception as e:
|
121 |
+
logger.error(f"❌ Download failed: {e}")
|
122 |
+
raise
|
123 |
+
|
124 |
+
def verify_model_download(model_path):
|
125 |
+
"""Verify that the model download is complete and valid."""
|
126 |
+
logger.info(f"🔍 Verifying model download...")
|
127 |
+
|
128 |
+
required_files = [
|
129 |
+
"config.json",
|
130 |
+
"tokenizer.json",
|
131 |
+
"tokenizer_config.json",
|
132 |
+
"model.npz", # MLX models use .npz files
|
133 |
+
"generation_config.json"
|
134 |
+
]
|
135 |
+
|
136 |
+
missing_files = []
|
137 |
+
for file in required_files:
|
138 |
+
if not os.path.exists(os.path.join(model_path, file)):
|
139 |
+
missing_files.append(file)
|
140 |
+
|
141 |
+
if missing_files:
|
142 |
+
logger.warning(f"⚠️ Missing files: {missing_files}")
|
143 |
+
else:
|
144 |
+
logger.info("✅ All required files present")
|
145 |
+
|
146 |
+
# Try to load config
|
147 |
+
try:
|
148 |
+
config = AutoConfig.from_pretrained(model_path)
|
149 |
+
logger.info(f"✅ Config loaded successfully")
|
150 |
+
logger.info(f" Architecture: {config.architectures[0] if config.architectures else 'N/A'}")
|
151 |
+
logger.info(f" Vocab size: {config.vocab_size:,}")
|
152 |
+
logger.info(f" Hidden size: {config.hidden_size}")
|
153 |
+
logger.info(f" Num layers: {config.num_hidden_layers}")
|
154 |
+
logger.info(f" Model type: {config.model_type}")
|
155 |
+
except Exception as e:
|
156 |
+
logger.warning(f"⚠️ Could not load config: {e}")
|
157 |
+
|
158 |
+
return len(missing_files) == 0
|
159 |
+
|
160 |
+
def load_model_for_verification(model_path, args):
|
161 |
+
"""Optionally load the model to verify it works (memory intensive)."""
|
162 |
+
if not args.verify_load:
|
163 |
+
return None
|
164 |
+
|
165 |
+
logger.info("🧪 Loading model for verification (this may take a while and require significant RAM)...")
|
166 |
+
|
167 |
+
try:
|
168 |
+
# Load tokenizer first
|
169 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
170 |
+
logger.info("✅ Tokenizer loaded successfully")
|
171 |
+
|
172 |
+
# For MLX models, we need to check if the weights can be loaded
|
173 |
+
try:
|
174 |
+
# Check if we can load the weights
|
175 |
+
weights = mx.load(os.path.join(model_path, "model.npz"))
|
176 |
+
logger.info(f"✅ Model weights loaded successfully")
|
177 |
+
logger.info(f" Number of weight arrays: {len(weights)}")
|
178 |
+
|
179 |
+
# Test a simple inference if requested
|
180 |
+
if args.test_inference:
|
181 |
+
logger.info("🧪 Testing tokenizer and basic functionality...")
|
182 |
+
test_text = "The capital of France is"
|
183 |
+
inputs = tokenizer(test_text, return_tensors="np")
|
184 |
+
|
185 |
+
logger.info(f"📝 Tokenized input: {inputs}")
|
186 |
+
logger.info(f" Input shape: {inputs['input_ids'].shape}")
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
logger.warning(f"⚠️ Model weight loading failed: {e}")
|
190 |
+
|
191 |
+
return None, tokenizer
|
192 |
+
|
193 |
+
except Exception as e:
|
194 |
+
logger.warning(f"⚠️ Model loading failed: {e}")
|
195 |
+
return None, None
|
196 |
+
|
197 |
+
def create_readme(model_path, args):
|
198 |
+
"""Create a README file with download information."""
|
199 |
+
readme_content = f"""# GPT-OSS-120B-MXFP4-Q4 Model Download
|
200 |
+
## Download Information
|
201 |
+
- **Model**: mlx-community/gpt-oss-120b-MXFP4-Q4
|
202 |
+
- **Download Date**: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
203 |
+
- **Revision**: {args.revision}
|
204 |
+
- **Output Directory**: {args.output_dir}
|
205 |
+
## Download Options Used
|
206 |
+
- Allow patterns: {args.allow_patterns or 'All files'}
|
207 |
+
- Ignore patterns: {args.ignore_patterns or 'None'}
|
208 |
+
- Force download: {args.force_download}
|
209 |
+
- Verify load: {args.verify_load}
|
210 |
+
- Test inference: {args.test_inference}
|
211 |
+
## Model Details
|
212 |
+
- **Architecture**: Transformer-based causal language model
|
213 |
+
- **Parameters**: 120 billion
|
214 |
+
- **Context Length**: 4096 tokens
|
215 |
+
- **Quantization**: MXFP4-Q4 (4-bit quantization optimized for MLX)
|
216 |
+
- **Framework**: MLX (Apple Silicon optimized)
|
217 |
+
- **Languages**: Primarily English
|
218 |
+
## Usage with MLX
|
219 |
+
```python
|
220 |
+
import mlx.core as mx
|
221 |
+
import mlx.nn as nn
|
222 |
+
from transformers import AutoTokenizer
|
223 |
+
|
224 |
+
# Load weights
|
225 |
+
weights = mx.load("{model_path}/model.npz")
|
226 |
+
|
227 |
+
# Load tokenizer
|
228 |
+
tokenizer = AutoTokenizer.from_pretrained("{model_path}")
|
229 |
+
|
230 |
+
# Note: You'll need to implement the model architecture to use the weights
|
231 |
+
Usage with Transformers (for tokenizer only)
|
232 |
+
|
233 |
+
python
|
234 |
+
from transformers import AutoTokenizer
|
235 |
+
|
236 |
+
tokenizer = AutoTokenizer.from_pretrained("{model_path}")
|
237 |
+
"""
|
238 |
+
readme_path = os.path.join(model_path, "DOWNLOAD_INFO.md")
|
239 |
+
with open(readme_path, 'w') as f:
|
240 |
+
f.write(readme_content)
|
241 |
+
|
242 |
+
logger.info(f"📝 Created README: {readme_path}")
|
243 |
+
|
244 |
+
def main():
|
245 |
+
parser = argparse.ArgumentParser(description="Download mlx-community/gpt-oss-120b-MXFP4-Q4 model")
|
246 |
+
# Download options
|
247 |
+
parser.add_argument("--output-dir", type=str, default="./gpt_oss_120b_mxfp4_q4",
|
248 |
+
help="Directory to save the model")
|
249 |
+
parser.add_argument("--cache-dir", type=str, default="./hf_cache",
|
250 |
+
help="Cache directory for Hugging Face")
|
251 |
+
parser.add_argument("--revision", type=str, default="main",
|
252 |
+
help="Model revision/branch to download")
|
253 |
+
parser.add_argument("--force-download", action="store_true",
|
254 |
+
help="Force re-download even if files exist")
|
255 |
+
|
256 |
+
# Filter options
|
257 |
+
parser.add_argument("--allow-patterns", nargs="+",
|
258 |
+
help="Only download files matching these patterns")
|
259 |
+
parser.add_argument("--ignore-patterns", nargs="+",
|
260 |
+
default=["*.h5", "*.ot", "*.msgpack", "*.tflite", "*.bin"],
|
261 |
+
help="Skip files matching these patterns")
|
262 |
+
|
263 |
+
# Verification options
|
264 |
+
parser.add_argument("--verify-load", action="store_true",
|
265 |
+
help="Load model after download to verify it works")
|
266 |
+
parser.add_argument("--test-inference", action="store_true",
|
267 |
+
help="Run a test inference after loading")
|
268 |
+
|
269 |
+
args = parser.parse_args()
|
270 |
+
|
271 |
+
# Create output directory
|
272 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
273 |
+
|
274 |
+
try:
|
275 |
+
# Download the model
|
276 |
+
model_path = download_model(args)
|
277 |
+
|
278 |
+
# Verify download
|
279 |
+
verify_model_download(model_path)
|
280 |
+
|
281 |
+
# Optionally load and test the model
|
282 |
+
if args.verify_load:
|
283 |
+
model, tokenizer = load_model_for_verification(model_path, args)
|
284 |
+
|
285 |
+
# Create readme
|
286 |
+
create_readme(model_path, args)
|
287 |
+
|
288 |
+
logger.info("🎉 Model download and verification completed successfully!")
|
289 |
+
logger.info(f"📁 Model available at: {model_path}")
|
290 |
+
logger.info("💡 Next steps: Use the model with MLX framework:")
|
291 |
+
logger.info(f" import mlx.core as mx")
|
292 |
+
logger.info(f" weights = mx.load('{model_path}/model.npz')")
|
293 |
+
logger.info(f" from transformers import AutoTokenizer")
|
294 |
+
logger.info(f" tokenizer = AutoTokenizer.from_pretrained('{model_path}')")
|
295 |
+
|
296 |
+
except Exception as e:
|
297 |
+
logger.error(f"❌ Download failed: {e}")
|
298 |
+
return 1
|
299 |
+
|
300 |
+
return 0
|
301 |
+
if __name__ == "__main__":
|
302 |
+
exit(main())
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
## Key Differences for the GPT-OSS-120B-MXFP4-Q4 Model:
|
307 |
+
|
308 |
+
## 1. **Model Format**: This model uses MLX's `.npz` format instead of PyTorch's `.safetensors` or `.bin` files
|
309 |
+
## 2. **Framework**: Optimized for Apple's MLX framework rather than standard PyTorch
|
310 |
+
## 3. **Quantization**: Uses MXFP4-Q4 quantization (4-bit) which is specific to MLX
|
311 |
+
## 4. **Size**: At 120B parameters, this is a much larger model than SmolLM3-3B
|
312 |
+
## 5. **Loading**: The model loading process is different for MLX models compared to standard Transformers models
|
313 |
+
|
314 |
+
## Usage Notes:
|
315 |
+
|
316 |
+
## 1. This script requires the `mlx` package to be installed for full functionality
|
317 |
+
## 2. The model is optimized for Apple Silicon devices
|
318 |
+
## 3. Due to the model's large size (60-70GB), ensure you have sufficient disk space
|
319 |
+
## 4. The script includes special handling for MLX's file format and quantization
|
320 |
+
|
321 |
+
## You can run this script with various options like:
|
322 |
+
## ```bash
|
323 |
+
## python download_gpt_oss_120b.py --output-dir ./my_model --verify-load --test-inference
|
gpt_oss_ui.py
ADDED
@@ -0,0 +1,700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Enhanced Modern UI for GPT-OSS-120B Chat Interface
|
4 |
+
"""
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import markdown
|
10 |
+
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
11 |
+
QHBoxLayout, QTextEdit, QLineEdit, QPushButton,
|
12 |
+
QLabel, QScrollArea, QFrame, QGroupBox, QSpinBox,
|
13 |
+
QSizePolicy, QProgressBar, QSplitter, QToolButton,
|
14 |
+
QMenu, QAction, QFileDialog, QMessageBox)
|
15 |
+
from PyQt5.QtCore import Qt, QThread, pyqtSignal, QTimer, QSize
|
16 |
+
from PyQt5.QtGui import QFont, QTextCursor, QPalette, QColor, QIcon, QTextCharFormat, QSyntaxHighlighter, QTextDocument
|
17 |
+
from mlx_lm import load, generate
|
18 |
+
import logging
|
19 |
+
import re
|
20 |
+
import json
|
21 |
+
from datetime import datetime
|
22 |
+
from typing import List, Dict
|
23 |
+
|
24 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
class ModelLoaderThread(QThread):
|
28 |
+
"""Thread for loading the model to prevent UI freezing"""
|
29 |
+
model_loaded = pyqtSignal()
|
30 |
+
model_error = pyqtSignal(str)
|
31 |
+
progress_update = pyqtSignal(str)
|
32 |
+
|
33 |
+
def __init__(self):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
def run(self):
|
37 |
+
try:
|
38 |
+
self.progress_update.emit("Downloading model files...")
|
39 |
+
logger.info("🚀 Loading GPT-OSS-120B...")
|
40 |
+
model, tokenizer = load("mlx-community/gpt-oss-120b-MXFP4-Q4")
|
41 |
+
logger.info("✅ Model loaded successfully!")
|
42 |
+
self.progress_update.emit("Model loaded successfully!")
|
43 |
+
self.model_loaded.emit()
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"Failed to load model: {e}")
|
46 |
+
self.model_error.emit(str(e))
|
47 |
+
|
48 |
+
|
49 |
+
class GenerationThread(QThread):
|
50 |
+
"""Thread for generating responses to prevent UI freezing"""
|
51 |
+
response_ready = pyqtSignal(str, float)
|
52 |
+
error_occurred = pyqtSignal(str)
|
53 |
+
progress_update = pyqtSignal(str)
|
54 |
+
|
55 |
+
def __init__(self, model, tokenizer, prompt, max_tokens):
|
56 |
+
super().__init__()
|
57 |
+
self.model = model
|
58 |
+
self.tokenizer = tokenizer
|
59 |
+
self.prompt = prompt
|
60 |
+
self.max_tokens = max_tokens
|
61 |
+
|
62 |
+
def run(self):
|
63 |
+
try:
|
64 |
+
start_time = time.time()
|
65 |
+
|
66 |
+
# Format prompt with chat template
|
67 |
+
self.progress_update.emit("Formatting prompt...")
|
68 |
+
messages = [{"role": "user", "content": self.prompt}]
|
69 |
+
formatted_prompt = self.tokenizer.apply_chat_template(
|
70 |
+
messages, add_generation_prompt=True
|
71 |
+
)
|
72 |
+
|
73 |
+
# Generate response
|
74 |
+
self.progress_update.emit("Generating response...")
|
75 |
+
response = generate(
|
76 |
+
self.model,
|
77 |
+
self.tokenizer,
|
78 |
+
prompt=formatted_prompt,
|
79 |
+
max_tokens=self.max_tokens,
|
80 |
+
verbose=False
|
81 |
+
)
|
82 |
+
|
83 |
+
# Extract and clean the final response
|
84 |
+
self.progress_update.emit("Processing response...")
|
85 |
+
final_response = self.extract_final_response(response)
|
86 |
+
generation_time = time.time() - start_time
|
87 |
+
|
88 |
+
self.response_ready.emit(final_response, generation_time)
|
89 |
+
|
90 |
+
except Exception as e:
|
91 |
+
self.error_occurred.emit(str(e))
|
92 |
+
|
93 |
+
def extract_final_response(self, response: str) -> str:
|
94 |
+
"""Extract the final assistant response from the chat template"""
|
95 |
+
# Look for the final assistant response
|
96 |
+
if "<|start|>assistant" in response:
|
97 |
+
parts = response.split("<|start|>assistant")
|
98 |
+
if len(parts) > 1:
|
99 |
+
final_part = parts[-1]
|
100 |
+
|
101 |
+
# Remove all channel and message tags
|
102 |
+
final_part = re.sub(r'<\|channel\|>[^<]+', '', final_part)
|
103 |
+
final_part = final_part.replace('<|message|>', '')
|
104 |
+
final_part = final_part.replace('<|end|>', '')
|
105 |
+
|
106 |
+
# Clean up any remaining tags or whitespace
|
107 |
+
final_part = re.sub(r'<[^>]+>', '', final_part)
|
108 |
+
final_part = final_part.strip()
|
109 |
+
|
110 |
+
if final_part:
|
111 |
+
return final_part
|
112 |
+
|
113 |
+
# Fallback: return the original response cleaned up
|
114 |
+
cleaned = re.sub(r'<\|[^>]+\|>', '', response)
|
115 |
+
cleaned = re.sub(r'<[^>]+>', '', cleaned)
|
116 |
+
return cleaned.strip()
|
117 |
+
|
118 |
+
|
119 |
+
class CodeHighlighter(QSyntaxHighlighter):
|
120 |
+
"""Basic syntax highlighter for code blocks"""
|
121 |
+
def __init__(self, parent=None):
|
122 |
+
super().__init__(parent)
|
123 |
+
|
124 |
+
self.highlighting_rules = []
|
125 |
+
|
126 |
+
# Keyword format
|
127 |
+
keyword_format = QTextCharFormat()
|
128 |
+
keyword_format.setForeground(QColor("#569CD6"))
|
129 |
+
keyword_format.setFontWeight(QFont.Bold)
|
130 |
+
keywords = ["def", "class", "return", "import", "from", "as", "if",
|
131 |
+
"else", "elif", "for", "while", "try", "except", "finally"]
|
132 |
+
for word in keywords:
|
133 |
+
pattern = r'\b' + word + r'\b'
|
134 |
+
self.highlighting_rules.append((re.compile(pattern), keyword_format))
|
135 |
+
|
136 |
+
# String format
|
137 |
+
string_format = QTextCharFormat()
|
138 |
+
string_format.setForeground(QColor("#CE9178"))
|
139 |
+
self.highlighting_rules.append((re.compile(r'\".*\"'), string_format))
|
140 |
+
self.highlighting_rules.append((re.compile(r'\'.*\''), string_format))
|
141 |
+
|
142 |
+
# Comment format
|
143 |
+
comment_format = QTextCharFormat()
|
144 |
+
comment_format.setForeground(QColor("#6A9955"))
|
145 |
+
self.highlighting_rules.append((re.compile(r'#.*'), comment_format))
|
146 |
+
|
147 |
+
def highlightBlock(self, text):
|
148 |
+
for pattern, format in self.highlighting_rules:
|
149 |
+
for match in pattern.finditer(text):
|
150 |
+
start, end = match.span()
|
151 |
+
self.setFormat(start, end - start, format)
|
152 |
+
|
153 |
+
|
154 |
+
class ChatMessageWidget(QWidget):
|
155 |
+
"""Custom widget for displaying chat messages"""
|
156 |
+
def __init__(self, is_user, message, timestamp=None, generation_time=None):
|
157 |
+
super().__init__()
|
158 |
+
self.is_user = is_user
|
159 |
+
|
160 |
+
layout = QVBoxLayout()
|
161 |
+
layout.setContentsMargins(15, 8, 15, 8)
|
162 |
+
|
163 |
+
# Header with sender info and timestamp
|
164 |
+
header_layout = QHBoxLayout()
|
165 |
+
|
166 |
+
sender_icon = QLabel("👤" if is_user else "🤖")
|
167 |
+
sender_label = QLabel("You" if is_user else "GPT-OSS-120B")
|
168 |
+
sender_label.setStyleSheet("font-weight: bold; color: #2E86AB;" if is_user else "font-weight: bold; color: #A23B72;")
|
169 |
+
|
170 |
+
time_text = timestamp if timestamp else datetime.now().strftime("%H:%M:%S")
|
171 |
+
time_label = QLabel(time_text)
|
172 |
+
time_label.setStyleSheet("color: #777; font-size: 11px;")
|
173 |
+
|
174 |
+
header_layout.addWidget(sender_icon)
|
175 |
+
header_layout.addWidget(sender_label)
|
176 |
+
header_layout.addStretch()
|
177 |
+
header_layout.addWidget(time_label)
|
178 |
+
|
179 |
+
if generation_time and not is_user:
|
180 |
+
speed_label = QLabel(f"{generation_time:.1f}s")
|
181 |
+
speed_label.setStyleSheet("color: #777; font-size: 11px;")
|
182 |
+
header_layout.addWidget(speed_label)
|
183 |
+
|
184 |
+
layout.addLayout(header_layout)
|
185 |
+
|
186 |
+
# Message content - use QTextEdit for proper text rendering
|
187 |
+
message_display = QTextEdit()
|
188 |
+
message_display.setReadOnly(True)
|
189 |
+
|
190 |
+
# Format message with basic markdown support
|
191 |
+
formatted_message = self.format_message(message)
|
192 |
+
message_display.setHtml(formatted_message)
|
193 |
+
|
194 |
+
message_display.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
|
195 |
+
message_display.setHorizontalScrollBarPolicy(Qt.ScrollBarAsNeeded)
|
196 |
+
message_display.setStyleSheet("""
|
197 |
+
QTextEdit {
|
198 |
+
background-color: %s;
|
199 |
+
border: 1px solid %s;
|
200 |
+
border-radius: 12px;
|
201 |
+
padding: 12px;
|
202 |
+
margin: 2px;
|
203 |
+
font-size: 14px;
|
204 |
+
}
|
205 |
+
""" % ("#E8F4F8" if is_user else "#F8F0F5", "#B8D8E8" if is_user else "#E8C6DE"))
|
206 |
+
|
207 |
+
# Set size policy
|
208 |
+
message_display.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred)
|
209 |
+
message_display.setMinimumHeight(50)
|
210 |
+
message_display.setMaximumHeight(600)
|
211 |
+
|
212 |
+
# Add syntax highlighter for code blocks
|
213 |
+
if not is_user and self.contains_code(message):
|
214 |
+
highlighter = CodeHighlighter(message_display.document())
|
215 |
+
|
216 |
+
layout.addWidget(message_display)
|
217 |
+
self.setLayout(layout)
|
218 |
+
|
219 |
+
def format_message(self, message):
|
220 |
+
"""Format message with basic HTML styling"""
|
221 |
+
# Convert markdown to basic HTML
|
222 |
+
html = markdown.markdown(message)
|
223 |
+
|
224 |
+
# Add some basic styling
|
225 |
+
styled_html = f"""
|
226 |
+
<div style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
227 |
+
line-height: 1.4; color: #333;">
|
228 |
+
{html}
|
229 |
+
</div>
|
230 |
+
"""
|
231 |
+
return styled_html
|
232 |
+
|
233 |
+
def contains_code(self, message):
|
234 |
+
"""Check if message contains code-like content"""
|
235 |
+
code_indicators = ["def ", "class ", "import ", "function ", "var ", "const ", "=", "()", "{}", "[]"]
|
236 |
+
return any(indicator in message for indicator in code_indicators)
|
237 |
+
|
238 |
+
|
239 |
+
class GPTOSSChatUI(QMainWindow):
|
240 |
+
def __init__(self):
|
241 |
+
super().__init__()
|
242 |
+
self.model = None
|
243 |
+
self.tokenizer = None
|
244 |
+
self.conversation_history = []
|
245 |
+
self.max_tokens = 2048
|
246 |
+
self.generation_thread = None
|
247 |
+
self.model_loader_thread = None
|
248 |
+
|
249 |
+
self.init_ui()
|
250 |
+
self.load_model_in_background()
|
251 |
+
|
252 |
+
def init_ui(self):
|
253 |
+
"""Initialize the user interface"""
|
254 |
+
self.setWindowTitle("GPT-OSS-120B Chat")
|
255 |
+
self.setGeometry(100, 100, 1400, 900)
|
256 |
+
|
257 |
+
# Central widget
|
258 |
+
central_widget = QWidget()
|
259 |
+
self.setCentralWidget(central_widget)
|
260 |
+
|
261 |
+
# Main layout
|
262 |
+
main_layout = QHBoxLayout(central_widget)
|
263 |
+
main_layout.setContentsMargins(15, 15, 15, 15)
|
264 |
+
main_layout.setSpacing(15)
|
265 |
+
|
266 |
+
# Left panel for settings
|
267 |
+
left_panel = QFrame()
|
268 |
+
left_panel.setMinimumWidth(280)
|
269 |
+
left_panel.setMaximumWidth(350)
|
270 |
+
left_panel.setFrameShape(QFrame.StyledPanel)
|
271 |
+
left_panel_layout = QVBoxLayout(left_panel)
|
272 |
+
left_panel_layout.setContentsMargins(12, 12, 12, 12)
|
273 |
+
|
274 |
+
# App title
|
275 |
+
title_label = QLabel("GPT-OSS-120B Chat")
|
276 |
+
title_label.setStyleSheet("font-size: 18px; font-weight: bold; color: #2E86AB; margin-bottom: 15px;")
|
277 |
+
title_label.setAlignment(Qt.AlignCenter)
|
278 |
+
left_panel_layout.addWidget(title_label)
|
279 |
+
|
280 |
+
# Model info
|
281 |
+
model_info_group = QGroupBox("🤖 Model Information")
|
282 |
+
model_info_group.setStyleSheet("QGroupBox { font-weight: bold; }")
|
283 |
+
model_info_layout = QVBoxLayout()
|
284 |
+
|
285 |
+
model_details = [
|
286 |
+
("GPT-OSS-120B", "font-weight: bold; font-size: 14px; color: #333;"),
|
287 |
+
("120B parameters, 4-bit quantized", "color: #666; font-size: 12px;"),
|
288 |
+
("Apple M3 Ultra • 512GB RAM", "color: #666; font-size: 12px;"),
|
289 |
+
("Performance: ~95 tokens/second", "color: #4CAF50; font-size: 12px; font-weight: bold;")
|
290 |
+
]
|
291 |
+
|
292 |
+
for text, style in model_details:
|
293 |
+
label = QLabel(text)
|
294 |
+
label.setStyleSheet(style)
|
295 |
+
label.setWordWrap(True)
|
296 |
+
model_info_layout.addWidget(label)
|
297 |
+
|
298 |
+
model_info_group.setLayout(model_info_layout)
|
299 |
+
left_panel_layout.addWidget(model_info_group)
|
300 |
+
|
301 |
+
# Generation settings
|
302 |
+
settings_group = QGroupBox("⚙️ Generation Settings")
|
303 |
+
settings_group.setStyleSheet("QGroupBox { font-weight: bold; }")
|
304 |
+
settings_layout = QVBoxLayout()
|
305 |
+
|
306 |
+
# Max tokens setting
|
307 |
+
tokens_layout = QHBoxLayout()
|
308 |
+
tokens_label = QLabel("Max Tokens:")
|
309 |
+
tokens_label.setStyleSheet("font-weight: bold;")
|
310 |
+
self.tokens_spinner = QSpinBox()
|
311 |
+
self.tokens_spinner.setRange(128, 4096)
|
312 |
+
self.tokens_spinner.setValue(2048)
|
313 |
+
self.tokens_spinner.valueChanged.connect(self.update_max_tokens)
|
314 |
+
self.tokens_spinner.setStyleSheet("padding: 6px; border-radius: 4px;")
|
315 |
+
tokens_layout.addWidget(tokens_label)
|
316 |
+
tokens_layout.addWidget(self.tokens_spinner)
|
317 |
+
settings_layout.addLayout(tokens_layout)
|
318 |
+
|
319 |
+
settings_group.setLayout(settings_layout)
|
320 |
+
left_panel_layout.addWidget(settings_group)
|
321 |
+
|
322 |
+
# Conversation management
|
323 |
+
conv_group = QGroupBox("💬 Conversation")
|
324 |
+
conv_group.setStyleSheet("QGroupBox { font-weight: bold; }")
|
325 |
+
conv_layout = QVBoxLayout()
|
326 |
+
|
327 |
+
clear_btn = QPushButton("🗑️ Clear Conversation")
|
328 |
+
clear_btn.clicked.connect(self.clear_conversation)
|
329 |
+
clear_btn.setStyleSheet("text-align: left; padding: 8px;")
|
330 |
+
conv_layout.addWidget(clear_btn)
|
331 |
+
|
332 |
+
export_btn = QPushButton("💾 Export Conversation")
|
333 |
+
export_btn.clicked.connect(self.export_conversation)
|
334 |
+
export_btn.setStyleSheet("text-align: left; padding: 8px;")
|
335 |
+
conv_layout.addWidget(export_btn)
|
336 |
+
|
337 |
+
conv_group.setLayout(conv_layout)
|
338 |
+
left_panel_layout.addWidget(conv_group)
|
339 |
+
|
340 |
+
left_panel_layout.addStretch()
|
341 |
+
|
342 |
+
# Status indicator
|
343 |
+
self.status_indicator = QLabel("🟡 Loading model...")
|
344 |
+
self.status_indicator.setStyleSheet("color: #666; font-size: 11px; margin-top: 10px;")
|
345 |
+
left_panel_layout.addWidget(self.status_indicator)
|
346 |
+
|
347 |
+
# Right panel for chat
|
348 |
+
right_panel = QWidget()
|
349 |
+
right_panel_layout = QVBoxLayout(right_panel)
|
350 |
+
right_panel_layout.setContentsMargins(0, 0, 0, 0)
|
351 |
+
|
352 |
+
# Chat history area
|
353 |
+
self.chat_scroll = QScrollArea()
|
354 |
+
self.chat_scroll.setWidgetResizable(True)
|
355 |
+
self.chat_scroll.setVerticalScrollBarPolicy(Qt.ScrollBarAsNeeded)
|
356 |
+
self.chat_scroll.setHorizontalScrollBarPolicy(Qt.ScrollBarAlwaysOff)
|
357 |
+
self.chat_scroll.setStyleSheet("background-color: #FAFAFA; border: none;")
|
358 |
+
|
359 |
+
self.chat_container = QWidget()
|
360 |
+
self.chat_layout = QVBoxLayout(self.chat_container)
|
361 |
+
self.chat_layout.setAlignment(Qt.AlignTop)
|
362 |
+
self.chat_layout.setSpacing(10)
|
363 |
+
self.chat_layout.setContentsMargins(10, 10, 10, 10)
|
364 |
+
|
365 |
+
self.chat_scroll.setWidget(self.chat_container)
|
366 |
+
right_panel_layout.addWidget(self.chat_scroll)
|
367 |
+
|
368 |
+
# Input area
|
369 |
+
input_frame = QFrame()
|
370 |
+
input_frame.setStyleSheet("background-color: white; border-top: 1px solid #EEE;")
|
371 |
+
input_layout = QVBoxLayout(input_frame)
|
372 |
+
input_layout.setContentsMargins(15, 15, 15, 15)
|
373 |
+
|
374 |
+
# Message input with character count
|
375 |
+
input_top_layout = QHBoxLayout()
|
376 |
+
self.message_input = QTextEdit()
|
377 |
+
self.message_input.setPlaceholderText("Type your message here... (Shift+Enter for new line)")
|
378 |
+
self.message_input.setMaximumHeight(100)
|
379 |
+
self.message_input.setStyleSheet("""
|
380 |
+
QTextEdit {
|
381 |
+
padding: 12px;
|
382 |
+
border: 2px solid #DDD;
|
383 |
+
border-radius: 8px;
|
384 |
+
font-size: 14px;
|
385 |
+
}
|
386 |
+
QTextEdit:focus {
|
387 |
+
border-color: #2E86AB;
|
388 |
+
}
|
389 |
+
""")
|
390 |
+
self.message_input.textChanged.connect(self.update_char_count)
|
391 |
+
input_top_layout.addWidget(self.message_input)
|
392 |
+
|
393 |
+
self.send_btn = QPushButton("Send")
|
394 |
+
self.send_btn.setFixedSize(80, 50)
|
395 |
+
self.send_btn.clicked.connect(self.send_message)
|
396 |
+
self.send_btn.setStyleSheet("""
|
397 |
+
QPushButton {
|
398 |
+
background-color: #2E86AB;
|
399 |
+
color: white;
|
400 |
+
border: none;
|
401 |
+
border-radius: 8px;
|
402 |
+
font-weight: bold;
|
403 |
+
}
|
404 |
+
QPushButton:hover {
|
405 |
+
background-color: #1F5E7A;
|
406 |
+
}
|
407 |
+
QPushButton:disabled {
|
408 |
+
background-color: #CCCCCC;
|
409 |
+
}
|
410 |
+
""")
|
411 |
+
input_top_layout.addWidget(self.send_btn)
|
412 |
+
|
413 |
+
input_layout.addLayout(input_top_layout)
|
414 |
+
|
415 |
+
# Character count and controls
|
416 |
+
bottom_layout = QHBoxLayout()
|
417 |
+
self.char_count = QLabel("0 characters")
|
418 |
+
self.char_count.setStyleSheet("color: #777; font-size: 11px;")
|
419 |
+
bottom_layout.addWidget(self.char_count)
|
420 |
+
|
421 |
+
bottom_layout.addStretch()
|
422 |
+
|
423 |
+
# Add some utility buttons
|
424 |
+
clear_input_btn = QPushButton("Clear Input")
|
425 |
+
clear_input_btn.setStyleSheet("font-size: 11px; padding: 4px 8px;")
|
426 |
+
clear_input_btn.clicked.connect(self.clear_input)
|
427 |
+
bottom_layout.addWidget(clear_input_btn)
|
428 |
+
|
429 |
+
input_layout.addLayout(bottom_layout)
|
430 |
+
right_panel_layout.addWidget(input_frame)
|
431 |
+
|
432 |
+
# Add panels to main layout
|
433 |
+
main_layout.addWidget(left_panel)
|
434 |
+
main_layout.addWidget(right_panel)
|
435 |
+
|
436 |
+
# Status bar
|
437 |
+
self.statusBar().showMessage("Ready")
|
438 |
+
|
439 |
+
# Set styles
|
440 |
+
self.apply_styles()
|
441 |
+
|
442 |
+
def apply_styles(self):
|
443 |
+
"""Apply modern styling to the UI"""
|
444 |
+
self.setStyleSheet("""
|
445 |
+
QMainWindow {
|
446 |
+
background-color: #F5F5F7;
|
447 |
+
}
|
448 |
+
QGroupBox {
|
449 |
+
font-weight: bold;
|
450 |
+
border: 1px solid #E0E0E0;
|
451 |
+
border-radius: 8px;
|
452 |
+
margin-top: 10px;
|
453 |
+
padding-top: 20px;
|
454 |
+
background-color: white;
|
455 |
+
}
|
456 |
+
QGroupBox::title {
|
457 |
+
subcontrol-origin: margin;
|
458 |
+
left: 10px;
|
459 |
+
padding: 0 8px 0 8px;
|
460 |
+
color: #2E86AB;
|
461 |
+
}
|
462 |
+
QPushButton {
|
463 |
+
background-color: #2E86AB;
|
464 |
+
color: white;
|
465 |
+
border: none;
|
466 |
+
padding: 8px 16px;
|
467 |
+
border-radius: 6px;
|
468 |
+
font-weight: bold;
|
469 |
+
}
|
470 |
+
QPushButton:hover {
|
471 |
+
background-color: #1F5E7A;
|
472 |
+
}
|
473 |
+
QPushButton:disabled {
|
474 |
+
background-color: #CCCCCC;
|
475 |
+
}
|
476 |
+
QScrollArea {
|
477 |
+
border: none;
|
478 |
+
background-color: #FAFAFA;
|
479 |
+
}
|
480 |
+
QSpinBox {
|
481 |
+
padding: 6px;
|
482 |
+
border: 1px solid #DDD;
|
483 |
+
border-radius: 4px;
|
484 |
+
background-color: white;
|
485 |
+
}
|
486 |
+
QFrame {
|
487 |
+
background-color: white;
|
488 |
+
border-radius: 8px;
|
489 |
+
}
|
490 |
+
""")
|
491 |
+
|
492 |
+
def update_char_count(self):
|
493 |
+
"""Update character count label"""
|
494 |
+
text = self.message_input.toPlainText()
|
495 |
+
self.char_count.setText(f"{len(text)} characters")
|
496 |
+
|
497 |
+
def clear_input(self):
|
498 |
+
"""Clear the input field"""
|
499 |
+
self.message_input.clear()
|
500 |
+
|
501 |
+
def load_model_in_background(self):
|
502 |
+
"""Load the model in a separate thread to prevent UI freezing"""
|
503 |
+
self.statusBar().showMessage("Loading model...")
|
504 |
+
self.status_indicator.setText("🟡 Loading model...")
|
505 |
+
self.send_btn.setEnabled(False)
|
506 |
+
self.message_input.setEnabled(False)
|
507 |
+
self.tokens_spinner.setEnabled(False)
|
508 |
+
|
509 |
+
self.model_loader_thread = ModelLoaderThread()
|
510 |
+
self.model_loader_thread.model_loaded.connect(self.model_loaded)
|
511 |
+
self.model_loader_thread.model_error.connect(self.model_error)
|
512 |
+
self.model_loader_thread.progress_update.connect(self.update_progress)
|
513 |
+
self.model_loader_thread.start()
|
514 |
+
|
515 |
+
def update_progress(self, message):
|
516 |
+
"""Update progress message"""
|
517 |
+
self.status_indicator.setText(f"🟡 {message}")
|
518 |
+
|
519 |
+
def model_loaded(self):
|
520 |
+
"""Called when model is successfully loaded"""
|
521 |
+
from mlx_lm import load, generate
|
522 |
+
# Load the model in the main thread
|
523 |
+
try:
|
524 |
+
self.model, self.tokenizer = load("mlx-community/gpt-oss-120b-MXFP4-Q4")
|
525 |
+
self.statusBar().showMessage("Model loaded and ready!")
|
526 |
+
self.status_indicator.setText("🟢 Model loaded and ready!")
|
527 |
+
self.send_btn.setEnabled(True)
|
528 |
+
self.message_input.setEnabled(True)
|
529 |
+
self.tokens_spinner.setEnabled(True)
|
530 |
+
|
531 |
+
# Add welcome message
|
532 |
+
welcome_msg = """Hello! I'm GPT-OSS-120B, running locally on your M3 Ultra.
|
533 |
+
|
534 |
+
I'm a 120 billion parameter open-source language model, and I'm here to assist you with:
|
535 |
+
- Answering questions
|
536 |
+
- Generating creative content
|
537 |
+
- Explaining complex concepts
|
538 |
+
- Writing and analyzing code
|
539 |
+
- And much more!
|
540 |
+
|
541 |
+
How can I help you today?"""
|
542 |
+
self.add_message(False, welcome_msg, 0.0)
|
543 |
+
|
544 |
+
# Scroll to bottom after a short delay to ensure UI is rendered
|
545 |
+
QTimer.singleShot(100, self.scroll_to_bottom)
|
546 |
+
except Exception as e:
|
547 |
+
self.model_error(str(e))
|
548 |
+
|
549 |
+
def model_error(self, error_msg):
|
550 |
+
"""Called when model loading fails"""
|
551 |
+
self.statusBar().showMessage(f"Error loading model: {error_msg}")
|
552 |
+
self.status_indicator.setText(f"🔴 Error: {error_msg}")
|
553 |
+
error_widget = ChatMessageWidget(False, f"Error loading model: {error_msg}")
|
554 |
+
self.chat_layout.addWidget(error_widget)
|
555 |
+
self.send_btn.setEnabled(False)
|
556 |
+
self.message_input.setEnabled(False)
|
557 |
+
|
558 |
+
def send_message(self):
|
559 |
+
"""Send the current message"""
|
560 |
+
message = self.message_input.toPlainText().strip()
|
561 |
+
if not message or not self.model:
|
562 |
+
return
|
563 |
+
|
564 |
+
# Add user message to chat
|
565 |
+
self.add_message(True, message)
|
566 |
+
self.message_input.clear()
|
567 |
+
|
568 |
+
# Disable input while generating
|
569 |
+
self.send_btn.setEnabled(False)
|
570 |
+
self.message_input.setEnabled(False)
|
571 |
+
self.tokens_spinner.setEnabled(False)
|
572 |
+
self.statusBar().showMessage("Generating response...")
|
573 |
+
self.status_indicator.setText("🟡 Generating response...")
|
574 |
+
|
575 |
+
# Generate response in a separate thread
|
576 |
+
self.generation_thread = GenerationThread(
|
577 |
+
self.model, self.tokenizer, message, self.max_tokens
|
578 |
+
)
|
579 |
+
self.generation_thread.response_ready.connect(self.handle_response)
|
580 |
+
self.generation_thread.error_occurred.connect(self.handle_error)
|
581 |
+
self.generation_thread.progress_update.connect(self.update_progress)
|
582 |
+
self.generation_thread.start()
|
583 |
+
|
584 |
+
def handle_response(self, response, generation_time):
|
585 |
+
"""Handle the generated response"""
|
586 |
+
self.add_message(False, response, generation_time)
|
587 |
+
|
588 |
+
# Re-enable input
|
589 |
+
self.send_btn.setEnabled(True)
|
590 |
+
self.message_input.setEnabled(True)
|
591 |
+
self.tokens_spinner.setEnabled(True)
|
592 |
+
self.statusBar().showMessage("Ready")
|
593 |
+
self.status_indicator.setText("🟢 Ready")
|
594 |
+
|
595 |
+
# Scroll to bottom
|
596 |
+
self.scroll_to_bottom()
|
597 |
+
|
598 |
+
def handle_error(self, error_msg):
|
599 |
+
"""Handle generation errors"""
|
600 |
+
self.add_message(False, f"Error: {error_msg}", 0.0)
|
601 |
+
|
602 |
+
# Re-enable input
|
603 |
+
self.send_btn.setEnabled(True)
|
604 |
+
self.message_input.setEnabled(True)
|
605 |
+
self.tokens_spinner.setEnabled(True)
|
606 |
+
self.statusBar().showMessage("Error occurred")
|
607 |
+
self.status_indicator.setText("🔴 Error occurred")
|
608 |
+
|
609 |
+
# Scroll to bottom
|
610 |
+
self.scroll_to_bottom()
|
611 |
+
|
612 |
+
def add_message(self, is_user, message, generation_time=0.0):
|
613 |
+
"""Add a message to the chat history"""
|
614 |
+
# Add to conversation history
|
615 |
+
self.conversation_history.append({
|
616 |
+
"is_user": is_user,
|
617 |
+
"message": message,
|
618 |
+
"timestamp": datetime.now().strftime("%H:%M:%S"),
|
619 |
+
"generation_time": generation_time
|
620 |
+
})
|
621 |
+
|
622 |
+
# Create and add message widget
|
623 |
+
message_widget = ChatMessageWidget(is_user, message, datetime.now().strftime("%H:%M:%S"), generation_time)
|
624 |
+
self.chat_layout.addWidget(message_widget)
|
625 |
+
|
626 |
+
def clear_conversation(self):
|
627 |
+
"""Clear the conversation history"""
|
628 |
+
# Clear history
|
629 |
+
self.conversation_history = []
|
630 |
+
|
631 |
+
# Remove all message widgets
|
632 |
+
for i in reversed(range(self.chat_layout.count())):
|
633 |
+
widget = self.chat_layout.itemAt(i).widget()
|
634 |
+
if widget:
|
635 |
+
widget.setParent(None)
|
636 |
+
|
637 |
+
# Add welcome message again
|
638 |
+
welcome_msg = "Hello! I'm GPT-OSS-120B. How can I assist you today?"
|
639 |
+
self.add_message(False, welcome_msg, 0.0)
|
640 |
+
|
641 |
+
# Scroll to bottom
|
642 |
+
self.scroll_to_bottom()
|
643 |
+
|
644 |
+
def export_conversation(self):
|
645 |
+
"""Export the conversation to a file"""
|
646 |
+
try:
|
647 |
+
options = QFileDialog.Options()
|
648 |
+
file_path, _ = QFileDialog.getSaveFileName(
|
649 |
+
self, "Save Conversation", "conversation.json", "JSON Files (*.json)", options=options
|
650 |
+
)
|
651 |
+
|
652 |
+
if file_path:
|
653 |
+
if not file_path.endswith('.json'):
|
654 |
+
file_path += '.json'
|
655 |
+
|
656 |
+
with open(file_path, 'w', encoding='utf-8') as f:
|
657 |
+
json.dump(self.conversation_history, f, indent=2, ensure_ascii=False)
|
658 |
+
|
659 |
+
QMessageBox.information(self, "Success", f"Conversation exported to {file_path}")
|
660 |
+
except Exception as e:
|
661 |
+
QMessageBox.critical(self, "Error", f"Failed to export conversation: {str(e)}")
|
662 |
+
|
663 |
+
def update_max_tokens(self, value):
|
664 |
+
"""Update the maximum tokens for generation"""
|
665 |
+
self.max_tokens = value
|
666 |
+
|
667 |
+
def scroll_to_bottom(self):
|
668 |
+
"""Scroll the chat area to the bottom"""
|
669 |
+
scrollbar = self.chat_scroll.verticalScrollBar()
|
670 |
+
scrollbar.setValue(scrollbar.maximum())
|
671 |
+
|
672 |
+
def keyPressEvent(self, event):
|
673 |
+
"""Handle key press events"""
|
674 |
+
if event.key() == Qt.Key_Return and event.modifiers() & Qt.ShiftModifier:
|
675 |
+
# Allow Shift+Enter for new lines
|
676 |
+
self.message_input.insertPlainText("\n")
|
677 |
+
elif event.key() == Qt.Key_Return:
|
678 |
+
# Send message on Enter (without Shift)
|
679 |
+
self.send_message()
|
680 |
+
else:
|
681 |
+
super().keyPressEvent(event)
|
682 |
+
|
683 |
+
|
684 |
+
def main():
|
685 |
+
app = QApplication(sys.argv)
|
686 |
+
|
687 |
+
# Set application style and font
|
688 |
+
app.setStyle('Fusion')
|
689 |
+
font = QFont("SF Pro Text", 12) # Use system font
|
690 |
+
app.setFont(font)
|
691 |
+
|
692 |
+
# Create and show the main window
|
693 |
+
chat_ui = GPTOSSChatUI()
|
694 |
+
chat_ui.show()
|
695 |
+
|
696 |
+
sys.exit(app.exec_())
|
697 |
+
|
698 |
+
|
699 |
+
if __name__ == "__main__":
|
700 |
+
main()
|
output.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91d93386ebf345500857f3eadda6317c5cc2c70774f790b9e6a290db3f2df01e
|
3 |
+
size 9746422
|