abhisheksan commited on
Commit
1ccdaa7
·
1 Parent(s): 23e1d87

Enhance model initialization and download process; add directory verification, error handling, and temporary file management

Browse files
Files changed (1) hide show
  1. main.py +79 -10
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import Optional, Dict, Any, Literal
3
  from enum import Enum
4
  from fastapi import FastAPI, HTTPException, status
@@ -136,14 +137,37 @@ class GenerateRequest(BaseModel):
136
  class ModelManager:
137
  def __init__(self):
138
  self.model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  async def initialize(self):
141
  """Initialize the model with error handling"""
142
- if not MODEL_PATH.exists():
143
- await self.download_model()
144
-
145
- self.model = self.initialize_model(MODEL_PATH)
146
- return self.model is not None
 
 
 
 
 
 
 
147
 
148
  @staticmethod
149
  async def download_model():
@@ -152,42 +176,87 @@ class ModelManager:
152
  from tqdm import tqdm
153
 
154
  if MODEL_PATH.exists():
 
155
  return
156
 
157
- logger.info(f"Downloading model to {MODEL_PATH}")
 
 
 
158
  try:
159
  response = requests.get(MODEL_URL, stream=True)
160
  response.raise_for_status()
161
  total_size = int(response.headers.get('content-length', 0))
162
 
163
- with open(MODEL_PATH, 'wb') as file, tqdm(
 
 
 
 
 
 
 
 
 
164
  desc="Downloading",
165
  total=total_size,
166
  unit='iB',
167
  unit_scale=True,
168
  unit_divisor=1024,
169
  ) as pbar:
170
- for data in response.iter_content(chunk_size=1024):
171
  size = file.write(data)
172
  pbar.update(size)
 
 
 
 
 
 
 
 
 
 
 
 
173
  except Exception as e:
174
  logger.error(f"Error downloading model: {str(e)}")
 
 
 
 
175
  if MODEL_PATH.exists():
176
  MODEL_PATH.unlink()
177
- raise
178
 
179
  def initialize_model(self, model_path: Path):
180
  """Initialize the model with the specified configuration"""
181
  try:
 
 
 
 
 
 
 
 
 
 
182
  model = AutoModelForCausalLM.from_pretrained(
183
  str(model_path.parent),
184
  model_file=model_path.name,
185
  model_type="llama",
186
- max_new_tokens=1500, # Support for longest poems
187
  context_length=2048,
188
  gpu_layers=0
189
  )
 
 
 
 
 
190
  return model
 
191
  except Exception as e:
192
  logger.error(f"Error initializing model: {str(e)}")
193
  return None
 
1
  import os
2
+ import shutil
3
  from typing import Optional, Dict, Any, Literal
4
  from enum import Enum
5
  from fastapi import FastAPI, HTTPException, status
 
137
  class ModelManager:
138
  def __init__(self):
139
  self.model = None
140
+
141
+ def ensure_model_directory(self):
142
+ """Ensure the model directory exists and is writable"""
143
+ try:
144
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
145
+
146
+ # Verify directory exists and is writable
147
+ if not MODEL_DIR.exists():
148
+ raise RuntimeError(f"Failed to create directory: {MODEL_DIR}")
149
+ if not os.access(MODEL_DIR, os.W_OK):
150
+ raise RuntimeError(f"Directory not writable: {MODEL_DIR}")
151
+
152
+ logger.info(f"Model directory verified: {MODEL_DIR}")
153
+ except Exception as e:
154
+ logger.error(f"Error setting up model directory: {str(e)}")
155
+ raise
156
 
157
  async def initialize(self):
158
  """Initialize the model with error handling"""
159
+ try:
160
+ # Ensure directory exists before attempting download
161
+ self.ensure_model_directory()
162
+
163
+ if not MODEL_PATH.exists():
164
+ await self.download_model()
165
+
166
+ self.model = self.initialize_model(MODEL_PATH)
167
+ return self.model is not None
168
+ except Exception as e:
169
+ logger.error(f"Initialization failed: {str(e)}")
170
+ return False
171
 
172
  @staticmethod
173
  async def download_model():
 
176
  from tqdm import tqdm
177
 
178
  if MODEL_PATH.exists():
179
+ logger.info(f"Model already exists at {MODEL_PATH}")
180
  return
181
 
182
+ # Create a temporary file for downloading
183
+ temp_path = MODEL_PATH.with_suffix('.temp')
184
+
185
+ logger.info(f"Downloading model to temporary file: {temp_path}")
186
  try:
187
  response = requests.get(MODEL_URL, stream=True)
188
  response.raise_for_status()
189
  total_size = int(response.headers.get('content-length', 0))
190
 
191
+ # Ensure we have enough disk space
192
+ free_space = shutil.disk_usage(MODEL_DIR).free
193
+ if free_space < total_size * 1.1: # 10% buffer
194
+ raise RuntimeError(
195
+ f"Insufficient disk space. Need {total_size * 1.1 / (1024**3):.2f}GB, "
196
+ f"have {free_space / (1024**3):.2f}GB"
197
+ )
198
+
199
+ # Download to temporary file first
200
+ with open(temp_path, 'wb') as file, tqdm(
201
  desc="Downloading",
202
  total=total_size,
203
  unit='iB',
204
  unit_scale=True,
205
  unit_divisor=1024,
206
  ) as pbar:
207
+ for data in response.iter_content(chunk_size=8192):
208
  size = file.write(data)
209
  pbar.update(size)
210
+
211
+ # Verify file size
212
+ if temp_path.stat().st_size != total_size:
213
+ raise RuntimeError(
214
+ f"Downloaded file size ({temp_path.stat().st_size}) "
215
+ f"doesn't match expected size ({total_size})"
216
+ )
217
+
218
+ # Move temporary file to final location
219
+ temp_path.rename(MODEL_PATH)
220
+ logger.info(f"Model downloaded successfully to {MODEL_PATH}")
221
+
222
  except Exception as e:
223
  logger.error(f"Error downloading model: {str(e)}")
224
+ # Clean up temporary file if it exists
225
+ if temp_path.exists():
226
+ temp_path.unlink()
227
+ # Clean up partial download if it exists
228
  if MODEL_PATH.exists():
229
  MODEL_PATH.unlink()
230
+ raise RuntimeError(f"Model download failed: {str(e)}")
231
 
232
  def initialize_model(self, model_path: Path):
233
  """Initialize the model with the specified configuration"""
234
  try:
235
+ if not model_path.exists():
236
+ raise FileNotFoundError(f"Model file not found: {model_path}")
237
+
238
+ if not model_path.is_file():
239
+ raise RuntimeError(f"Model path is not a file: {model_path}")
240
+
241
+ if not os.access(model_path, os.R_OK):
242
+ raise RuntimeError(f"Model file is not readable: {model_path}")
243
+
244
+ logger.info(f"Initializing model from: {model_path}")
245
  model = AutoModelForCausalLM.from_pretrained(
246
  str(model_path.parent),
247
  model_file=model_path.name,
248
  model_type="llama",
249
+ max_new_tokens=1500,
250
  context_length=2048,
251
  gpu_layers=0
252
  )
253
+
254
+ if model is None:
255
+ raise RuntimeError("Model initialization returned None")
256
+
257
+ logger.info("Model initialized successfully")
258
  return model
259
+
260
  except Exception as e:
261
  logger.error(f"Error initializing model: {str(e)}")
262
  return None