badaoui HF Staff commited on
Commit
959ee6e
·
verified ·
1 Parent(s): 900a193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -219
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import csv
2
  import os
3
  from datetime import datetime
4
- from typing import Optional, Union
5
-
6
  import gradio as gr
7
  from huggingface_hub import HfApi, Repository
8
-
9
  from optimum_neuron_export import convert
10
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
11
  from apscheduler.schedulers.background import BackgroundScheduler
@@ -13,18 +12,16 @@ from apscheduler.schedulers.background import BackgroundScheduler
13
  DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/neuron-exports"
14
  DATA_FILENAME = "exports.csv"
15
  DATA_FILE = os.path.join("data", DATA_FILENAME)
16
-
17
- HF_TOKEN = os.environ.get("HF_WRITE_TOKEN")
18
-
19
  DATADIR = "neuron_exports_data"
20
-
21
  repo: Optional[Repository] = None
 
22
  # Uncomment if you want to push to dataset repo with token
23
  # if HF_TOKEN:
24
- # repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)
25
 
26
- # Define all possible tasks and their categories for coloring
27
- TASK_CATEGORIES = {
28
  "auto": {"color": "#6b7280", "category": "Auto"},
29
  "feature-extraction": {"color": "#3b82f6", "category": "Feature Extraction"},
30
  "fill-mask": {"color": "#8b5cf6", "category": "NLP"},
@@ -41,13 +38,28 @@ TASK_CATEGORIES = {
41
  "image-classification": {"color": "#ef4444", "category": "Vision"},
42
  "object-detection": {"color": "#ef4444", "category": "Vision"},
43
  "semantic-segmentation": {"color": "#ef4444", "category": "Vision"},
44
- "text-to-image": {"color": "#ec4899", "category": "Multimodal"},
45
- "image-to-image": {"color": "#ec4899", "category": "Multimodal"},
46
- "inpaint": {"color": "#ec4899", "category": "Multimodal"},
47
  "zero-shot-image-classification": {"color": "#ec4899", "category": "Multimodal"},
48
  "sentence-similarity": {"color": "#06b6d4", "category": "Similarity"},
49
  }
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  TAGS = {
52
  "Feature Extraction": {"color": "#3b82f6", "category": "Feature Extraction"},
53
  "NLP": {"color": "#8b5cf6", "category": "NLP"},
@@ -56,15 +68,37 @@ TAGS = {
56
  "Vision": {"color": "#ef4444", "category": "Vision"},
57
  "Multimodal": {"color": "#ec4899", "category": "Multimodal"},
58
  "Similarity": {"color": "#06b6d4", "category": "Similarity"},
 
 
 
 
 
 
 
59
  }
60
 
61
- # Get all tasks for dropdown
62
- ALL_TASKS = list(TASK_CATEGORIES.keys())
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def create_task_tag(task: str) -> str:
65
  """Create a colored HTML tag for a task"""
66
- if task in TASK_CATEGORIES:
67
- color = TASK_CATEGORIES[task]["color"]
 
 
 
68
  return f'<span style="background-color: {color}; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
69
  elif task in TAGS:
70
  color = TAGS[task]["color"]
@@ -77,46 +111,91 @@ def format_tasks_for_table(tasks_str: str) -> str:
77
  tasks = [task.strip() for task in tasks_str.split(',')]
78
  return ' '.join([create_task_tag(task) for task in tasks])
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- def neuron_export(model_id: str, task: str, oauth_token: gr.OAuthToken) -> str:
82
- if oauth_token.token is None:
83
- return "You must be logged in to use this space"
84
-
85
  if not model_id:
86
- return f"### Invalid input 🐞 Please specify a model name from the hub."
 
 
 
 
 
 
 
 
 
 
87
 
88
  try:
89
- api = HfApi(token=oauth_token.token)
90
-
91
- error, commit_info = convert(api=api, model_id=model_id, task=task, token=oauth_token.token)
92
- if error != "0":
93
- return commit_info
94
-
95
- print("[commit_info]", commit_info)
96
-
97
- # Save in a private dataset if repo initialized
98
- if repo is not None:
99
- repo.git_pull(rebase=True)
100
- with open(os.path.join(DATADIR, DATA_FILE), "a") as csvfile:
101
- writer = csv.DictWriter(
102
- csvfile, fieldnames=["model_id", "pr_url", "time"]
103
- )
104
- writer.writerow(
105
- {
106
- "model_id": model_id,
107
- "pr_url": commit_info.pr_url,
108
- "time": str(datetime.now()),
109
- }
110
- )
111
- commit_url = repo.push_to_hub()
112
- print("[dataset]", commit_url)
113
-
114
- pr_revision = commit_info.pr_revision.replace("/", "%2F")
115
- return f"#### Success 🔥 Yay! This model was successfully exported and a PR was opened using your token: [{commit_info.pr_url}]({commit_info.pr_url}). If you would like to use the exported model without waiting for the PR to be approved, head to https://huggingface.co/{model_id}/tree/{pr_revision}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  except Exception as e:
118
- return f"#### Error: {e}"
119
-
120
 
121
  TITLE_IMAGE = """
122
  <div style="display: block; margin-left: auto; margin-right: auto; width: 50%;">
@@ -127,205 +206,119 @@ TITLE_IMAGE = """
127
  TITLE = """
128
  <div style="text-align: center; max-width: 1400px; margin: 0 auto;">
129
  <h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px; font-size: 2.2rem;">
130
- 🤗 Optimum Neuron Model Exporter 🏎️ (WIP)
131
  </h1>
132
  </div>
133
  """
134
 
 
135
  DESCRIPTION = """
136
- This Space allows you to automatically export 🤗 transformers models hosted on the Hugging Face Hub to AWS Neuron-optimized format for Inferentia/Trainium acceleration. It opens a PR on the target model, and it is up to the owner of the original model to merge the PR to allow people to leverage Neuron optimization!
137
-
138
- **Features:**
139
- - Automatically opens PR with Neuron-optimized model
140
- - Preserves original model weights
141
- - Adds proper tags to model card
142
-
143
- **Requirements:**
144
- - Model must be compatible with [Optimum Neuron](https://huggingface.co/docs/optimum-neuron)
145
- - User must be logged in with write token
146
- """
147
-
148
- # Custom CSS to fix dark mode compatibility and transparency issues
149
- CUSTOM_CSS = """
150
- /* Fix for HuggingfaceHubSearch component visibility in both light and dark modes */
151
- .gradio-container .gr-form {
152
- background: var(--background-fill-primary) !important;
153
- border: 1px solid var(--border-color-primary) !important;
154
- }
155
-
156
- /* Ensure text is visible in both modes */
157
- .gradio-container input[type="text"],
158
- .gradio-container textarea,
159
- .gradio-container .gr-textbox input {
160
- color: var(--body-text-color) !important;
161
- background: var(--input-background-fill) !important;
162
- border: 1px solid var(--border-color-primary) !important;
163
- }
164
-
165
- /* Fix dropdown/search results visibility */
166
- .gradio-container .gr-dropdown,
167
- .gradio-container .gr-dropdown .gr-box,
168
- .gradio-container [data-testid="textbox"] {
169
- background: var(--background-fill-primary) !important;
170
- color: var(--body-text-color) !important;
171
- border: 1px solid var(--border-color-primary) !important;
172
- }
173
-
174
- /* Fix for search component specifically */
175
- .gradio-container .gr-form > div,
176
- .gradio-container .gr-form input {
177
- background: var(--input-background-fill) !important;
178
- color: var(--body-text-color) !important;
179
- }
180
-
181
- /* Ensure proper contrast for placeholder text */
182
- .gradio-container input::placeholder {
183
- color: var(--body-text-color-subdued) !important;
184
- opacity: 0.7;
185
- }
186
 
187
- /* Fix any remaining transparent backgrounds */
188
- .gradio-container .gr-box,
189
- .gradio-container .gr-panel {
190
- background: var(--background-fill-primary) !important;
191
- }
192
 
193
- /* Make sure search results are visible */
194
- .gradio-container .gr-dropdown-item {
195
- color: var(--body-text-color) !important;
196
- background: var(--background-fill-primary) !important;
197
- }
198
 
199
- .gradio-container .gr-dropdown-item:hover {
200
- background: var(--background-fill-secondary) !important;
201
- }
 
202
 
203
- /* Task tag styling improvements */
204
- .task-tags {
205
- line-height: 1.8;
206
- }
207
 
208
- .task-tags span {
209
- display: inline-block;
210
- margin: 2px;
211
- }
212
 
 
213
  /* Primary button styling with warm colors */
214
- button[variant="primary"] {
215
- background: linear-gradient(135deg, #3B82F6, #10B981) !important;
 
216
  color: white !important;
217
  padding: 16px 32px !important;
218
  font-size: 1.1rem !important;
219
  font-weight: 700 !important;
220
  border: none !important;
221
  border-radius: 12px !important;
222
- box-shadow: 0 0 15px rgba(59, 130, 246, 0.5) !important;
223
- transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
224
- position: relative;
225
- overflow: hidden;
226
- animation: glow 1.5s ease-in-out infinite alternate;
227
- }
228
- button[variant="primary"]::before {
229
- content: "✨ ";
230
- }
231
- button[variant="primary"]:hover {
232
- transform: translateY(-5px) scale(1.05) !important;
233
- box-shadow: 0 10px 25px rgba(59, 130, 246, 0.7) !important;
234
- }
235
- @keyframes glow {
236
- from {
237
- box-shadow: 0 0 10px rgba(59, 130, 246, 0.5);
238
- }
239
- to {
240
- box-shadow: 0 0 20px rgba(59, 130, 246, 0.8), 0 0 30px rgba(16, 185, 129, 0.5);
241
- }
242
- }
243
-
244
-
245
- /* Login button styling with glow effect using warm colors */
246
- #login-button {
247
- background: linear-gradient(135deg, #FF7A00, #FFD700) !important; /* Orange to Gold */
248
- color: white !important;
249
- font-weight: 700 !important;
250
- border: none !important;
251
- border-radius: 12px !important;
252
- box-shadow: 0 0 15px rgba(255, 165, 0, 0.5) !important; /* Warm glow */
253
  transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
254
  position: relative;
255
  overflow: hidden;
256
- animation: glow 1.5s ease-in-out infinite alternate;
257
- max-width: 300px !important;
258
- margin: 0 auto !important;
259
- }
260
-
261
- #login-button::before {
262
- content: "🔑 ";
263
- display: inline-block !important;
264
- vertical-align: middle !important;
265
- margin-right: 5px !important;
266
- line-height: normal !important;
267
- }
268
-
269
- #login-button:hover {
270
- transform: translateY(-3px) scale(1.03) !important;
271
- box-shadow: 0 10px 25px rgba(255, 140, 0, 0.7) !important; /* Deeper warm glow */
272
- }
273
-
274
- #login-button::after {
275
- content: "";
276
- position: absolute;
277
- top: 0;
278
- left: -100%;
279
- width: 100%;
280
- height: 100%;
281
- background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
282
- transition: 0.5s;
283
- }
284
-
285
- #login-button:hover::after {
286
- left: 100%;
287
  }
288
  """
289
 
290
- with gr.Blocks(css=CUSTOM_CSS) as demo:
291
- # Login requirement notice and button
292
- gr.Markdown("**You must be logged in to use this space**")
293
- gr.LoginButton(elem_id="login-button", elem_classes="center-button", min_width=250)
294
-
295
- # Centered title and image
296
  gr.HTML(TITLE_IMAGE)
297
  gr.HTML(TITLE)
298
-
299
- # Full-width description
300
  gr.Markdown(DESCRIPTION)
301
 
302
  with gr.Tabs():
303
  with gr.Tab("Export Model"):
304
- # Input controls in a row
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
  with gr.Row():
306
  input_model = HuggingfaceHubSearch(
307
  label="Hub model ID",
308
- placeholder="Search for model ID on the hub",
309
  search_type="model",
310
  )
311
- input_task = gr.Dropdown(
312
- choices=ALL_TASKS,
313
  value="auto",
314
- label='Task (auto could infer task from model)',
315
  )
316
 
317
- # Export button below the inputs
318
  btn = gr.Button("Export to Neuron", size="lg", variant="primary")
319
 
320
- # Output section
321
- output = gr.Markdown(label="Output")
322
-
 
 
 
 
 
 
 
 
 
 
 
 
323
  btn.click(
324
  fn=neuron_export,
325
- inputs=[input_model, input_task],
326
- outputs=output,
 
 
 
 
 
 
327
  )
328
-
329
  with gr.Tab("Supported Architectures"):
330
  gr.HTML(f"""
331
  <div style="margin-bottom: 20px;">
@@ -404,7 +397,6 @@ with gr.Blocks(css=CUSTOM_CSS) as demo:
404
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Yolos</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, object-detection")}</td></tr>
405
  </tbody>
406
  </table>
407
-
408
  <h2>🧨 Diffusers</h2>
409
  <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
410
  <colgroup>
@@ -425,9 +417,10 @@ with gr.Blocks(css=CUSTOM_CSS) as demo:
425
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">LCM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
426
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">PixArt-α</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
427
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">PixArt-Σ</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
 
 
428
  </tbody>
429
  </table>
430
-
431
  <h2>🤖 Sentence Transformers</h2>
432
  <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
433
  <colgroup>
@@ -445,7 +438,6 @@ with gr.Blocks(css=CUSTOM_CSS) as demo:
445
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">CLIP</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, zero-shot-image-classification")}</td></tr>
446
  </tbody>
447
  </table>
448
-
449
  <div style="margin-top: 20px;">
450
  <p>💡 <strong>Note</strong>: Some architectures may have specific requirements or limitations. DeBERTa models are only supported on INF2 instances.</p>
451
  <p>For more details, check the <a href="https://huggingface.co/docs/optimum-neuron" target="_blank">Optimum Neuron documentation</a>.</p>
@@ -456,12 +448,4 @@ with gr.Blocks(css=CUSTOM_CSS) as demo:
456
  gr.Markdown("<br><br><br><br>")
457
 
458
  if __name__ == "__main__":
459
- def restart_space():
460
- if HF_TOKEN:
461
- HfApi().restart_space(repo_id="optimum/neuron-export", token=HF_TOKEN, factory_reboot=True)
462
-
463
- scheduler = BackgroundScheduler()
464
- scheduler.add_job(restart_space, "interval", seconds=21600) # Restart every 6 hours
465
- scheduler.start()
466
-
467
- demo.launch()
 
1
  import csv
2
  import os
3
  from datetime import datetime
4
+ from typing import Optional, Union, List
 
5
  import gradio as gr
6
  from huggingface_hub import HfApi, Repository
7
+ from huggingface_hub import login
8
  from optimum_neuron_export import convert
9
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
10
  from apscheduler.schedulers.background import BackgroundScheduler
 
12
  DATASET_REPO_URL = "https://huggingface.co/datasets/optimum/neuron-exports"
13
  DATA_FILENAME = "exports.csv"
14
  DATA_FILE = os.path.join("data", DATA_FILENAME)
15
+ HF_TOKEN = os.getenv("HF_TOKEN") # It's better to use environment variables
 
 
16
  DATADIR = "neuron_exports_data"
 
17
  repo: Optional[Repository] = None
18
+
19
  # Uncomment if you want to push to dataset repo with token
20
  # if HF_TOKEN:
21
+ # repo = Repository(local_dir=DATADIR, clone_from=DATASET_REPO_URL, token=HF_TOKEN)
22
 
23
+ # Define transformer tasks and their categories for coloring
24
+ TRANSFORMER_TASKS = {
25
  "auto": {"color": "#6b7280", "category": "Auto"},
26
  "feature-extraction": {"color": "#3b82f6", "category": "Feature Extraction"},
27
  "fill-mask": {"color": "#8b5cf6", "category": "NLP"},
 
38
  "image-classification": {"color": "#ef4444", "category": "Vision"},
39
  "object-detection": {"color": "#ef4444", "category": "Vision"},
40
  "semantic-segmentation": {"color": "#ef4444", "category": "Vision"},
 
 
 
41
  "zero-shot-image-classification": {"color": "#ec4899", "category": "Multimodal"},
42
  "sentence-similarity": {"color": "#06b6d4", "category": "Similarity"},
43
  }
44
 
45
+ # Define diffusion pipeline types
46
+ DIFFUSION_PIPELINES = {
47
+ "text-to-image": {"color": "#ec4899", "category": "Stable Diffusion"},
48
+ "image-to-image": {"color": "#ec4899", "category": "Stable Diffusion"},
49
+ "inpaint": {"color": "#ec4899", "category": "Stable Diffusion"},
50
+ "instruct-pix2pix": {"color": "#ec4899", "category": "Stable Diffusion"},
51
+ "latent-consistency": {"color": "#8b5cf6", "category": "Latent Consistency"},
52
+ "stable-diffusion": {"color": "#10b981", "category": "Stable Diffusion"},
53
+ "stable-diffusion-xl": {"color": "#10b981", "category": "Stable Diffusion XL"},
54
+ "stable-diffusion-xl-img2img": {"color": "#10b981", "category": "Stable Diffusion XL"},
55
+ "stable-diffusion-xl-inpaint": {"color": "#10b981", "category": "Stable Diffusion XL"},
56
+ "controlnet": {"color": "#f59e0b", "category": "ControlNet"},
57
+ "controlnet-xl": {"color": "#f59e0b", "category": "ControlNet XL"},
58
+ "pixart-alpha": {"color": "#ef4444", "category": "PixArt"},
59
+ "pixart-sigma": {"color": "#ef4444", "category": "PixArt"},
60
+ "flux": {"color": "#06b6d4", "category": "Flux"},
61
+ }
62
+
63
  TAGS = {
64
  "Feature Extraction": {"color": "#3b82f6", "category": "Feature Extraction"},
65
  "NLP": {"color": "#8b5cf6", "category": "NLP"},
 
68
  "Vision": {"color": "#ef4444", "category": "Vision"},
69
  "Multimodal": {"color": "#ec4899", "category": "Multimodal"},
70
  "Similarity": {"color": "#06b6d4", "category": "Similarity"},
71
+ "Stable Diffusion": {"color": "#ec4899", "category": "Stable Diffusion"},
72
+ "Stable Diffusion XL": {"color": "#10b981", "category": "Stable Diffusion XL"},
73
+ "ControlNet": {"color": "#f59e0b", "category": "ControlNet"},
74
+ "ControlNet XL": {"color": "#f59e0b", "category": "ControlNet XL"},
75
+ "PixArt": {"color": "#ef4444", "category": "PixArt"},
76
+ "Latent Consistency": {"color": "#8b5cf6", "category": "Latent Consistency"},
77
+ "Flux": {"color": "#06b6d4", "category": "Flux"},
78
  }
79
 
80
+ # UPDATED: New choices for the Pull Request destination UI component
81
+ DEST_NEW_NEURON_REPO = "Create new Neuron-optimized repository"
82
+ DEST_CACHE_REPO = "Create a PR in the cache repository"
83
+ DEST_CUSTOM_REPO = "Create a PR in a custom repository"
84
+
85
+ PR_DESTINATION_CHOICES = [
86
+ DEST_NEW_NEURON_REPO,
87
+ DEST_CACHE_REPO,
88
+ DEST_CUSTOM_REPO
89
+ ]
90
+
91
+ # Get all tasks and pipelines for dropdowns
92
+ ALL_TRANSFORMER_TASKS = list(TRANSFORMER_TASKS.keys())
93
+ ALL_DIFFUSION_PIPELINES = list(DIFFUSION_PIPELINES.keys())
94
 
95
  def create_task_tag(task: str) -> str:
96
  """Create a colored HTML tag for a task"""
97
+ if task in TRANSFORMER_TASKS:
98
+ color = TRANSFORMER_TASKS[task]["color"]
99
+ return f'<span style="background-color: {color}; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
100
+ elif task in DIFFUSION_PIPELINES:
101
+ color = DIFFUSION_PIPELINES[task]["color"]
102
  return f'<span style="background-color: {color}; color: white; padding: 2px 6px; border-radius: 12px; font-size: 0.75rem; font-weight: 500; margin: 1px;">{task}</span>'
103
  elif task in TAGS:
104
  color = TAGS[task]["color"]
 
111
  tasks = [task.strip() for task in tasks_str.split(',')]
112
  return ' '.join([create_task_tag(task) for task in tasks])
113
 
114
+ def update_task_dropdown(model_type: str):
115
+ """Update the task dropdown based on selected model type"""
116
+ if model_type == "transformers":
117
+ return gr.Dropdown(
118
+ choices=ALL_TRANSFORMER_TASKS,
119
+ value="auto",
120
+ label="Task (auto can infer task from model)",
121
+ visible=True
122
+ )
123
+ else: # diffusion
124
+ return gr.Dropdown(
125
+ choices=ALL_DIFFUSION_PIPELINES,
126
+ value="text-to-image",
127
+ label="Pipeline Type",
128
+ visible=True
129
+ )
130
+
131
+ def toggle_custom_repo_box(pr_destinations: List[str]):
132
+ """Show or hide the custom repo ID textbox based on checkbox selection."""
133
+ if DEST_CUSTOM_REPO in pr_destinations:
134
+ return gr.Textbox(visible=True)
135
+ else:
136
+ return gr.Textbox(visible=False, value="")
137
 
138
+ # UPDATED: Modified function to handle new repository creation workflow
139
+ def neuron_export(model_id: str, model_type: str, task_or_pipeline: str,
140
+ pr_destinations: List[str], custom_repo_id: str):
 
141
  if not model_id:
142
+ yield "🚫 Invalid input. Please specify a model name from the hub."
143
+ return
144
+
145
+ log_buffer = ""
146
+ def log(msg):
147
+ nonlocal log_buffer
148
+ # Handle cases where the message from the backend is not a string
149
+ if not isinstance(msg, str):
150
+ msg = str(msg)
151
+ log_buffer += msg + "\n"
152
+ return log_buffer
153
 
154
  try:
155
+ api = HfApi()
156
+ yield log(f"🔑 Logging in with provided token...")
157
+ if not HF_TOKEN:
158
+ yield log("❌ HF_TOKEN not found. Please set it as an environment variable in the Space secrets.")
159
+ return
160
+
161
+ login(token=HF_TOKEN)
162
+ yield log("✅ Login successful.")
163
+ yield log(f"🔍 Checking access to `{model_id}`...")
164
+ try:
165
+ api.model_info(model_id, token=HF_TOKEN)
166
+ except Exception as e:
167
+ yield log(f"❌ Could not access model `{model_id}`: {e}")
168
+ return
169
+
170
+ yield log(f"✅ Model `{model_id}` is accessible. Starting Neuron export...")
171
+
172
+ # UPDATED: Build pr_options with new structure
173
+ pr_options = {
174
+ "create_neuron_repo": DEST_NEW_NEURON_REPO in pr_destinations,
175
+ "create_cache_pr": DEST_CACHE_REPO in pr_destinations,
176
+ "create_custom_pr": DEST_CUSTOM_REPO in pr_destinations,
177
+ "custom_repo_id": custom_repo_id.strip() if custom_repo_id else ""
178
+ }
179
+
180
+ # The convert function is a generator, so we iterate through its messages
181
+ for status_code, message in convert(api, model_id, task_or_pipeline, model_type,
182
+ token=HF_TOKEN, pr_options=pr_options):
183
+ if isinstance(message, str):
184
+ yield log(message)
185
+ else: # It's the final result dictionary
186
+ final_message = "🎉 Process finished.\n"
187
+ if message.get("neuron_repo"):
188
+ final_message += f"🏗️ New Neuron Repository: {message['neuron_repo']}\n"
189
+ if message.get("readme_pr"):
190
+ final_message += f"📝 README PR (Original Model): {message['readme_pr']}\n"
191
+ if message.get("cache_pr"):
192
+ final_message += f"🔗 Cache PR: {message['cache_pr']}\n"
193
+ if message.get("custom_pr"):
194
+ final_message += f"🔗 Custom PR: {message['custom_pr']}\n"
195
+ yield log(final_message)
196
 
197
  except Exception as e:
198
+ yield log(f" An unexpected error occurred in the Gradio interface: {e}")
 
199
 
200
  TITLE_IMAGE = """
201
  <div style="display: block; margin-left: auto; margin-right: auto; width: 50%;">
 
206
  TITLE = """
207
  <div style="text-align: center; max-width: 1400px; margin: 0 auto;">
208
  <h1 style="font-weight: 900; margin-bottom: 10px; margin-top: 10px; font-size: 2.2rem;">
209
+ 🤗 Optimum Neuron Model Compiler 🏎️
210
  </h1>
211
  </div>
212
  """
213
 
214
+ # UPDATED: Description to reflect new workflow
215
  DESCRIPTION = """
216
+ This Space allows you to automatically export 🤗 transformers and diffusion models to AWS Neuron-optimized format for Inferentia/Trainium acceleration.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ Simply provide a model ID from the Hugging Face Hub, and choose your desired output.
 
 
 
 
219
 
220
+ ### Key Features
 
 
 
 
221
 
222
+ * **🚀 Create a New Optimized Repo**: Automatically converts the model and uploads it to a new repository under your username (e.g., `your-username/model-name-neuron`).
223
+ * **🔗 Link Back to Original**: Creates a Pull Request on the original model's repository to add a link to your new optimized version, making it easily discoverable by the community.
224
+ * **🛠️ PR to a Custom Repo**: For custom workflows, you can create a Pull Request with the optimized files directly into an existing repository you own.
225
+ * **📦 Contribute to Cache**: You can also contribute the generated compilation artifacts to a centralized cache repository, which helps speed up future compilations for everyone.
226
 
227
+ ### ⚙️ How to Use
 
 
 
228
 
229
+ 1. **Model ID**: Enter the ID of the model you want to export (e.g., `bert-base-uncased` or `stabilityai/stable-diffusion-xl-base-1.0`).
230
+ 2. **Export Options**: Select at least one option for where to save the exported model.
231
+ 3. **Convert & Upload**: Click the button and follow the logs for progress!
232
+ """
233
 
234
+ CUSTOM_CSS = """
235
  /* Primary button styling with warm colors */
236
+ button.gradio-button.lg.primary {
237
+ /* Changed the blue/green gradient to an orange/yellow one */
238
+ background: linear-gradient(135deg, #F97316, #FBBF24) !important;
239
  color: white !important;
240
  padding: 16px 32px !important;
241
  font-size: 1.1rem !important;
242
  font-weight: 700 !important;
243
  border: none !important;
244
  border-radius: 12px !important;
245
+ /* Updated the shadow to match the new orange color */
246
+ box-shadow: 0 0 15px rgba(249, 115, 22, 0.5) !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1) !important;
248
  position: relative;
249
  overflow: hidden;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  }
251
  """
252
 
253
+ with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Soft()) as demo:
 
 
 
 
 
254
  gr.HTML(TITLE_IMAGE)
255
  gr.HTML(TITLE)
 
 
256
  gr.Markdown(DESCRIPTION)
257
 
258
  with gr.Tabs():
259
  with gr.Tab("Export Model"):
260
+ with gr.Group():
261
+ with gr.Row():
262
+ pr_destinations_checkbox = gr.CheckboxGroup(
263
+ choices=PR_DESTINATION_CHOICES,
264
+ label="Export Destination",
265
+ value=[DEST_NEW_NEURON_REPO],
266
+ info="Select one or more destinations for the compiled model."
267
+ )
268
+ custom_repo_id_textbox = gr.Textbox(
269
+ label="Custom Repository ID",
270
+ placeholder="e.g., your-username/your-repo-name",
271
+ visible=False,
272
+ interactive=True
273
+ )
274
+ with gr.Row():
275
+ model_type = gr.Radio(
276
+ choices=["transformers", "diffusion"],
277
+ value="transformers",
278
+ label="Model Type",
279
+ info="Choose the type of model you want to export"
280
+ )
281
  with gr.Row():
282
  input_model = HuggingfaceHubSearch(
283
  label="Hub model ID",
284
+ placeholder="Search for a model on the Hub...",
285
  search_type="model",
286
  )
287
+ task_dropdown = gr.Dropdown(
288
+ choices=ALL_TRANSFORMER_TASKS,
289
  value="auto",
290
+ label="Task (auto can infer from model)",
291
  )
292
 
 
293
  btn = gr.Button("Export to Neuron", size="lg", variant="primary")
294
 
295
+ log_box = gr.Textbox(label="Logs", lines=20, interactive=False, show_copy_button=True)
296
+
297
+ # Event Handlers
298
+ model_type.change(
299
+ fn=update_task_dropdown,
300
+ inputs=[model_type],
301
+ outputs=[task_dropdown]
302
+ )
303
+
304
+ pr_destinations_checkbox.change(
305
+ fn=toggle_custom_repo_box,
306
+ inputs=pr_destinations_checkbox,
307
+ outputs=custom_repo_id_textbox
308
+ )
309
+
310
  btn.click(
311
  fn=neuron_export,
312
+ inputs=[
313
+ input_model,
314
+ model_type,
315
+ task_dropdown,
316
+ pr_destinations_checkbox,
317
+ custom_repo_id_textbox
318
+ ],
319
+ outputs=log_box,
320
  )
321
+
322
  with gr.Tab("Supported Architectures"):
323
  gr.HTML(f"""
324
  <div style="margin-bottom: 20px;">
 
397
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Yolos</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, object-detection")}</td></tr>
398
  </tbody>
399
  </table>
 
400
  <h2>🧨 Diffusers</h2>
401
  <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
402
  <colgroup>
 
417
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">LCM</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
418
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">PixArt-α</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
419
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">PixArt-Σ</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
420
+ <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">Flux</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("text-to-image")}</td></tr>
421
+
422
  </tbody>
423
  </table>
 
424
  <h2>🤖 Sentence Transformers</h2>
425
  <table style="width: 100%; border-collapse: collapse; margin: 20px 0;">
426
  <colgroup>
 
438
  <tr><td style="border: 1px solid var(--border-color-primary); padding: 8px; font-weight: bold;">CLIP</td><td style="border: 1px solid var(--border-color-primary); padding: 8px;" class="task-tags">{format_tasks_for_table("feature-extraction, zero-shot-image-classification")}</td></tr>
439
  </tbody>
440
  </table>
 
441
  <div style="margin-top: 20px;">
442
  <p>💡 <strong>Note</strong>: Some architectures may have specific requirements or limitations. DeBERTa models are only supported on INF2 instances.</p>
443
  <p>For more details, check the <a href="https://huggingface.co/docs/optimum-neuron" target="_blank">Optimum Neuron documentation</a>.</p>
 
448
  gr.Markdown("<br><br><br><br>")
449
 
450
  if __name__ == "__main__":
451
+ demo.launch(debug=True)