eder0782 commited on
Commit
73c62d4
·
verified ·
1 Parent(s): 6da0e40
Files changed (1) hide show
  1. app.py +35 -10
app.py CHANGED
@@ -7,7 +7,25 @@ from diffusers import DiffusionPipeline
7
  import io
8
  import base64
9
  from PIL import Image
10
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  dtype = torch.bfloat16
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -19,6 +37,8 @@ MAX_IMAGE_SIZE = 2048
19
 
20
  @spaces.GPU()
21
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
 
22
  if randomize_seed:
23
  seed = random.randint(0, MAX_SEED)
24
  generator = torch.Generator().manual_seed(seed)
@@ -38,13 +58,21 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_in
38
  image.save(buffered, format="PNG")
39
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
40
 
41
- # Retornar JSON com Base64 e seed
42
  return {"image_base64": f"data:image/png;base64,{img_str}", "seed": seed}
43
 
44
- # Função para a API personalizada
45
- def api_infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4):
46
- result = infer(prompt, seed, randomize_seed, width, height, num_inference_steps)
47
- return result # Retorna diretamente o JSON
 
 
 
 
 
 
 
 
 
48
 
49
  examples = [
50
  "a tiny astronaut hatching from an egg on the moon",
@@ -125,7 +153,6 @@ with gr.Blocks(css=css) as demo:
125
  output = infer(prompt, seed, randomize_seed, width, height, num_inference_steps)
126
  return output["image_base64"], output["seed"]
127
 
128
- # Interface Gradio
129
  gr.on(
130
  triggers=[run_button.click, prompt.submit],
131
  fn=format_output,
@@ -133,7 +160,5 @@ with gr.Blocks(css=css) as demo:
133
  outputs=[result, seed_output]
134
  )
135
 
136
- # Endpoint personalizado para a API
137
- demo.queue(api_name="infer_api").launch()
138
-
139
  demo.launch()
 
7
  import io
8
  import base64
9
  from PIL import Image
10
+ import logging
11
+ from fastapi import FastAPI
12
+ from pydantic import BaseModel
13
+
14
+ # Configurar logging para depuração
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Inicializar FastAPI
19
+ app = FastAPI()
20
+
21
+ # Modelo para validação dos parâmetros da API
22
+ class ImageRequest(BaseModel):
23
+ prompt: str
24
+ seed: int = 42
25
+ randomize_seed: bool = False
26
+ width: int = 1024
27
+ height: int = 1024
28
+ num_inference_steps: int = 4
29
 
30
  dtype = torch.bfloat16
31
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
37
 
38
  @spaces.GPU()
39
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
40
+ logger.info(f"Chamando infer com prompt={prompt}, seed={seed}, randomize_seed={randomize_seed}, width={width}, height={height}, num_inference_steps={num_inference_steps}")
41
+
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
  generator = torch.Generator().manual_seed(seed)
 
58
  image.save(buffered, format="PNG")
59
  img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
60
 
 
61
  return {"image_base64": f"data:image/png;base64,{img_str}", "seed": seed}
62
 
63
+ # Endpoint FastAPI
64
+ @app.post("/api/infer")
65
+ async def api_infer(request: ImageRequest):
66
+ logger.info(f"Requisição API recebida: {request}")
67
+ result = infer(
68
+ prompt=request.prompt,
69
+ seed=request.seed,
70
+ randomize_seed=request.randomize_seed,
71
+ width=request.width,
72
+ height=request.height,
73
+ num_inference_steps=request.num_inference_steps
74
+ )
75
+ return result
76
 
77
  examples = [
78
  "a tiny astronaut hatching from an egg on the moon",
 
153
  output = infer(prompt, seed, randomize_seed, width, height, num_inference_steps)
154
  return output["image_base64"], output["seed"]
155
 
 
156
  gr.on(
157
  triggers=[run_button.click, prompt.submit],
158
  fn=format_output,
 
160
  outputs=[result, seed_output]
161
  )
162
 
163
+ # Iniciar o Gradio (sem queue, pois usamos FastAPI para a API)
 
 
164
  demo.launch()