abhisheksan commited on
Commit
23e1d87
·
1 Parent(s): 0d8e806

Refactor model initialization logic; improve error handling and logging for better diagnostics

Browse files
Files changed (1) hide show
  1. main.py +221 -207
main.py CHANGED
@@ -1,57 +1,24 @@
1
  import os
2
- from typing import Optional
 
3
  from fastapi import FastAPI, HTTPException, status
4
  from pathlib import Path
5
  import logging
6
  import sys
7
-
8
- from pydantic import BaseModel
9
  from ctransformers import AutoModelForCausalLM
 
 
 
10
  BASE_DIR = Path("/app")
11
  MODEL_DIR = BASE_DIR / "models"
12
  MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
13
  MODEL_PATH = MODEL_DIR / MODEL_NAME
14
- # Create FastAPI app
15
- app = FastAPI(title="LLM API Server")
16
- # Ensure model directory exists
17
- MODEL_DIR.mkdir(parents=True, exist_ok=True)
18
- # Model download URL
19
  MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
20
- def download_model():
21
- """Download the model if it doesn't exist"""
22
- import requests
23
- from tqdm import tqdm
24
-
25
- if MODEL_PATH.exists():
26
- logger.info(f"Model already exists at {MODEL_PATH}")
27
- return
28
-
29
- logger.info(f"Downloading model to {MODEL_PATH}")
30
- try:
31
- response = requests.get(MODEL_URL, stream=True)
32
- response.raise_for_status()
33
- total_size = int(response.headers.get('content-length', 0))
34
-
35
- with open(MODEL_PATH, 'wb') as file, tqdm(
36
- desc="Downloading",
37
- total=total_size,
38
- unit='iB',
39
- unit_scale=True,
40
- unit_divisor=1024,
41
- ) as pbar:
42
- for data in response.iter_content(chunk_size=1024):
43
- size = file.write(data)
44
- pbar.update(size)
45
-
46
- logger.info("Model downloaded successfully")
47
- except Exception as e:
48
- logger.error(f"Error downloading model: {str(e)}")
49
- if MODEL_PATH.exists():
50
- MODEL_PATH.unlink()
51
- raise
52
- # Enhanced logging configuration
53
  logging.basicConfig(
54
- level=logging.DEBUG, # Changed to DEBUG for more verbose output
55
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
56
  handlers=[
57
  logging.StreamHandler(sys.stdout),
@@ -60,198 +27,245 @@ logging.basicConfig(
60
  )
61
  logger = logging.getLogger(__name__)
62
 
63
- def verify_model_path(model_path: Path) -> bool:
64
- """
65
- Verify that the model path exists and has correct permissions
66
- """
67
- try:
68
- if not model_path.exists():
69
- logger.error(f"Model path does not exist: {model_path}")
70
- return False
71
-
72
- if not model_path.is_file():
73
- logger.error(f"Model path is not a file: {model_path}")
74
- return False
75
-
76
- # Check if file is readable
77
- if not os.access(model_path, os.R_OK):
78
- logger.error(f"Model file is not readable: {model_path}")
79
- return False
80
-
81
- # Check file size (should be at least 1MB)
82
- if model_path.stat().st_size < 1_000_000:
83
- logger.error(f"Model file seems too small: {model_path.stat().st_size} bytes")
84
- return False
85
-
86
- return True
87
- except Exception as e:
88
- logger.error(f"Error verifying model path: {str(e)}")
89
- return False
90
 
91
- def initialize_model(model_path: Path):
92
- """
93
- Initialize the model with enhanced error handling and logging
94
- """
95
- logger.info("Starting model initialization...")
 
 
 
 
 
 
96
 
97
- try:
98
- # Verify model path first
99
- if not verify_model_path(model_path):
100
- logger.error("Model path verification failed")
101
- return None
102
-
103
- # Log model loading attempt
104
- logger.info(f"Attempting to load model from: {model_path}")
105
- logger.debug(f"Parent directory: {model_path.parent}")
106
- logger.debug(f"Model filename: {model_path.name}")
 
 
 
 
 
 
 
107
 
108
- # Try to initialize the model
109
- model = AutoModelForCausalLM.from_pretrained(
110
- str(model_path.parent),
111
- model_file=model_path.name,
112
- model_type="llama",
113
- max_new_tokens=512,
114
- context_length=512,
115
- gpu_layers=0 # CPU only
116
- )
117
 
118
- # Verify model loaded successfully
119
- if model is None:
120
- logger.error("Model initialization returned None")
121
- return None
122
-
123
- # Test model with a simple input
124
- try:
125
- test_output = model("Test", max_new_tokens=5)
126
- logger.info("Model successfully generated test output")
127
- except Exception as e:
128
- logger.error(f"Model failed basic inference test: {str(e)}")
129
- return None
 
 
 
 
 
 
 
 
 
 
130
 
131
- logger.info("Model initialized successfully")
132
- return model
133
 
134
- except ImportError as e:
135
- logger.error(f"Failed to import required modules: {str(e)}")
136
- logger.error("Please ensure ctransformers is installed: pip install ctransformers")
137
- return None
138
- except RuntimeError as e:
139
- logger.error(f"Runtime error during model initialization: {str(e)}")
140
- return None
141
- except Exception as e:
142
- logger.error(f"Unexpected error during model initialization: {str(e)}")
143
- logger.exception("Full traceback:")
144
- return None
145
 
146
- # Modified startup event
147
- async def startup_event():
148
- """Initialize the model during startup with enhanced error handling"""
149
- global model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- logger.info("Starting application...")
 
 
 
 
 
 
152
 
153
- if not MODEL_PATH:
154
- logger.error("MODEL_PATH is not defined in config")
155
- return
 
 
156
 
157
- # Try to download model if it doesn't exist
158
- if not MODEL_PATH.exists():
 
 
159
  try:
160
- logger.info("Model not found, attempting download...")
161
- download_model()
 
 
 
 
 
 
 
 
 
 
 
 
162
  except Exception as e:
163
- logger.error(f"Failed to download model: {str(e)}")
164
- return
165
-
166
- # Initialize model
167
- model = initialize_model(MODEL_PATH)
168
-
169
- if model is None:
170
- logger.warning(
171
- "Model failed to load. Service will start but /generate endpoint will be unavailable. "
172
- "Check model_loading.log for details."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  )
174
- else:
175
- logger.info("Application started successfully with model loaded")
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  @app.on_event("startup")
178
  async def startup():
179
  """Initialize the model during startup"""
180
- global model
181
-
182
- logger.info("Starting application...")
183
-
184
- if not MODEL_PATH:
185
- logger.error("MODEL_PATH is not defined in config")
186
- return
187
-
188
- # Try to download model if it doesn't exist
189
- if not MODEL_PATH.exists():
190
- try:
191
- logger.info("Model not found, attempting download...")
192
- download_model()
193
- except Exception as e:
194
- logger.error(f"Failed to download model: {str(e)}")
195
- return
196
-
197
- # Initialize model
198
- model = initialize_model(MODEL_PATH)
199
-
200
- if model is None:
201
- logger.warning(
202
- "Model failed to load. Service will start but /generate endpoint will be unavailable. "
203
- "Check model_loading.log for details."
204
- )
205
- else:
206
- logger.info("Application started successfully with model loaded")
207
 
208
  @app.get("/health")
209
  async def health_check():
210
  """Health check endpoint"""
211
  return {
212
  "status": "healthy",
213
- "model_loaded": model is not None
214
  }
215
- class GenerateRequest(BaseModel):
216
- prompt: str
217
- max_tokens: Optional[int] = 512
218
- temperature: Optional[float] = 0.7
219
  @app.post("/generate")
220
  async def generate_text(request: GenerateRequest):
221
- """Generate text from the model"""
222
- if model is None:
223
- raise HTTPException(
224
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
225
- detail="Model is not loaded. Please check server logs."
226
- )
227
-
228
- try:
229
- # Generate response from the model
230
- response = model(
231
- request.prompt,
232
- max_new_tokens=request.max_tokens,
233
- temperature=request.temperature
234
- )
235
-
236
- return {
237
- "generated_text": response,
238
- "prompt": request.prompt,
239
- "max_tokens": request.max_tokens,
240
- "temperature": request.temperature
241
- }
242
- except Exception as e:
243
- logger.error(f"Error generating text: {str(e)}")
244
- raise HTTPException(
245
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
246
- detail=f"Error generating text: {str(e)}"
247
- )
248
 
249
- # Optional shutdown event
250
  @app.on_event("shutdown")
251
  async def shutdown():
252
  """Cleanup on shutdown"""
253
- global model
254
- if model is not None:
255
- del model
256
- model = None
257
- logger.info("Model unloaded during shutdown")
 
1
  import os
2
+ from typing import Optional, Dict, Any, Literal
3
+ from enum import Enum
4
  from fastapi import FastAPI, HTTPException, status
5
  from pathlib import Path
6
  import logging
7
  import sys
8
+ from pydantic import BaseModel, Field, validator
 
9
  from ctransformers import AutoModelForCausalLM
10
+ from dataclasses import dataclass
11
+
12
+ # Constants
13
  BASE_DIR = Path("/app")
14
  MODEL_DIR = BASE_DIR / "models"
15
  MODEL_NAME = "llama-2-7b-chat.q4_K_M.gguf"
16
  MODEL_PATH = MODEL_DIR / MODEL_NAME
 
 
 
 
 
17
  MODEL_URL = "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGUF/resolve/main/llama-2-7b-chat.Q4_K_M.gguf"
18
+
19
+ # Logging configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  logging.basicConfig(
21
+ level=logging.DEBUG,
22
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
23
  handlers=[
24
  logging.StreamHandler(sys.stdout),
 
27
  )
28
  logger = logging.getLogger(__name__)
29
 
30
+ # Aligned with frontend enums
31
+ class PoemStyle(str, Enum):
32
+ SONNET = "sonnet"
33
+ HAIKU = "haiku"
34
+ FREE_VERSE = "free-verse"
35
+ VILLANELLE = "villanelle"
36
+
37
+ class EmotionalTone(str, Enum):
38
+ CONTEMPLATIVE = "contemplative"
39
+ JOYFUL = "joyful"
40
+ MELANCHOLIC = "melancholic"
41
+ ROMANTIC = "romantic"
42
+
43
+ class Length(str, Enum):
44
+ SHORT = "short" # 100 words
45
+ MEDIUM = "medium" # 200 words
46
+ LONG = "long" # 300 words
 
 
 
 
 
 
 
 
 
 
47
 
48
+ @dataclass
49
+ class StyleConfig:
50
+ """Maps style parameters to model parameters"""
51
+ temperature: float
52
+ top_p: float
53
+ top_k: int
54
+ repetition_penalty: float
55
+ max_tokens: int
56
+
57
+ class StyleMapper:
58
+ """Maps style preferences to model parameters"""
59
 
60
+ @staticmethod
61
+ def get_style_config(
62
+ style: PoemStyle,
63
+ emotional_tone: EmotionalTone,
64
+ creative_style: float, # 0-100
65
+ language_variety: float, # 0-1
66
+ length: Length,
67
+ word_repetition: float, # 1-2
68
+ ) -> StyleConfig:
69
+ # Base configuration
70
+ config = {
71
+ "temperature": 0.7,
72
+ "top_p": 0.9,
73
+ "top_k": 40,
74
+ "repetition_penalty": 1.1,
75
+ "max_tokens": 512
76
+ }
77
 
78
+ # Map creative_style (0-100) to temperature (0.5-1.0)
79
+ config["temperature"] = 0.5 + (creative_style / 100) * 0.5
 
 
 
 
 
 
 
80
 
81
+ # Map length to tokens (assuming average word is 5 tokens)
82
+ length_token_map = {
83
+ Length.SHORT: 500, # ~100 words
84
+ Length.MEDIUM: 1000, # ~200 words
85
+ Length.LONG: 1500, # ~300 words
86
+ }
87
+ config["max_tokens"] = length_token_map[length]
88
+
89
+ # Map language_variety (0-1) to top_p
90
+ config["top_p"] = 0.7 + (language_variety * 0.3)
91
+
92
+ # Map word_repetition (1-2) to repetition_penalty
93
+ config["repetition_penalty"] = word_repetition
94
+
95
+ # Adjust based on emotional tone
96
+ tone_temp_adjustment = {
97
+ EmotionalTone.CONTEMPLATIVE: 0.0,
98
+ EmotionalTone.JOYFUL: 0.1,
99
+ EmotionalTone.MELANCHOLIC: -0.1,
100
+ EmotionalTone.ROMANTIC: 0.2
101
+ }
102
+ config["temperature"] += tone_temp_adjustment[emotional_tone]
103
 
104
+ # Clamp temperature between 0.5 and 1.0
105
+ config["temperature"] = max(0.5, min(1.0, config["temperature"]))
106
 
107
+ return StyleConfig(**config)
 
 
 
 
 
 
 
 
 
 
108
 
109
+ class GenerateRequest(BaseModel):
110
+ prompt: str
111
+ style: PoemStyle
112
+ emotionalTone: EmotionalTone
113
+ creativeStyle: float = Field(ge=0, le=100) # 0-100 slider
114
+ languageVariety: float = Field(ge=0, le=1) # 0-1 slider
115
+ length: Length
116
+ wordRepetition: float = Field(ge=1, le=2) # 1-2 slider
117
+
118
+ @validator('creativeStyle')
119
+ def validate_creative_style(cls, v):
120
+ if not 0 <= v <= 100:
121
+ raise ValueError('creativeStyle must be between 0 and 100')
122
+ return v
123
+
124
+ @validator('languageVariety')
125
+ def validate_language_variety(cls, v):
126
+ if not 0 <= v <= 1:
127
+ raise ValueError('languageVariety must be between 0 and 1')
128
+ return v
129
+
130
+ @validator('wordRepetition')
131
+ def validate_word_repetition(cls, v):
132
+ if not 1 <= v <= 2:
133
+ raise ValueError('wordRepetition must be between 1 and 2')
134
+ return v
135
+
136
+ class ModelManager:
137
+ def __init__(self):
138
+ self.model = None
139
 
140
+ async def initialize(self):
141
+ """Initialize the model with error handling"""
142
+ if not MODEL_PATH.exists():
143
+ await self.download_model()
144
+
145
+ self.model = self.initialize_model(MODEL_PATH)
146
+ return self.model is not None
147
 
148
+ @staticmethod
149
+ async def download_model():
150
+ """Download the model if it doesn't exist"""
151
+ import requests
152
+ from tqdm import tqdm
153
 
154
+ if MODEL_PATH.exists():
155
+ return
156
+
157
+ logger.info(f"Downloading model to {MODEL_PATH}")
158
  try:
159
+ response = requests.get(MODEL_URL, stream=True)
160
+ response.raise_for_status()
161
+ total_size = int(response.headers.get('content-length', 0))
162
+
163
+ with open(MODEL_PATH, 'wb') as file, tqdm(
164
+ desc="Downloading",
165
+ total=total_size,
166
+ unit='iB',
167
+ unit_scale=True,
168
+ unit_divisor=1024,
169
+ ) as pbar:
170
+ for data in response.iter_content(chunk_size=1024):
171
+ size = file.write(data)
172
+ pbar.update(size)
173
  except Exception as e:
174
+ logger.error(f"Error downloading model: {str(e)}")
175
+ if MODEL_PATH.exists():
176
+ MODEL_PATH.unlink()
177
+ raise
178
+
179
+ def initialize_model(self, model_path: Path):
180
+ """Initialize the model with the specified configuration"""
181
+ try:
182
+ model = AutoModelForCausalLM.from_pretrained(
183
+ str(model_path.parent),
184
+ model_file=model_path.name,
185
+ model_type="llama",
186
+ max_new_tokens=1500, # Support for longest poems
187
+ context_length=2048,
188
+ gpu_layers=0
189
+ )
190
+ return model
191
+ except Exception as e:
192
+ logger.error(f"Error initializing model: {str(e)}")
193
+ return None
194
+
195
+ def generate(self, request: GenerateRequest) -> Dict[str, Any]:
196
+ """Generate text based on the request and style parameters"""
197
+ if self.model is None:
198
+ raise HTTPException(
199
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
200
+ detail="Model not loaded"
201
+ )
202
+
203
+ # Get style configuration
204
+ style_config = StyleMapper.get_style_config(
205
+ request.style,
206
+ request.emotionalTone,
207
+ request.creativeStyle,
208
+ request.languageVariety,
209
+ request.length,
210
+ request.wordRepetition
211
  )
 
 
212
 
213
+ try:
214
+ # Prepare prompt based on style
215
+ style_prompts = {
216
+ PoemStyle.SONNET: "Write a sonnet about",
217
+ PoemStyle.HAIKU: "Write a haiku about",
218
+ PoemStyle.FREE_VERSE: "Write a free verse poem about",
219
+ PoemStyle.VILLANELLE: "Write a villanelle about"
220
+ }
221
+
222
+ styled_prompt = f"{style_prompts[request.style]} {request.prompt}"
223
+
224
+ response = self.model(
225
+ styled_prompt,
226
+ max_new_tokens=style_config.max_tokens,
227
+ temperature=style_config.temperature,
228
+ top_p=style_config.top_p,
229
+ top_k=style_config.top_k,
230
+ repetition_penalty=style_config.repetition_penalty
231
+ )
232
+
233
+ return {
234
+ "generated_text": response,
235
+ "prompt": styled_prompt,
236
+ "style_config": style_config.__dict__
237
+ }
238
+ except Exception as e:
239
+ logger.error(f"Error generating text: {str(e)}")
240
+ raise HTTPException(
241
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
242
+ detail=str(e)
243
+ )
244
+
245
+ # Create FastAPI app and model manager
246
+ app = FastAPI(title="Poetry Generation API")
247
+ model_manager = ModelManager()
248
+
249
  @app.on_event("startup")
250
  async def startup():
251
  """Initialize the model during startup"""
252
+ await model_manager.initialize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  @app.get("/health")
255
  async def health_check():
256
  """Health check endpoint"""
257
  return {
258
  "status": "healthy",
259
+ "model_loaded": model_manager.model is not None
260
  }
261
+
 
 
 
262
  @app.post("/generate")
263
  async def generate_text(request: GenerateRequest):
264
+ """Generate text with style parameters"""
265
+ return model_manager.generate(request)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
 
267
  @app.on_event("shutdown")
268
  async def shutdown():
269
  """Cleanup on shutdown"""
270
+ if model_manager.model is not None:
271
+ del model_manager.model