Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,27 +5,31 @@ from fastapi.responses import StreamingResponse
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from transformers import (
|
| 7 |
AutoConfig,
|
| 8 |
-
pipeline,
|
| 9 |
AutoModelForCausalLM,
|
| 10 |
AutoTokenizer,
|
| 11 |
GenerationConfig,
|
| 12 |
-
StoppingCriteriaList
|
|
|
|
| 13 |
)
|
| 14 |
import asyncio
|
| 15 |
from io import BytesIO
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
# Diccionario global para almacenar los tokens
|
| 18 |
token_dict = {}
|
| 19 |
|
| 20 |
-
#
|
| 21 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
| 22 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 23 |
AWS_REGION = os.getenv("AWS_REGION")
|
| 24 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
| 25 |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 26 |
|
|
|
|
| 27 |
app = FastAPI()
|
| 28 |
|
|
|
|
| 29 |
class GenerateRequest(BaseModel):
|
| 30 |
model_name: str
|
| 31 |
input_text: str
|
|
@@ -42,14 +46,19 @@ class GenerateRequest(BaseModel):
|
|
| 42 |
stop_sequences: list[str] = []
|
| 43 |
|
| 44 |
class S3ModelLoader:
|
| 45 |
-
def __init__(self, bucket_name,
|
| 46 |
self.bucket_name = bucket_name
|
| 47 |
-
self.s3_client =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
def _get_s3_uri(self, model_name):
|
| 50 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
| 51 |
|
| 52 |
-
|
| 53 |
if model_name in token_dict:
|
| 54 |
return token_dict[model_name]
|
| 55 |
|
|
@@ -69,55 +78,14 @@ class S3ModelLoader:
|
|
| 69 |
}
|
| 70 |
|
| 71 |
return token_dict[model_name]
|
|
|
|
|
|
|
| 72 |
except Exception as e:
|
| 73 |
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
| 74 |
|
| 75 |
-
model_loader = S3ModelLoader(S3_BUCKET_NAME,
|
| 76 |
-
|
| 77 |
-
@app.post("/generate")
|
| 78 |
-
async def generate(request: GenerateRequest):
|
| 79 |
-
try:
|
| 80 |
-
model_name = request.model_name
|
| 81 |
-
input_text = request.input_text
|
| 82 |
-
temperature = request.temperature
|
| 83 |
-
max_new_tokens = request.max_new_tokens
|
| 84 |
-
stream = request.stream
|
| 85 |
-
top_p = request.top_p
|
| 86 |
-
top_k = request.top_k
|
| 87 |
-
repetition_penalty = request.repetition_penalty
|
| 88 |
-
num_return_sequences = request.num_return_sequences
|
| 89 |
-
do_sample = request.do_sample
|
| 90 |
-
chunk_delay = request.chunk_delay
|
| 91 |
-
stop_sequences = request.stop_sequences
|
| 92 |
-
|
| 93 |
-
# Cargar modelo y tokenizer desde el S3
|
| 94 |
-
model_data = await model_loader.load_model_and_tokenizer(model_name)
|
| 95 |
-
model = model_data["model"]
|
| 96 |
-
tokenizer = model_data["tokenizer"]
|
| 97 |
-
pad_token_id = model_data["pad_token_id"]
|
| 98 |
-
eos_token_id = model_data["eos_token_id"]
|
| 99 |
-
|
| 100 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 101 |
-
model.to(device)
|
| 102 |
-
|
| 103 |
-
generation_config = GenerationConfig(
|
| 104 |
-
temperature=temperature,
|
| 105 |
-
max_new_tokens=max_new_tokens,
|
| 106 |
-
top_p=top_p,
|
| 107 |
-
top_k=top_k,
|
| 108 |
-
repetition_penalty=repetition_penalty,
|
| 109 |
-
do_sample=do_sample,
|
| 110 |
-
num_return_sequences=num_return_sequences,
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
return StreamingResponse(
|
| 114 |
-
stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
|
| 115 |
-
media_type="text/plain"
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
except Exception as e:
|
| 119 |
-
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 120 |
|
|
|
|
| 121 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
|
| 122 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
| 123 |
input_length = encoded_input["input_ids"].shape[1]
|
|
@@ -159,20 +127,52 @@ async def stream_text(model, tokenizer, input_text, generation_config, stop_sequ
|
|
| 159 |
yield output_text
|
| 160 |
return
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
@app.post("/generate-image")
|
| 177 |
async def generate_image(request: GenerateRequest):
|
| 178 |
try:
|
|
@@ -191,6 +191,7 @@ async def generate_image(request: GenerateRequest):
|
|
| 191 |
except Exception as e:
|
| 192 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 193 |
|
|
|
|
| 194 |
@app.post("/generate-text-to-speech")
|
| 195 |
async def generate_text_to_speech(request: GenerateRequest):
|
| 196 |
try:
|
|
@@ -209,6 +210,7 @@ async def generate_text_to_speech(request: GenerateRequest):
|
|
| 209 |
except Exception as e:
|
| 210 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 211 |
|
|
|
|
| 212 |
@app.post("/generate-video")
|
| 213 |
async def generate_video(request: GenerateRequest):
|
| 214 |
try:
|
|
@@ -226,6 +228,7 @@ async def generate_video(request: GenerateRequest):
|
|
| 226 |
except Exception as e:
|
| 227 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 228 |
|
|
|
|
| 229 |
if __name__ == "__main__":
|
| 230 |
import uvicorn
|
| 231 |
-
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
| 5 |
from pydantic import BaseModel
|
| 6 |
from transformers import (
|
| 7 |
AutoConfig,
|
|
|
|
| 8 |
AutoModelForCausalLM,
|
| 9 |
AutoTokenizer,
|
| 10 |
GenerationConfig,
|
| 11 |
+
StoppingCriteriaList,
|
| 12 |
+
pipeline
|
| 13 |
)
|
| 14 |
import asyncio
|
| 15 |
from io import BytesIO
|
| 16 |
+
from botocore.exceptions import NoCredentialsError
|
| 17 |
+
import boto3
|
| 18 |
|
| 19 |
+
# Diccionario global para almacenar los tokens y configuraciones de los modelos
|
| 20 |
token_dict = {}
|
| 21 |
|
| 22 |
+
# Configuraci贸n para acceso a modelos en Hugging Face o S3
|
| 23 |
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
| 24 |
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
| 25 |
AWS_REGION = os.getenv("AWS_REGION")
|
| 26 |
S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
|
| 27 |
HUGGINGFACE_HUB_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
|
| 28 |
|
| 29 |
+
# Inicializaci贸n de la aplicaci贸n FastAPI
|
| 30 |
app = FastAPI()
|
| 31 |
|
| 32 |
+
# Modelo de la solicitud para la API
|
| 33 |
class GenerateRequest(BaseModel):
|
| 34 |
model_name: str
|
| 35 |
input_text: str
|
|
|
|
| 46 |
stop_sequences: list[str] = []
|
| 47 |
|
| 48 |
class S3ModelLoader:
|
| 49 |
+
def __init__(self, bucket_name, aws_access_key_id=None, aws_secret_access_key=None, aws_region=None):
|
| 50 |
self.bucket_name = bucket_name
|
| 51 |
+
self.s3_client = boto3.client(
|
| 52 |
+
's3',
|
| 53 |
+
aws_access_key_id=aws_access_key_id,
|
| 54 |
+
aws_secret_access_key=aws_secret_access_key,
|
| 55 |
+
region_name=aws_region
|
| 56 |
+
)
|
| 57 |
|
| 58 |
def _get_s3_uri(self, model_name):
|
| 59 |
return f"s3://{self.bucket_name}/{model_name.replace('/', '-')}"
|
| 60 |
|
| 61 |
+
def load_model_and_tokenizer(self, model_name):
|
| 62 |
if model_name in token_dict:
|
| 63 |
return token_dict[model_name]
|
| 64 |
|
|
|
|
| 78 |
}
|
| 79 |
|
| 80 |
return token_dict[model_name]
|
| 81 |
+
except NoCredentialsError:
|
| 82 |
+
raise HTTPException(status_code=500, detail="AWS credentials not found.")
|
| 83 |
except Exception as e:
|
| 84 |
raise HTTPException(status_code=500, detail=f"Error loading model: {e}")
|
| 85 |
|
| 86 |
+
model_loader = S3ModelLoader(S3_BUCKET_NAME, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
+
# Funci贸n para hacer streaming de texto, generando un token a la vez
|
| 89 |
async def stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay, max_length=2048):
|
| 90 |
encoded_input = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_length).to(device)
|
| 91 |
input_length = encoded_input["input_ids"].shape[1]
|
|
|
|
| 127 |
yield output_text
|
| 128 |
return
|
| 129 |
|
| 130 |
+
# Endpoint para la generaci贸n de texto
|
| 131 |
+
@app.post("/generate")
|
| 132 |
+
async def generate(request: GenerateRequest):
|
| 133 |
+
try:
|
| 134 |
+
model_name = request.model_name
|
| 135 |
+
input_text = request.input_text
|
| 136 |
+
temperature = request.temperature
|
| 137 |
+
max_new_tokens = request.max_new_tokens
|
| 138 |
+
stream = request.stream
|
| 139 |
+
top_p = request.top_p
|
| 140 |
+
top_k = request.top_k
|
| 141 |
+
repetition_penalty = request.repetition_penalty
|
| 142 |
+
num_return_sequences = request.num_return_sequences
|
| 143 |
+
do_sample = request.do_sample
|
| 144 |
+
chunk_delay = request.chunk_delay
|
| 145 |
+
stop_sequences = request.stop_sequences
|
| 146 |
+
|
| 147 |
+
# Cargar el modelo y el tokenizer desde el S3
|
| 148 |
+
model_data = model_loader.load_model_and_tokenizer(model_name)
|
| 149 |
+
model = model_data["model"]
|
| 150 |
+
tokenizer = model_data["tokenizer"]
|
| 151 |
+
pad_token_id = model_data["pad_token_id"]
|
| 152 |
+
eos_token_id = model_data["eos_token_id"]
|
| 153 |
+
|
| 154 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 155 |
+
model.to(device)
|
| 156 |
+
|
| 157 |
+
generation_config = GenerationConfig(
|
| 158 |
+
temperature=temperature,
|
| 159 |
+
max_new_tokens=max_new_tokens,
|
| 160 |
+
top_p=top_p,
|
| 161 |
+
top_k=top_k,
|
| 162 |
+
repetition_penalty=repetition_penalty,
|
| 163 |
+
do_sample=do_sample,
|
| 164 |
+
num_return_sequences=num_return_sequences,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
return StreamingResponse(
|
| 168 |
+
stream_text(model, tokenizer, input_text, generation_config, stop_sequences, device, chunk_delay),
|
| 169 |
+
media_type="text/plain"
|
| 170 |
)
|
| 171 |
|
| 172 |
+
except Exception as e:
|
| 173 |
+
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 174 |
+
|
| 175 |
+
# Endpoint para la generaci贸n de im谩genes
|
| 176 |
@app.post("/generate-image")
|
| 177 |
async def generate_image(request: GenerateRequest):
|
| 178 |
try:
|
|
|
|
| 191 |
except Exception as e:
|
| 192 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 193 |
|
| 194 |
+
# Endpoint para la generaci贸n de texto a voz
|
| 195 |
@app.post("/generate-text-to-speech")
|
| 196 |
async def generate_text_to_speech(request: GenerateRequest):
|
| 197 |
try:
|
|
|
|
| 210 |
except Exception as e:
|
| 211 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 212 |
|
| 213 |
+
# Endpoint para la generaci贸n de video
|
| 214 |
@app.post("/generate-video")
|
| 215 |
async def generate_video(request: GenerateRequest):
|
| 216 |
try:
|
|
|
|
| 228 |
except Exception as e:
|
| 229 |
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
| 230 |
|
| 231 |
+
# Configuraci贸n para ejecutar el servidor
|
| 232 |
if __name__ == "__main__":
|
| 233 |
import uvicorn
|
| 234 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|