Spaces:
Running
Running
import os | |
import logging | |
import sys | |
from datetime import datetime | |
from typing import Optional, Dict, Any, List | |
from functools import lru_cache | |
import torch | |
import asyncio | |
import numpy as np | |
import re | |
from fastapi import FastAPI, HTTPException, status, BackgroundTasks, Depends | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel, Field, validator | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config | |
from contextlib import asynccontextmanager | |
# Configuration | |
class Config: | |
BASE_MODEL_DIR = "./models/" | |
MODEL_PATH = os.path.join(BASE_MODEL_DIR, "poeticagpt.pth") | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
BATCH_SIZE = 8 # Increased batch size for better throughput | |
CACHE_SIZE = 2048 # Increased cache size | |
MAX_QUEUE_SIZE = 16 # Maximum number of requests to queue | |
QUANTIZE_MODEL = True # Enable quantization for improved performance | |
WARMUP_INPUTS = True # Pre-warm the model with sample inputs | |
# Use environment-specific log directory or default to a temp directory | |
LOG_DIR = os.environ.get('LOG_DIR', '/tmp/poetry_logs') | |
ENABLE_PROFILING = False # Set to True to enable performance profiling | |
REQUEST_TIMEOUT = 30.0 # Timeout for request processing in seconds | |
MODEL_CONFIG = GPT2Config( | |
n_positions=400, | |
n_ctx=400, | |
n_embd=384, | |
n_layer=6, | |
n_head=6, | |
vocab_size=50257, | |
bos_token_id=50256, | |
eos_token_id=50256, | |
use_cache=True, | |
) | |
config = Config() | |
# Configure logging with proper error handling | |
def setup_logging(): | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
# Always add stdout handler | |
console_handler = logging.StreamHandler(sys.stdout) | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
# Try to set up file handler, but handle permission issues gracefully | |
try: | |
# Attempt to create directory if it doesn't exist | |
os.makedirs(config.LOG_DIR, exist_ok=True) | |
log_file = os.path.join( | |
config.LOG_DIR, | |
f'poetry_generation_{datetime.now().strftime("%Y%m%d")}.log' | |
) | |
# Test if we can write to the file | |
with open(log_file, 'a') as f: | |
pass | |
file_handler = logging.FileHandler(log_file) | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
print(f"Log file created at: {log_file}") | |
except (PermissionError, OSError) as e: | |
print(f"Warning: Could not create log file: {e}") | |
print(f"Continuing with console logging only.") | |
return logger | |
# Initialize logger | |
logger = setup_logging() | |
# Request models | |
class GenerateRequest(BaseModel): | |
prompt: str = Field(..., min_length=1, max_length=500) | |
max_length: Optional[int] = Field(default=100, ge=10, le=500) | |
temperature: float = Field(default=0.9, ge=0.1, le=2.0) | |
top_k: int = Field(default=50, ge=1, le=100) | |
top_p: float = Field(default=0.95, ge=0.1, le=1.0) | |
repetition_penalty: float = Field(default=1.2, ge=1.0, le=2.0) | |
style: Optional[str] = Field(default="free_verse", | |
description="Poetry style: free_verse, haiku, sonnet") | |
def validate_prompt(cls, v): | |
# Normalize whitespace | |
v = ' '.join(v.split()) | |
return v | |
# Poem formatting module | |
class PoemFormatter: | |
"""Efficient poem formatter with optimized text processing""" | |
def format_free_verse(text: str) -> List[str]: | |
# More efficient regex splitting | |
lines = re.split(r'[.!?]+|\n+', text) | |
lines = [line.strip() for line in lines if line.strip()] | |
formatted_lines = [] | |
for line in lines: | |
if len(line) > 40: | |
parts = line.split(',') | |
formatted_lines.extend(part.strip() for part in parts if part.strip()) | |
else: | |
formatted_lines.append(line) | |
return formatted_lines | |
def format_haiku(text: str) -> List[str]: | |
# Precompile regex for performance | |
vowel_pattern = re.compile(r'[aeiou]+') | |
words = text.split() | |
lines = [] | |
current_line = [] | |
syllable_count = 0 | |
syllable_targets = [5, 7, 5] # Traditional haiku structure | |
current_target_idx = 0 | |
for word in words: | |
syllables = len(vowel_pattern.findall(word.lower())) or 1 # Ensure at least 1 syllable | |
if current_target_idx >= len(syllable_targets): | |
break | |
current_target = syllable_targets[current_target_idx] | |
if syllable_count + syllables <= current_target: | |
current_line.append(word) | |
syllable_count += syllables | |
else: | |
if current_line: | |
lines.append(' '.join(current_line)) | |
current_line = [word] | |
syllable_count = syllables | |
current_target_idx += 1 | |
if current_line and len(lines) < len(syllable_targets): | |
lines.append(' '.join(current_line)) | |
# Ensure we have exactly 3 lines for a haiku | |
while len(lines) < 3: | |
lines.append("...") | |
return lines[:3] | |
def format_sonnet(text: str) -> List[str]: | |
words = text.split() | |
lines = [] | |
current_line = [] | |
target_line_length = 10 # Approximate iambic pentameter | |
for word in words: | |
current_line.append(word) | |
if len(current_line) >= target_line_length: | |
lines.append(' '.join(current_line)) | |
current_line = [] | |
if len(lines) >= 14: # Traditional sonnet has 14 lines | |
break | |
if current_line and len(lines) < 14: | |
lines.append(' '.join(current_line)) | |
# Ensure we have 14 lines for a complete sonnet | |
while len(lines) < 14: | |
lines.append("...") | |
return lines | |
def generate_title(poem_text: str) -> str: | |
words = poem_text.split()[:10] # Use more words to find better title candidates | |
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'with', 'by'} | |
key_words = [word for word in words if word.lower() not in stop_words and len(word) > 2] | |
if key_words: | |
title = ' '.join(key_words[:3]).strip().capitalize() | |
return title if title else "Untitled" | |
return "Untitled" | |
# Request queue for efficient processing | |
class RequestQueue: | |
def __init__(self, max_size=config.MAX_QUEUE_SIZE): | |
self.queue = asyncio.Queue(maxsize=max_size) | |
self.semaphore = asyncio.Semaphore(max_size) | |
async def add_request(self, request_data): | |
async with self.semaphore: | |
return await asyncio.wait_for( | |
self._process_request(request_data), | |
timeout=config.REQUEST_TIMEOUT | |
) | |
async def _process_request(self, request_data): | |
future = asyncio.Future() | |
await self.queue.put((request_data, future)) | |
return await future | |
# Optimized Tokenization Service | |
class TokenizationService: | |
def __init__(self): | |
self.tokenizer = None | |
self._lock = asyncio.Lock() | |
def cached_tokenize(self, text): | |
return self.tokenizer.encode(text, return_tensors='pt') | |
async def initialize(self): | |
async with self._lock: | |
if self.tokenizer is None: | |
logger.info("Initializing tokenizer") | |
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
return self.tokenizer | |
async def encode(self, text): | |
if not self.tokenizer: | |
await self.initialize() | |
# Use multithreading for tokenization if the text is large | |
if len(text) > 100: | |
loop = asyncio.get_event_loop() | |
return await loop.run_in_executor( | |
None, | |
lambda: self.cached_tokenize(text) | |
) | |
else: | |
return self.cached_tokenize(text) | |
def decode(self, tokens, skip_special_tokens=True): | |
return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens) | |
# Model Manager with optimization techniques | |
class ModelManager: | |
def __init__(self): | |
self.model = None | |
self._lock = asyncio.Lock() | |
self.request_count = 0 | |
self.last_cleanup = datetime.now() | |
self.model_ready = asyncio.Event() | |
self.tokenization_service = TokenizationService() | |
self.request_queue = RequestQueue() | |
self.poem_formatter = PoemFormatter() | |
self.batch_processor_task = None | |
async def initialize(self) -> bool: | |
try: | |
logger.info(f"Initializing model on device: {config.DEVICE}") | |
# Check if model file exists | |
if not os.path.exists(config.MODEL_PATH): | |
logger.error(f"Model file not found at {config.MODEL_PATH}") | |
# Try to create directory in case it doesn't exist | |
os.makedirs(os.path.dirname(config.MODEL_PATH), exist_ok=True) | |
return False | |
await self.tokenization_service.initialize() | |
await self._load_and_optimize_model() | |
# Start batch processing worker | |
self.batch_processor_task = asyncio.create_task(self._batch_processor_worker()) | |
logger.info(f"Model and tokenizer loaded successfully on {config.DEVICE}") | |
self.model_ready.set() | |
# Warmup the model with dummy inputs | |
if config.WARMUP_INPUTS: | |
await self._warmup_model() | |
return True | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}") | |
logger.exception("Detailed traceback:") | |
return False | |
async def _batch_processor_worker(self): | |
"""Worker that processes queued requests in batches""" | |
logger.info("Starting batch processor worker") | |
try: | |
while True: | |
# Process requests in batches when possible | |
if not self.request_queue.queue.empty(): | |
batch = [] | |
batch_futures = [] | |
# Get up to BATCH_SIZE requests from the queue | |
batch_size = min(config.BATCH_SIZE, self.request_queue.queue.qsize()) | |
for _ in range(batch_size): | |
if self.request_queue.queue.empty(): | |
break | |
request_data, future = await self.request_queue.queue.get() | |
batch.append(request_data) | |
batch_futures.append(future) | |
if batch: | |
try: | |
# Process the batch | |
results = await self._process_batch(batch) | |
# Set results to futures | |
for i, future in enumerate(batch_futures): | |
if not future.done(): | |
future.set_result(results[i]) | |
except Exception as e: | |
# Set exception to all futures in the batch | |
for future in batch_futures: | |
if not future.done(): | |
future.set_exception(e) | |
finally: | |
# Mark tasks as done | |
for _ in range(len(batch)): | |
self.request_queue.queue.task_done() | |
else: | |
# If queue is empty, sleep briefly before checking again | |
await asyncio.sleep(0.01) | |
except asyncio.CancelledError: | |
logger.info("Batch processor worker cancelled") | |
except Exception as e: | |
logger.error(f"Error in batch processor worker: {str(e)}") | |
logger.exception("Detailed traceback") | |
async def _process_batch(self, batch_requests): | |
"""Process a batch of requests efficiently""" | |
results = [] | |
# Use with torch.no_grad() for all requests in the batch | |
with torch.no_grad(): | |
for request in batch_requests: | |
try: | |
# Prepare inputs | |
inputs = await self._prepare_inputs(request.prompt) | |
# Generate text | |
outputs = await self._generate_optimized(inputs, request) | |
# Process outputs | |
result = await self._process_outputs(outputs, request) | |
results.append(result) | |
except Exception as e: | |
logger.error(f"Error processing request in batch: {str(e)}") | |
results.append({"error": str(e)}) | |
return results | |
async def _load_and_optimize_model(self): | |
"""Load and optimize the model with advanced techniques""" | |
async with self._lock: | |
if not os.path.exists(config.MODEL_PATH): | |
raise FileNotFoundError(f"Model file not found at {config.MODEL_PATH}") | |
# Create model with configuration | |
self.model = GPT2LMHeadModel(config.MODEL_CONFIG) | |
# Load state dict | |
state_dict = torch.load(config.MODEL_PATH, map_location=config.DEVICE) | |
self.model.load_state_dict(state_dict, strict=False) | |
# Move model to device | |
self.model.to(config.DEVICE) | |
self.model.eval() # Set to evaluation mode | |
# Apply quantization if enabled and supported | |
if config.QUANTIZE_MODEL and config.DEVICE.type == 'cuda': | |
try: | |
# Use dynamic quantization for better inference performance | |
torch.quantization.quantize_dynamic( | |
self.model, {torch.nn.Linear}, dtype=torch.qint8 | |
) | |
logger.info("Model quantized successfully") | |
except Exception as e: | |
logger.warning(f"Quantization failed, using full precision: {str(e)}") | |
# Apply other optimizations for CUDA devices | |
if config.DEVICE.type == 'cuda': | |
# Set optimization flags | |
torch.backends.cudnn.benchmark = True | |
# Enable TF32 precision if available (on A100 GPUs) | |
if hasattr(torch.backends.cuda, 'matmul') and hasattr(torch.backends.cuda.matmul, 'allow_tf32'): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# Convert model to TorchScript for faster inference | |
try: | |
# Use a safer approach to TorchScript optimization | |
self.model = torch.jit.script(self.model) | |
logger.info("Model optimized with TorchScript") | |
except Exception as e: | |
logger.warning(f"TorchScript optimization failed: {str(e)}") | |
async def _warmup_model(self): | |
"""Pre-warm the model with sample inputs to eliminate cold start issues""" | |
logger.info("Warming up model...") | |
# Create dummy inputs of different lengths | |
dummy_texts = [ | |
"Write a poem about nature", | |
"Write a poem about love and loss in the modern world" | |
] | |
# Process dummy requests | |
dummy_requests = [ | |
GenerateRequest(prompt=text, max_length=50, temperature=0.9) | |
for text in dummy_texts | |
] | |
for req in dummy_requests: | |
try: | |
with torch.no_grad(): | |
# Prepare inputs | |
inputs = await self._prepare_inputs(req.prompt) | |
# Run model inference | |
_ = await self._generate_optimized(inputs, req) | |
except Exception as e: | |
logger.warning(f"Model warmup error: {str(e)}") | |
logger.info("Model warmup completed") | |
async def _prepare_inputs(self, prompt: str): | |
"""Prepare model inputs with optimized tokenization""" | |
poetry_prompt = f"Write a poem about: {prompt}\n\nPoem:" | |
tokens = await self.tokenization_service.encode(poetry_prompt) | |
return tokens.to(config.DEVICE) | |
async def _generate_optimized(self, inputs, request: GenerateRequest): | |
"""Optimized text generation with style-specific parameters""" | |
attention_mask = torch.ones(inputs.shape, dtype=torch.long, device=config.DEVICE) | |
# Style-specific parameters | |
style_params = { | |
"haiku": {"max_length": 50, "repetition_penalty": 1.4, "no_repeat_ngram_size": 2}, | |
"sonnet": {"max_length": 200, "repetition_penalty": 1.2, "no_repeat_ngram_size": 3}, | |
"free_verse": { | |
"max_length": request.max_length, | |
"repetition_penalty": request.repetition_penalty, | |
"no_repeat_ngram_size": 3 | |
} | |
} | |
params = style_params.get(request.style, style_params["free_verse"]) | |
# Get bad word IDs for filtering | |
tokenizer = await self.tokenization_service.initialize() | |
bad_words = ['http', 'www', 'com', ':', '/', '#', '[', ']', '{', '}'] | |
bad_words_ids = [[tokenizer.encode(word)[0]] for word in bad_words if len(tokenizer.encode(word)) > 0] | |
return self.model.generate( | |
inputs, | |
attention_mask=attention_mask, | |
max_length=params["max_length"], | |
num_return_sequences=1, | |
temperature=request.temperature, | |
top_k=request.top_k, | |
top_p=request.top_p, | |
repetition_penalty=params["repetition_penalty"], | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id, | |
use_cache=True, | |
no_repeat_ngram_size=params["no_repeat_ngram_size"], | |
early_stopping=True, | |
bad_words_ids=bad_words_ids, | |
min_length=20 if request.style != "haiku" else 10, | |
) | |
async def _process_outputs(self, outputs, request: GenerateRequest): | |
"""Process and format the generated text into a poem""" | |
# Decode generated text | |
raw_text = self.tokenization_service.decode(outputs[0], skip_special_tokens=True) | |
# Extract poem from generated text | |
prompt_pattern = f"Write a poem about: {request.prompt}\n\nPoem:" | |
poem_text = raw_text.replace(prompt_pattern, '').strip() | |
# Format based on style | |
if request.style == "haiku": | |
formatted_lines = self.poem_formatter.format_haiku(poem_text) | |
elif request.style == "sonnet": | |
formatted_lines = self.poem_formatter.format_sonnet(poem_text) | |
else: | |
formatted_lines = self.poem_formatter.format_free_verse(poem_text) | |
# Generate response | |
return { | |
"poem": { | |
"title": self.poem_formatter.generate_title(poem_text), | |
"lines": formatted_lines, | |
"style": request.style | |
}, | |
"original_prompt": request.prompt, | |
"parameters": { | |
"max_length": request.max_length, | |
"temperature": request.temperature, | |
"top_k": request.top_k, | |
"top_p": request.top_p, | |
"repetition_penalty": request.repetition_penalty | |
}, | |
"metadata": { | |
"device": config.DEVICE.type, | |
"model_type": "GPT2-Optimized", | |
"timestamp": datetime.now().isoformat() | |
} | |
} | |
async def generate(self, request: GenerateRequest) -> Dict[str, Any]: | |
"""Queue a request for generation and await result""" | |
try: | |
# Wait for model to be ready | |
await asyncio.wait_for(self.model_ready.wait(), timeout=60.0) | |
self.request_count += 1 | |
# Add request to queue and get result | |
result = await self.request_queue.add_request(request) | |
return result | |
except asyncio.TimeoutError: | |
raise HTTPException( | |
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
detail="Model is still initializing or overloaded" | |
) | |
except Exception as e: | |
logger.error(f"Error generating text: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=str(e) | |
) | |
async def cleanup(self): | |
"""Perform memory cleanup operations""" | |
if config.DEVICE.type == 'cuda': | |
torch.cuda.empty_cache() | |
self.last_cleanup = datetime.now() | |
logger.info("Memory cleanup performed") | |
async def shutdown(self): | |
"""Clean shutdown of the model manager""" | |
# Cancel batch processor worker | |
if self.batch_processor_task: | |
self.batch_processor_task.cancel() | |
try: | |
await self.batch_processor_task | |
except asyncio.CancelledError: | |
pass | |
# Clear model from memory | |
if self.model is not None: | |
self.model = None | |
# Clear tokenizer from memory | |
if self.tokenization_service.tokenizer is not None: | |
self.tokenization_service.tokenizer = None | |
# Final memory cleanup | |
if config.DEVICE.type == 'cuda': | |
torch.cuda.empty_cache() | |
# Create model manager instance | |
model_manager = ModelManager() | |
# FastAPI lifespan | |
async def lifespan(app: FastAPI): | |
# Initialize on startup | |
initialized = await model_manager.initialize() | |
if not initialized: | |
logger.error("Failed to initialize model manager") | |
yield | |
# Clean up on shutdown | |
logger.info("Shutting down Poetry Generation API") | |
await model_manager.shutdown() | |
# Create FastAPI app | |
app = FastAPI( | |
title="Poetry Generation API", | |
description="High-Performance API for generating poetry using GPT-2", | |
version="3.0.0", | |
lifespan=lifespan | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Health check endpoint | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"model_loaded": model_manager.model is not None, | |
"model_ready": model_manager.model_ready.is_set(), | |
"tokenizer_loaded": model_manager.tokenization_service.tokenizer is not None, | |
"device": config.DEVICE.type, | |
"request_count": model_manager.request_count, | |
"queue_size": model_manager.request_queue.queue.qsize(), | |
"last_cleanup": model_manager.last_cleanup.isoformat(), | |
"system_info": { | |
"cuda_available": torch.cuda.is_available(), | |
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, | |
"cuda_memory": { | |
"allocated": f"{torch.cuda.memory_allocated() / (1024**2):.2f} MB", | |
"reserved": f"{torch.cuda.memory_reserved() / (1024**2):.2f} MB", | |
"max_allocated": f"{torch.cuda.max_memory_allocated() / (1024**2):.2f} MB" | |
} if torch.cuda.is_available() else {}, | |
} | |
} | |
# Poetry generation endpoint | |
async def generate_text( | |
request: GenerateRequest, | |
background_tasks: BackgroundTasks | |
): | |
try: | |
result = await model_manager.generate(request) | |
# Schedule cleanup every 50 requests | |
if model_manager.request_count % 50 == 0: | |
background_tasks.add_task(model_manager.cleanup) | |
return JSONResponse( | |
content=result, | |
status_code=status.HTTP_200_OK | |
) | |
except HTTPException as e: | |
# Re-raise HTTP exceptions | |
raise | |
except Exception as e: | |
logger.error(f"Error in generate_text: {str(e)}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=str(e) | |
) | |
# Add profiling endpoint if profiling is enabled | |
if config.ENABLE_PROFILING: | |
async def get_profiling(): | |
if config.DEVICE.type == 'cuda': | |
return { | |
"memory": { | |
"allocated": f"{torch.cuda.memory_allocated() / (1024**2):.2f} MB", | |
"reserved": f"{torch.cuda.memory_reserved() / (1024**2):.2f} MB", | |
"max_allocated": f"{torch.cuda.max_memory_allocated() / (1024**2):.2f} MB" | |
}, | |
"request_count": model_manager.request_count, | |
"queue_size": model_manager.request_queue.queue.qsize(), | |
} | |
else: | |
return {"device": "cpu", "profiling": "not available"} |