Nymbo commited on
Commit
cf0a064
·
verified ·
1 Parent(s): f2e4d2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -147
app.py CHANGED
@@ -1,191 +1,167 @@
1
- # main/app.py
2
- # -------------------------------------------------------------------
3
- # FLUX.1-Schnell Space — fixed wiring + proper Space launch settings
4
- # -------------------------------------------------------------------
5
- import os
6
  import io
7
  import random
8
- import requests
 
9
  from PIL import Image
10
- import gradio as gr
11
- from deep_translator import GoogleTranslator # used opportunistically
 
 
12
 
13
- # -----------------------------
14
- # Config / constants (simple)
15
- # -----------------------------
16
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
17
- # Try a few common env var names so you don't get bitten by naming
18
- API_TOKEN = (
19
- os.getenv("HF_READ_TOKEN")
20
- or os.getenv("HF_TOKEN")
21
- or os.getenv("HUGGINGFACEHUB_API_TOKEN")
22
- )
23
- HEADERS = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
24
- TIMEOUT = 100 # seconds
25
-
26
- # -----------------------------------------------
27
- # Helper: translate to EN only if it looks Cyrillic
28
- # (keeps behavior without forcing translation)
29
- # -----------------------------------------------
30
- def maybe_translate_to_en(text: str) -> str:
31
- # Count Cyrillic letters; if enough, treat as Russian and translate.
32
  try:
33
- cyr_count = sum("а" <= c <= "я" or "А" <= c <= "Я" for c in text)
34
- if cyr_count >= max(3, len(text) // 10):
35
- return GoogleTranslator(source="ru", target="en").translate(text)
36
- except Exception:
37
- # If translator fails, just return original prompt
38
- pass
39
- return text
40
-
41
- # -------------------------------------------------------
42
- # Core: generate image
43
- # - Signature now matches the UI components by POSITION
44
- # - We send only supported params to the Inference API
45
- # -------------------------------------------------------
46
- def query(
47
- prompt: str, # main user prompt (Textbox)
48
- negative_prompt: str, # negative prompt (Textbox)
49
- steps: int, # sampling steps (Slider)
50
- cfg_scale: float, # guidance scale (Slider)
51
- sampler: str, # UI-only; ignored by FLUX text2img
52
- seed: int, # -1 => random, else fixed
53
- strength: float, # UI-only here (img2img only), ignored
54
- width: int, # output width
55
- height: int, # output height
56
- ):
57
- # Guard empty input
58
  if not prompt:
59
- return None
 
 
 
 
60
 
61
- # Give each generation a small key for readable logs
62
  key = random.randint(0, 999)
63
 
64
- # Optional translation + your style suffix
65
- prompt_en = maybe_translate_to_en(prompt)
66
- prompt_en = f"{prompt_en} | ultra detail, ultra elaboration, ultra quality, perfect."
67
- print(f"\033[1mGeneration {key}:\033[0m {prompt_en}")
68
 
69
- # Build payload using the names expected by the HF Inference API
70
- params = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  "num_inference_steps": int(steps),
72
  "guidance_scale": float(cfg_scale),
 
 
 
73
  "width": int(width),
74
  "height": int(height),
75
  }
76
- if negative_prompt:
77
- params["negative_prompt"] = negative_prompt
78
- if int(seed) >= 0:
79
- params["seed"] = int(seed)
80
-
81
- # Make the request
82
- resp = requests.post(
83
- API_URL,
84
- headers=HEADERS,
85
- json={"inputs": prompt_en, "parameters": params},
86
- timeout=TIMEOUT,
87
- )
88
 
89
- # Handle common error modes: model warming, JSON error body, etc.
90
- if resp.status_code != 200:
91
- try:
92
- err = resp.json()
93
- except Exception:
94
- err = {"error": resp.text}
95
- msg = err.get("error") or err.get("message") or str(err)
96
- if resp.status_code == 503:
97
- # Friendly message when model is spinning up
98
- raise gr.Error("503: The model is warming up. Please try again in a moment.")
99
- raise gr.Error(f"{resp.status_code}: {msg}")
100
-
101
- # Turn returned bytes into a PIL image
102
  try:
103
- image = Image.open(io.BytesIO(resp.content)).convert("RGB")
104
- print(f"\033[1mGeneration {key} completed!\033[0m")
 
 
105
  return image
106
  except Exception as e:
107
- print("Error decoding image:", e)
108
- raise gr.Error("Failed to decode image from the model response.")
109
 
110
- # -----------------------
111
- # Minimal CSS for layout
112
- # -----------------------
113
  css = """
114
- #app-container { max-width: 800px; margin-left: auto; margin-right: auto; }
 
 
 
 
115
  """
116
 
117
- # -----------------------
118
- # UI layout (Blocks API)
119
- # -----------------------
120
- with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as app:
121
- # Title banner (simple HTML)
122
  gr.HTML("<center><h1>FLUX.1-Schnell</h1></center>")
123
-
124
- # Main column
125
  with gr.Column(elem_id="app-container"):
126
- # Prompt area
127
  with gr.Row():
128
  with gr.Column(elem_id="prompt-container"):
129
  with gr.Row():
130
- # Prompt Textbox (user prompt)
131
- text_prompt = gr.Textbox(
132
- label="Prompt",
133
- placeholder="Enter a prompt here",
134
- lines=2,
135
- elem_id="prompt-text-input",
136
- )
137
-
138
- # Advanced settings tucked into an Accordion
139
  with gr.Row():
140
  with gr.Accordion("Advanced Settings", open=False):
141
- # Negative Prompt Textbox (string, not a boolean!)
142
- negative_prompt = gr.Textbox(
143
- label="Negative Prompt",
144
- placeholder="What should not be in the image",
145
- value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos",
146
- lines=3,
147
- elem_id="negative-prompt-text-input",
148
- )
149
- # Size sliders
150
  with gr.Row():
151
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
152
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
153
- # Quality knobs
154
  steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=100, step=1)
155
  cfg = gr.Slider(label="CFG Scale", value=7, minimum=1, maximum=20, step=1)
156
- # Shown for completeness; not used for text2img
157
- strength = gr.Slider(label="Strength (img2img only)", value=0.7, minimum=0, maximum=1, step=0.001)
158
- # -1 => random seed each run
159
- seed = gr.Slider(label="Seed (-1 = random)", value=-1, minimum=-1, maximum=1_000_000_000, step=1)
160
- # UI-only (Flux Inference API ignores this)
161
- sampler = gr.Radio(
162
- label="Sampling method (UI only)",
163
- value="DPM++ 2M Karras",
164
- choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"],
165
- )
166
-
167
- # Run button
168
- with gr.Row():
169
- text_button = gr.Button("Run", variant="primary", elem_id="gen-button")
170
 
171
- # Image output
 
 
 
 
172
  with gr.Row():
173
  image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
174
-
175
- # Wire the click to the function with matching POSITIONS
176
  text_button.click(
177
- fn=query,
178
- inputs=[text_prompt, negative_prompt, steps, cfg, sampler, seed, strength, width, height],
179
  outputs=image_output,
 
180
  )
181
 
182
- # ------------------------------------------------------
183
- # Proper Space launch (no share tunnel, bind container)
184
- # ------------------------------------------------------
185
- if __name__ == "__main__":
186
- app.launch(
187
- server_name="0.0.0.0", # make server reachable in the Space
188
- server_port=int(os.getenv("PORT", "7860")), # Hugging Face sets PORT
189
- show_api=False, # we don't need /docs
190
- share=False, # Spaces already reverse-proxy
191
  )
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
 
 
 
3
  import io
4
  import random
5
+ import os
6
+ import time
7
  from PIL import Image
8
+ from deep_translator import GoogleTranslator
9
+ import json
10
+
11
+ # Project by Nymbo
12
 
 
 
 
13
  API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"
14
+ timeout = 100
15
+
16
+
17
+ def _translate_text(text: str | None) -> str | None:
18
+ """Translate user input to English when possible while failing gracefully."""
19
+ if not text:
20
+ return text
 
 
 
 
 
 
 
 
21
  try:
22
+ translator = GoogleTranslator(source="auto", target="en")
23
+ translated = translator.translate(text)
24
+ return translated or text
25
+ except Exception as exc:
26
+ print(f"Translation failed, using original text: {exc}")
27
+ return text
28
+
29
+
30
+ # Function to query the API and return the generated image
31
+ def query(prompt, negative_prompt, steps=30, cfg_scale=7, sampler="DPM++ 2M Karras", seed=-1, strength=0.7, width=1024, height=1024):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  if not prompt:
33
+ raise gr.Error("Please provide a prompt before generating an image.")
34
+
35
+ api_token = os.getenv("HF_READ_TOKEN")
36
+ if not api_token:
37
+ raise gr.Error("Missing HF_READ_TOKEN environment variable.")
38
 
39
+ headers = {"Authorization": f"Bearer {api_token}"}
40
  key = random.randint(0, 999)
41
 
42
+ translated_prompt = _translate_text(prompt) or prompt
43
+ translated_negative = _translate_text(negative_prompt) or negative_prompt
 
 
44
 
45
+ if translated_prompt != prompt:
46
+ print(f"\033[1mGeneration {key} translation:\033[0m {translated_prompt}")
47
+
48
+ if translated_negative and translated_negative != negative_prompt:
49
+ print(f"\033[1mGeneration {key} negative translation:\033[0m {translated_negative}")
50
+
51
+ final_prompt = f"{translated_prompt} | ultra detail, ultra elaboration, ultra quality, perfect."
52
+ print(f"\033[1mGeneration {key}:\033[0m {final_prompt}")
53
+
54
+ try:
55
+ seed_int = int(seed)
56
+ except (TypeError, ValueError):
57
+ seed_int = -1
58
+ seed_value = seed_int if seed_int >= 0 else random.randint(1, 1_000_000_000)
59
+
60
+ parameters = {
61
+ "negative_prompt": translated_negative or None,
62
  "num_inference_steps": int(steps),
63
  "guidance_scale": float(cfg_scale),
64
+ "scheduler": sampler,
65
+ "seed": int(seed_value),
66
+ "strength": float(strength),
67
  "width": int(width),
68
  "height": int(height),
69
  }
70
+ parameters = {k: v for k, v in parameters.items() if v is not None}
71
+
72
+ payload = {
73
+ "inputs": final_prompt,
74
+ "parameters": parameters,
75
+ "steps": int(steps),
76
+ "cfg_scale": float(cfg_scale),
77
+ "seed": int(seed_value),
78
+ "strength": float(strength),
79
+ }
 
 
80
 
81
+ if translated_negative:
82
+ payload["negative_prompt"] = translated_negative
83
+
84
+ # Send the request to the API and handle the response
85
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
86
+ if response.status_code != 200:
87
+ print(f"Error: Failed to get image. Response status: {response.status_code}")
88
+ print(f"Response content: {response.text}")
89
+ if response.status_code == 503:
90
+ raise gr.Error(f"{response.status_code} : The model is being loaded")
91
+ raise gr.Error(f"{response.status_code}")
92
+
 
93
  try:
94
+ # Convert the response content into an image
95
+ image_bytes = response.content
96
+ image = Image.open(io.BytesIO(image_bytes))
97
+ print(f'\033[1mGeneration {key} completed!\033[0m ({prompt})')
98
  return image
99
  except Exception as e:
100
+ print(f"Error when trying to open the image: {e}")
101
+ return None
102
 
103
+ # CSS to style the app
 
 
104
  css = """
105
+ #app-container {
106
+ max-width: 800px;
107
+ margin-left: auto;
108
+ margin-right: auto;
109
+ }
110
  """
111
 
112
+ # Build the Gradio UI with Blocks
113
+ with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
114
+ # Add a title to the app
 
 
115
  gr.HTML("<center><h1>FLUX.1-Schnell</h1></center>")
116
+
117
+ # Container for all the UI elements
118
  with gr.Column(elem_id="app-container"):
119
+ # Add a text input for the main prompt
120
  with gr.Row():
121
  with gr.Column(elem_id="prompt-container"):
122
  with gr.Row():
123
+ text_prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt here", lines=2, elem_id="prompt-text-input")
124
+
125
+ # Accordion for advanced settings
 
 
 
 
 
 
126
  with gr.Row():
127
  with gr.Accordion("Advanced Settings", open=False):
128
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="What should not be in the image", value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos", lines=3, elem_id="negative-prompt-text-input")
 
 
 
 
 
 
 
 
129
  with gr.Row():
130
  width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
131
  height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
 
132
  steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=100, step=1)
133
  cfg = gr.Slider(label="CFG Scale", value=7, minimum=1, maximum=20, step=1)
134
+ strength = gr.Slider(label="Strength", value=0.7, minimum=0, maximum=1, step=0.001)
135
+ seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1) # Setting the seed to -1 will make it random
136
+ method = gr.Radio(label="Sampling method", value="DPM++ 2M Karras", choices=["DPM++ 2M Karras", "DPM++ SDE Karras", "Euler", "Euler a", "Heun", "DDIM"])
 
 
 
 
 
 
 
 
 
 
 
137
 
138
+ # Add a button to trigger the image generation
139
+ with gr.Row():
140
+ text_button = gr.Button("Run", variant='primary', elem_id="gen-button")
141
+
142
+ # Image output area to display the generated image
143
  with gr.Row():
144
  image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery")
145
+
146
+ # Bind the button to the query function with the added width and height inputs
147
  text_button.click(
148
+ query,
149
+ inputs=[text_prompt, negative_prompt, steps, cfg, method, seed, strength, width, height],
150
  outputs=image_output,
151
+ show_api=False,
152
  )
153
 
154
+ # Launch the Gradio app
155
+ launch_kwargs = {"show_api": False}
156
+ if os.getenv("SPACE_ID"):
157
+ launch_kwargs.update(
158
+ {
159
+ "share": True,
160
+ "server_name": "0.0.0.0",
161
+ "server_port": int(os.getenv("PORT", "7860")),
162
+ }
163
  )
164
+ else:
165
+ launch_kwargs["share"] = False
166
+
167
+ app.launch(**launch_kwargs)