Hhggg / app.py
Yhhxhfh's picture
Update app.py
0111d07 verified
import os
import sys
import uvicorn
from fastapi import FastAPI, Query, HTTPException
from fastapi.responses import HTMLResponse
from starlette.middleware.cors import CORSMiddleware
from datasets import load_dataset, list_datasets
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from loguru import logger
import concurrent.futures
import psutil
import asyncio
import torch
from tenacity import retry, stop_after_attempt, wait_fixed
from huggingface_hub import HfApi
from huggingface_hub.utils import RepositoryNotFoundError
from dotenv import load_dotenv
# Cargar variables de entorno
load_dotenv()
# Obtener el token de Hugging Face
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HUGGINGFACE_TOKEN:
logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.")
sys.exit(1)
# Inicializar diccionarios para datasets y ejemplos
datasets_dict = {}
example_usage_list = []
# Configuración de caché
CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ["HF_HOME"] = CACHE_DIR
os.environ["HF_TOKEN"] = HUGGINGFACE_TOKEN
pipeline_instance = None # Solo un pipeline
# Flag para indicar si la inicialización está completa
initialization_complete = False
def initialize_model():
global pipeline_instance, initialization_complete
try:
logger.info("Initializing the GPT-2 model and tokenizer.")
base_model_repo = "gpt2" # Puedes elegir variantes como "gpt2-medium", "gpt2-large", etc.
model = AutoModelForCausalLM.from_pretrained(
base_model_repo,
cache_dir=CACHE_DIR,
ignore_mismatched_sizes=True # Ignorar discrepancias de tamaño
)
tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
pipeline_instance = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 if torch.cuda.is_available() else -1
)
logger.info("GPT-2 model and tokenizer initialized successfully.")
initialization_complete = True
except Exception as e:
logger.error(f"Error initializing model and tokenizer: {e}", exc_info=True)
sys.exit(1)
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5))
def download_dataset(dataset_name):
try:
logger.info(f"Starting download for dataset: {dataset_name}")
# Eliminado 'trust_remote_code=True' para evitar el error con ParquetConfig
datasets_dict[dataset_name] = load_dataset(dataset_name, cache_dir=CACHE_DIR)
create_example_usage(dataset_name)
except Exception as e:
logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True)
raise
def upload_model_to_hub():
try:
api = HfApi()
model_repo = "Yhhxhfh/Hhggg" # Reemplaza con tu repositorio en Hugging Face Hub
try:
api.repo_info(repo_id=model_repo)
logger.info(f"Model repository {model_repo} already exists.")
except RepositoryNotFoundError:
api.create_repo(repo_id=model_repo, private=False, token=HUGGINGFACE_TOKEN)
logger.info(f"Created model repository {model_repo}.")
logger.info(f"Pushing the model and tokenizer to {model_repo}.")
pipeline_instance.model.push_to_hub(model_repo, use_auth_token=HUGGINGFACE_TOKEN)
pipeline_instance.tokenizer.push_to_hub(model_repo, use_auth_token=HUGGINGFACE_TOKEN)
logger.info(f"Successfully pushed the model and tokenizer to {model_repo}.")
except Exception as e:
logger.error(f"Error uploading model to Hugging Face Hub: {e}", exc_info=True)
def create_example_usage(dataset_name):
try:
logger.info(f"Creating example usage for dataset {dataset_name}")
example_prompts = [
"Once upon a time,",
"In a world where AI rules,",
"The future of technology is",
"Explain the concept of",
"Describe a scenario where"
]
examples = []
for prompt in example_prompts:
generated_text = pipeline_instance(prompt, max_length=50, num_return_sequences=1)[0]['generated_text']
examples.append({"prompt": prompt, "response": generated_text})
example_usage_list.append({"dataset_name": dataset_name, "examples": examples})
logger.info(f"Example usage created for dataset {dataset_name}")
except Exception as e:
logger.error(f"Error creating example usage for dataset {dataset_name}: {e}", exc_info=True)
def unify_datasets():
try:
logger.info("Starting to unify datasets")
unified_dataset = None
for dataset in datasets_dict.values():
if unified_dataset is None:
unified_dataset = dataset
else:
unified_dataset = unified_dataset.concatenate(dataset)
datasets_dict['unified'] = unified_dataset
logger.info("Datasets successfully unified.")
except Exception as e:
logger.error(f"Error unifying datasets: {e}", exc_info=True)
# Configuración de concurrencia
cpu_count = psutil.cpu_count(logical=False) or 1
memory_available_mb = psutil.virtual_memory().available / (1024 * 1024)
memory_per_download_mb = 100
memory_available = int(memory_available_mb / memory_per_download_mb)
gpu_count = torch.cuda.device_count()
max_concurrent_downloads = min(cpu_count, memory_available, gpu_count * 2 if gpu_count else cpu_count)
max_concurrent_downloads = max(1, max_concurrent_downloads)
max_concurrent_downloads = min(10, max_concurrent_downloads)
logger.info(f"Using up to {max_concurrent_downloads} concurrent workers for downloading datasets.")
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads)
async def download_and_process_datasets():
dataset_names = list_datasets()
logger.info(f"Found {len(dataset_names)} datasets to download.")
loop = asyncio.get_event_loop()
tasks = []
for dataset_name in dataset_names:
task = loop.run_in_executor(executor, download_dataset, dataset_name)
tasks.append(task)
await asyncio.gather(*tasks)
unify_datasets()
upload_model_to_hub()
# Inicializar FastAPI con lifespan events para evitar DeprecationWarning
app = FastAPI()
# Configuración de CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Para mayor seguridad, especifica los orígenes permitidos
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
message_history = []
@app.on_event("startup")
async def startup_event():
logger.info("Application startup initiated.")
loop = asyncio.get_event_loop()
# Iniciar tareas en segundo plano sin bloquear el servidor
asyncio.create_task(run_initialization(loop))
logger.info("Startup tasks initiated.")
async def run_initialization(loop):
try:
# Inicializar el modelo en un hilo separado
await loop.run_in_executor(None, initialize_model)
# Descargar y procesar datasets
await download_and_process_datasets()
logger.info("All startup tasks completed successfully.")
except Exception as e:
logger.error(f"Error during startup tasks: {e}", exc_info=True)
@app.get('/')
async def index():
html_code = """
<!DOCTYPE html>
<html lang="en">
<head>
<!-- Existing head content -->
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>GPT-2 Chatbot</title>
<!-- Bootstrap CSS for a professional interface -->
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
<style>
body {
background-color: #f8f9fa;
font-family: Arial, sans-serif;
}
.container {
max-width: 800px;
margin-top: 50px;
}
.chat-container {
background-color: #ffffff;
border-radius: 10px;
box-shadow: 0 0 15px rgba(0,0,0,0.2);
padding: 20px;
display: flex;
flex-direction: column;
height: 600px;
}
.chat-box {
flex: 1;
overflow-y: auto;
margin-bottom: 15px;
}
.chat-input {
width: 100%;
padding: 10px;
border: 1px solid #ced4da;
border-radius: 5px;
font-size: 16px;
}
.chat-input:focus {
outline: none;
border-color: #80bdff;
box-shadow: 0 0 5px rgba(0,123,255,0.5);
}
.user-message {
text-align: right;
margin-bottom: 10px;
}
.user-message .message {
display: inline-block;
background-color: #007bff;
color: #fff;
padding: 10px 15px;
border-radius: 15px;
max-width: 70%;
}
.bot-message {
text-align: left;
margin-bottom: 10px;
}
.bot-message .message {
display: inline-block;
background-color: #6c757d;
color: #fff;
padding: 10px 15px;
border-radius: 15px;
max-width: 70%;
}
.toggle-history {
text-align: center;
cursor: pointer;
color: #007bff;
margin-top: 10px;
}
.history-container {
display: none;
background-color: #ffffff;
border-radius: 10px;
box-shadow: 0 0 15px rgba(0,0,0,0.2);
padding: 20px;
margin-top: 20px;
max-height: 300px;
overflow-y: auto;
}
</style>
</head>
<body>
<div class="container">
<h1 class="text-center mb-4">GPT-2 Chatbot</h1>
<div class="chat-container">
<div class="chat-box" id="chat-box">
</div>
<input type="text" class="chat-input" id="user-input" placeholder="Type your message..." onkeypress="handleKeyPress(event)">
<button class="btn btn-primary mt-3" onclick="sendMessage()">Send</button>
<div class="toggle-history mt-3" onclick="toggleHistory()">Toggle History</div>
<div class="history-container" id="history-container">
<h3>Chat History</h3>
<div id="history-content"></div>
</div>
</div>
</div>
<!-- Bootstrap JS (optional) -->
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
<script>
function toggleHistory() {
const historyContainer = document.getElementById('history-container');
historyContainer.classList.toggle('d-none');
}
function saveMessage(sender, message) {
const historyContent = document.getElementById('history-content');
const messageElement = document.createElement('div');
messageElement.className = sender === 'user' ? 'user-message' : 'bot-message';
messageElement.innerHTML = `<div class="message">${message}</div>`;
historyContent.appendChild(messageElement);
}
function appendMessage(sender, message) {
const chatBox = document.getElementById('chat-box');
const messageElement = document.createElement('div');
messageElement.className = sender === 'user' ? 'user-message' : 'bot-message';
messageElement.innerHTML = `<div class="message">${message}</div>`;
chatBox.appendChild(messageElement);
chatBox.scrollTop = chatBox.scrollHeight;
}
function handleKeyPress(event) {
if (event.key === 'Enter') {
event.preventDefault();
sendMessage();
}
}
function sendMessage() {
const userInput = document.getElementById('user-input');
const userMessage = userInput.value.trim();
if (userMessage === '') return;
appendMessage('user', userMessage);
saveMessage('user', userMessage);
userInput.value = '';
fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`)
.then(response => {
if (response.status === 503) {
return response.json().then(data => { throw new Error(data.detail); });
}
return response.json();
})
.then(data => {
const botMessages = data.result;
botMessages.forEach(message => {
appendMessage('bot', message);
saveMessage('bot', message);
});
})
.catch(error => {
console.error('Error:', error);
appendMessage('bot', "Sorry, I'm not available right now. Please try again later.");
saveMessage('bot', "Sorry, I'm not available right now. Please try again later.");
});
}
function retryLastMessage() {
const lastUserMessage = document.querySelector('.user-message:last-of-type .message');
if (lastUserMessage) {
const userInput = document.getElementById('user-input');
userInput.value = lastUserMessage.innerText;
sendMessage();
}
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_code, status_code=200)
@app.get('/autocomplete')
async def autocomplete(q: str = Query(..., title='query')):
global message_history, pipeline_instance, initialization_complete
message_history.append(('user', q))
if not initialization_complete:
logger.warning("Model is not initialized yet.")
raise HTTPException(status_code=503, detail="Model is not initialized yet. Please try again later.")
try:
response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text']
logger.debug(f"Successfully autocomplete, q:{q}, res:{response}")
return {"result": [response]}
except Exception as e:
logger.error(f"Ignored error in autocomplete: {e}", exc_info=True)
return {"result": ["Sorry, I encountered an error processing your request."]}
if __name__ == '__main__':
port = 7860 # Configurar FastAPI para que inicie en el puerto 7860
uvicorn.run(app=app, host='0.0.0.0', port=port)