abhisheksan commited on
Commit
1994657
·
1 Parent(s): f4f946e

Add initial log entries for model initialization and loading status

Browse files
Files changed (2) hide show
  1. logs/poetry_generation_20241117.log +4 -0
  2. main.py +310 -126
logs/poetry_generation_20241117.log ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 2024-11-17 00:08:48,570 - main - INFO - Initializing model on device: cpu
2
+ 2024-11-17 00:08:50,303 - main - INFO - Model and tokenizer loaded successfully
3
+ 2024-11-17 00:13:06,341 - main - INFO - Initializing model on device: cpu
4
+ 2024-11-17 00:13:07,660 - main - INFO - Model and tokenizer loaded successfully
main.py CHANGED
@@ -1,51 +1,38 @@
1
  import os
2
- from typing import Optional, Dict, Any
3
- from fastapi import FastAPI, HTTPException, status
 
 
4
  import logging
5
  import sys
6
- from pydantic import BaseModel, Field
7
  import torch
8
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
9
- import json
 
 
 
 
 
10
 
11
- # Define base model directory
12
  BASE_MODEL_DIR = "./models/"
13
  MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth")
 
 
 
 
14
  MODEL_CONFIG = GPT2Config(
15
- n_positions=128, # MAX_LENGTH from training
16
  n_ctx=128,
17
- n_embd=384, # Same as training
18
- n_layer=6, # Same as training
19
- n_head=6, # Same as training
20
  vocab_size=50257,
21
  bos_token_id=50256,
22
  eos_token_id=50256,
 
23
  )
24
- def setup_logging():
25
- logger = logging.getLogger(__name__)
26
- logger.setLevel(logging.DEBUG)
27
-
28
- formatter = logging.Formatter(
29
- '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
- )
31
-
32
- stdout_handler = logging.StreamHandler(sys.stdout)
33
- stdout_handler.setFormatter(formatter)
34
- logger.addHandler(stdout_handler)
35
-
36
- try:
37
- log_dir = os.path.join(os.getcwd(), 'logs')
38
- os.makedirs(log_dir, exist_ok=True)
39
-
40
- file_handler = logging.FileHandler(os.path.join(log_dir, 'poetry_generation.log'))
41
- file_handler.setFormatter(formatter)
42
- logger.addHandler(file_handler)
43
- except (OSError, PermissionError) as e:
44
- print(f"Warning: Could not create log file (using stdout only): {e}")
45
-
46
- return logger
47
-
48
- logger = setup_logging()
49
 
50
  class GenerateRequest(BaseModel):
51
  prompt: str = Field(..., min_length=1, max_length=500)
@@ -54,41 +41,100 @@ class GenerateRequest(BaseModel):
54
  top_k: int = Field(default=50, ge=1, le=100)
55
  top_p: float = Field(default=0.95, ge=0.1, le=1.0)
56
  repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  class ModelManager:
59
  def __init__(self):
60
  self.model = None
61
  self.tokenizer = None
 
 
 
 
62
 
63
- def initialize(self):
64
- """Initialize the model and tokenizer"""
65
  try:
66
- logger.info("Loading tokenizer...")
67
- self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
68
- self.tokenizer.pad_token = self.tokenizer.eos_token
69
 
70
- logger.info("Loading model...")
71
- if not os.path.exists(MODEL_PATH):
72
- logger.error(f"Model file not found at {MODEL_PATH}")
73
- return False
74
 
75
- # Initialize model with the same configuration as training
76
- self.model = GPT2LMHeadModel(MODEL_CONFIG)
77
-
78
- # Load your trained weights
79
- state_dict = torch.load(MODEL_PATH, map_location='cpu')
80
-
81
- # Load the state dictionary into the model
82
- missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
83
-
84
- if missing_keys:
85
- logger.warning(f"Missing keys: {missing_keys}")
86
- if unexpected_keys:
87
- logger.warning(f"Unexpected keys: {unexpected_keys}")
88
-
89
- # Force model to CPU and eval mode
90
- self.model.to('cpu')
91
- self.model.eval()
92
 
93
  logger.info("Model and tokenizer loaded successfully")
94
  return True
@@ -98,85 +144,223 @@ class ModelManager:
98
  logger.exception("Detailed traceback:")
99
  return False
100
 
101
-
102
- def generate(self, request: GenerateRequest) -> Dict[str, Any]:
103
- """Generate poetry based on the request parameters"""
104
- if self.model is None or self.tokenizer is None:
105
- raise HTTPException(
106
- status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
107
- detail="Model or tokenizer not loaded"
108
- )
 
 
 
109
 
110
  try:
111
- # Encode input
112
- inputs = self.tokenizer.encode(request.prompt, return_tensors='pt')
113
- attention_mask = torch.ones(inputs.shape, dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Generate
116
- with torch.no_grad():
117
- outputs = self.model.generate(
118
- inputs,
119
- attention_mask=attention_mask,
120
- max_length=request.max_length,
121
- num_return_sequences=1,
122
- temperature=request.temperature,
123
- top_k=request.top_k,
124
- top_p=request.top_p,
125
- repetition_penalty=request.repetition_penalty,
126
- do_sample=True,
127
- pad_token_id=self.tokenizer.eos_token_id,
 
 
 
 
 
 
 
 
128
  )
129
-
130
- # Decode the generated text
131
- generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
132
-
133
- return {
134
- "generated_text": generated_text,
135
- "prompt": request.prompt,
136
- "parameters": {
137
- "max_length": request.max_length,
138
- "temperature": request.temperature,
139
- "top_k": request.top_k,
140
- "top_p": request.top_p,
141
- "repetition_penalty": request.repetition_penalty
142
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  }
144
-
145
- except Exception as e:
146
- logger.error(f"Error generating text: {str(e)}")
147
- raise HTTPException(
148
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
149
- detail=str(e)
150
- )
151
-
152
- # Create FastAPI app and model manager
153
- app = FastAPI(title="Poetry Generation API")
154
- model_manager = ModelManager()
 
 
 
 
 
 
155
 
156
- @app.on_event("startup")
157
- async def startup():
158
- """Initialize the model during startup"""
159
- if not model_manager.initialize():
160
  logger.error("Failed to initialize model manager")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
  @app.get("/health")
163
  async def health_check():
164
- """Health check endpoint"""
165
  return {
166
  "status": "healthy",
167
  "model_loaded": model_manager.model is not None,
168
- "tokenizer_loaded": model_manager.tokenizer is not None
 
 
 
 
 
 
 
169
  }
170
 
171
  @app.post("/generate")
172
- async def generate_text(request: GenerateRequest):
173
- """Generate poetry with parameters"""
174
- return model_manager.generate(request)
175
-
176
- @app.on_event("shutdown")
177
- async def shutdown():
178
- """Cleanup on shutdown"""
179
- if model_manager.model is not None:
180
- del model_manager.model
181
- if model_manager.tokenizer is not None:
182
- del model_manager.tokenizer
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Optional, Dict, Any, List
3
+ from fastapi import FastAPI, HTTPException, status, BackgroundTasks
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import JSONResponse
6
  import logging
7
  import sys
8
+ from pydantic import BaseModel, Field, validator
9
  import torch
10
  from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
11
+ from contextlib import asynccontextmanager
12
+ import asyncio
13
+ from functools import lru_cache
14
+ import numpy as np
15
+ from datetime import datetime
16
+ import re
17
 
18
+ # Constants
19
  BASE_MODEL_DIR = "./models/"
20
  MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth")
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ BATCH_SIZE = 4
23
+ CACHE_SIZE = 1024
24
+
25
  MODEL_CONFIG = GPT2Config(
26
+ n_positions=128,
27
  n_ctx=128,
28
+ n_embd=384,
29
+ n_layer=6,
30
+ n_head=6,
31
  vocab_size=50257,
32
  bos_token_id=50256,
33
  eos_token_id=50256,
34
+ use_cache=True,
35
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  class GenerateRequest(BaseModel):
38
  prompt: str = Field(..., min_length=1, max_length=500)
 
41
  top_k: int = Field(default=50, ge=1, le=100)
42
  top_p: float = Field(default=0.95, ge=0.1, le=1.0)
43
  repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0)
44
+ style: Optional[str] = Field(default="free_verse",
45
+ description="Poetry style: free_verse, haiku, sonnet")
46
+
47
+ @validator('prompt')
48
+ def validate_prompt(cls, v):
49
+ v = ' '.join(v.split())
50
+ return v
51
+
52
+ class PoemFormatter:
53
+ """Handles poem formatting and processing"""
54
+
55
+ @staticmethod
56
+ def format_free_verse(text: str) -> List[str]:
57
+ lines = re.split(r'[.!?]+|\n+', text)
58
+ lines = [line.strip() for line in lines if line.strip()]
59
+ formatted_lines = []
60
+ for line in lines:
61
+ if len(line) > 40:
62
+ parts = line.split(',')
63
+ formatted_lines.extend(part.strip() for part in parts if part.strip())
64
+ else:
65
+ formatted_lines.append(line)
66
+ return formatted_lines
67
+
68
+ @staticmethod
69
+ def format_haiku(text: str) -> List[str]:
70
+ words = text.split()
71
+ lines = []
72
+ current_line = []
73
+ syllable_count = 0
74
+
75
+ for word in words:
76
+ syllables = len(re.findall(r'[aeiou]+', word.lower()))
77
+ if syllable_count + syllables <= 5 and len(lines) == 0:
78
+ current_line.append(word)
79
+ syllable_count += syllables
80
+ elif syllable_count + syllables <= 7 and len(lines) == 1:
81
+ current_line.append(word)
82
+ syllable_count += syllables
83
+ elif syllable_count + syllables <= 5 and len(lines) == 2:
84
+ current_line.append(word)
85
+ syllable_count += syllables
86
+ else:
87
+ if current_line:
88
+ lines.append(' '.join(current_line))
89
+ current_line = [word]
90
+ syllable_count = syllables
91
+
92
+ if len(lines) == 3:
93
+ break
94
+
95
+ if current_line and len(lines) < 3:
96
+ lines.append(' '.join(current_line))
97
+
98
+ return lines[:3]
99
+
100
+ @staticmethod
101
+ def format_sonnet(text: str) -> List[str]:
102
+ words = text.split()
103
+ lines = []
104
+ current_line = []
105
+ target_line_length = 10
106
+
107
+ for word in words:
108
+ current_line.append(word)
109
+ if len(current_line) >= target_line_length:
110
+ lines.append(' '.join(current_line))
111
+ current_line = []
112
+
113
+ if len(lines) >= 14:
114
+ break
115
+
116
+ if current_line and len(lines) < 14:
117
+ lines.append(' '.join(current_line))
118
+
119
+ return lines[:14]
120
 
121
  class ModelManager:
122
  def __init__(self):
123
  self.model = None
124
  self.tokenizer = None
125
+ self._lock = asyncio.Lock()
126
+ self.request_count = 0
127
+ self.last_cleanup = datetime.now()
128
+ self.poem_formatter = PoemFormatter()
129
 
130
+ async def initialize(self) -> bool:
 
131
  try:
132
+ self._setup_logging()
 
 
133
 
134
+ logger.info(f"Initializing model on device: {DEVICE}")
 
 
 
135
 
136
+ self.tokenizer = await self._load_tokenizer()
137
+ await self._load_and_optimize_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  logger.info("Model and tokenizer loaded successfully")
140
  return True
 
144
  logger.exception("Detailed traceback:")
145
  return False
146
 
147
+ @staticmethod
148
+ def _setup_logging():
149
+ global logger
150
+ logger = logging.getLogger(__name__)
151
+ logger.setLevel(logging.INFO)
152
+
153
+ formatter = logging.Formatter(
154
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
155
+ )
156
+
157
+ handlers = [logging.StreamHandler(sys.stdout)]
158
 
159
  try:
160
+ log_dir = os.path.join(os.getcwd(), 'logs')
161
+ os.makedirs(log_dir, exist_ok=True)
162
+ handlers.append(logging.FileHandler(
163
+ os.path.join(log_dir, f'poetry_generation_{datetime.now().strftime("%Y%m%d")}.log')
164
+ ))
165
+ except Exception as e:
166
+ print(f"Warning: Could not create log file: {e}")
167
+
168
+ for handler in handlers:
169
+ handler.setFormatter(formatter)
170
+ logger.addHandler(handler)
171
+
172
+ @lru_cache(maxsize=CACHE_SIZE)
173
+ async def _load_tokenizer(self):
174
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
175
+ tokenizer.pad_token = tokenizer.eos_token
176
+ return tokenizer
177
+
178
+ async def _load_and_optimize_model(self):
179
+ if not os.path.exists(MODEL_PATH):
180
+ raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
181
+
182
+ self.model = GPT2LMHeadModel(MODEL_CONFIG)
183
+
184
+ state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
185
+ self.model.load_state_dict(state_dict, strict=False)
186
+
187
+ self.model.to(DEVICE)
188
+ self.model.eval()
189
+
190
+ if DEVICE.type == 'cuda':
191
+ torch.backends.cudnn.benchmark = True
192
+ self.model = torch.jit.script(self.model)
193
 
194
+ dummy_input = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
195
+ with torch.no_grad():
196
+ self.model(dummy_input)
197
+
198
+ @torch.no_grad()
199
+ async def generate(self, request: GenerateRequest) -> Dict[str, Any]:
200
+ async with self._lock:
201
+ try:
202
+ self.request_count += 1
203
+ await self._check_cleanup()
204
+
205
+ inputs = await self._prepare_inputs(request.prompt)
206
+ outputs = await self._generate_optimized(inputs, request)
207
+
208
+ return await self._process_outputs(outputs, request)
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error generating text: {str(e)}")
212
+ raise HTTPException(
213
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
214
+ detail=str(e)
215
  )
216
+
217
+ async def _prepare_inputs(self, prompt: str):
218
+ poetry_prompt = f"Write a poem about: {prompt}\n\nPoem:"
219
+ tokens = self.tokenizer.encode(poetry_prompt, return_tensors='pt')
220
+ return tokens.to(DEVICE)
221
+
222
+ async def _generate_optimized(self, inputs, request: GenerateRequest):
223
+ attention_mask = torch.ones(inputs.shape, dtype=torch.long, device=DEVICE)
224
+
225
+ style_params = {
226
+ "haiku": {"max_length": 50, "repetition_penalty": 1.3},
227
+ "sonnet": {"max_length": 200, "repetition_penalty": 1.2},
228
+ "free_verse": {"max_length": request.max_length, "repetition_penalty": request.repetition_penalty}
229
+ }
230
+
231
+ params = style_params.get(request.style, style_params["free_verse"])
232
+
233
+ return self.model.generate(
234
+ inputs,
235
+ attention_mask=attention_mask,
236
+ max_length=params["max_length"],
237
+ num_return_sequences=1,
238
+ temperature=request.temperature,
239
+ top_k=request.top_k,
240
+ top_p=request.top_p,
241
+ repetition_penalty=params["repetition_penalty"],
242
+ do_sample=True,
243
+ pad_token_id=self.tokenizer.eos_token_id,
244
+ use_cache=True,
245
+ no_repeat_ngram_size=3,
246
+ early_stopping=True,
247
+ bad_words_ids=[[self.tokenizer.encode(word)[0]] for word in
248
+ ['http', 'www', 'com', ':', '/', '#']],
249
+ min_length=20,
250
+ )
251
+
252
+ async def _process_outputs(self, outputs, request: GenerateRequest):
253
+ raw_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
254
+
255
+ prompt_pattern = f"Write a poem about: {request.prompt}\n\nPoem:"
256
+ poem_text = raw_text.replace(prompt_pattern, '').strip()
257
+
258
+ if request.style == "haiku":
259
+ formatted_lines = PoemFormatter.format_haiku(poem_text)
260
+ elif request.style == "sonnet":
261
+ formatted_lines = PoemFormatter.format_sonnet(poem_text)
262
+ else:
263
+ formatted_lines = PoemFormatter.format_free_verse(poem_text)
264
+
265
+ return {
266
+ "poem": {
267
+ "title": self._generate_title(poem_text),
268
+ "lines": formatted_lines,
269
+ "style": request.style
270
+ },
271
+ "original_prompt": request.prompt,
272
+ "parameters": {
273
+ "max_length": request.max_length,
274
+ "temperature": request.temperature,
275
+ "top_k": request.top_k,
276
+ "top_p": request.top_p,
277
+ "repetition_penalty": request.repetition_penalty
278
+ },
279
+ "metadata": {
280
+ "device": DEVICE.type,
281
+ "model_type": "GPT2",
282
+ "timestamp": datetime.now().isoformat()
283
  }
284
+ }
285
+
286
+ def _generate_title(self, poem_text: str) -> str:
287
+ words = poem_text.split()[:6]
288
+ stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to'}
289
+ key_words = [word for word in words if word.lower() not in stop_words]
290
+
291
+ if key_words:
292
+ title = ' '.join(key_words[:3]).capitalize()
293
+ return title
294
+ return "Untitled"
295
+
296
+ async def _check_cleanup(self):
297
+ if self.request_count % 100 == 0:
298
+ if DEVICE.type == 'cuda':
299
+ torch.cuda.empty_cache()
300
+ self.last_cleanup = datetime.now()
301
 
302
+ @asynccontextmanager
303
+ async def lifespan(app: FastAPI):
304
+ if not await model_manager.initialize():
 
305
  logger.error("Failed to initialize model manager")
306
+ yield
307
+ if model_manager.model is not None:
308
+ del model_manager.model
309
+ if model_manager.tokenizer is not None:
310
+ del model_manager.tokenizer
311
+ if DEVICE.type == 'cuda':
312
+ torch.cuda.empty_cache()
313
+
314
+ app = FastAPI(
315
+ title="Poetry Generation API",
316
+ description="Optimized API for generating poetry using GPT-2",
317
+ version="2.0.0",
318
+ lifespan=lifespan
319
+ )
320
+
321
+ app.add_middleware(
322
+ CORSMiddleware,
323
+ allow_origins=["*"],
324
+ allow_credentials=True,
325
+ allow_methods=["*"],
326
+ allow_headers=["*"],
327
+ )
328
+
329
+ model_manager = ModelManager()
330
 
331
  @app.get("/health")
332
  async def health_check():
 
333
  return {
334
  "status": "healthy",
335
  "model_loaded": model_manager.model is not None,
336
+ "tokenizer_loaded": model_manager.tokenizer is not None,
337
+ "device": DEVICE.type,
338
+ "request_count": model_manager.request_count,
339
+ "last_cleanup": model_manager.last_cleanup.isoformat(),
340
+ "system_info": {
341
+ "cuda_available": torch.cuda.is_available(),
342
+ "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
343
+ }
344
  }
345
 
346
  @app.post("/generate")
347
+ async def generate_text(
348
+ request: GenerateRequest,
349
+ background_tasks: BackgroundTasks
350
+ ):
351
+ try:
352
+ result = await model_manager.generate(request)
353
+
354
+ if model_manager.request_count % 100 == 0:
355
+ background_tasks.add_task(torch.cuda.empty_cache)
356
+
357
+ return JSONResponse(
358
+ content=result,
359
+ status_code=status.HTTP_200_OK
360
+ )
361
+ except Exception as e:
362
+ logger.error(f"Error in generate_text: {str(e)}")
363
+ raise HTTPException(
364
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
365
+ detail=str(e)
366
+ )