ovi054 commited on
Commit
223679f
·
verified ·
1 Parent(s): e8e7cda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -2
app.py CHANGED
@@ -5,6 +5,13 @@ import spaces
5
  import torch
6
  from diffusers import QwenImagePipeline
7
 
 
 
 
 
 
 
 
8
  dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
@@ -20,6 +27,46 @@ MAX_IMAGE_SIZE = 2048
20
 
21
  # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @spaces.GPU()
24
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
25
  if randomize_seed:
@@ -28,8 +75,11 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
28
 
29
 
30
  if lora_id and lora_id.strip() != "":
31
- pipe.unload_lora_weights()
32
- pipe.load_lora_weights(lora_id.strip())
 
 
 
33
 
34
 
35
  try:
 
5
  import torch
6
  from diffusers import QwenImagePipeline
7
 
8
+
9
+ import os
10
+ import requests
11
+ import tempfile
12
+ import shutil
13
+ from urllib.parse import urlparse
14
+
15
  dtype = torch.bfloat16
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
 
27
 
28
  # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
29
 
30
+
31
+ def load_lora_auto(pipe, lora_input):
32
+ lora_input = lora_input.strip()
33
+ if not lora_input:
34
+ return
35
+
36
+ # If it's just an ID like "author/model"
37
+ if "/" in lora_input and not lora_input.startswith("http"):
38
+ pipe.load_lora_weights(lora_input)
39
+ return
40
+
41
+ if lora_input.startswith("http"):
42
+ url = lora_input
43
+
44
+ # Repo page (no blob/resolve)
45
+ if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
46
+ repo_id = urlparse(url).path.strip("/")
47
+ pipe.load_lora_weights(repo_id)
48
+ return
49
+
50
+ # Blob link → convert to resolve link
51
+ if "/blob/" in url:
52
+ url = url.replace("/blob/", "/resolve/")
53
+
54
+ # Download direct file
55
+ tmp_dir = tempfile.mkdtemp()
56
+ local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
57
+
58
+ try:
59
+ print(f"Downloading LoRA from {url}...")
60
+ resp = requests.get(url, stream=True)
61
+ resp.raise_for_status()
62
+ with open(local_path, "wb") as f:
63
+ for chunk in resp.iter_content(chunk_size=8192):
64
+ f.write(chunk)
65
+ print(f"Saved LoRA to {local_path}")
66
+ pipe.load_lora_weights(local_path)
67
+ finally:
68
+ shutil.rmtree(tmp_dir, ignore_errors=True)
69
+
70
  @spaces.GPU()
71
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
72
  if randomize_seed:
 
75
 
76
 
77
  if lora_id and lora_id.strip() != "":
78
+ try:
79
+ pipe.unload_lora_weights()
80
+ load_lora_auto(pipe, lora_id)
81
+ except Exception as e:
82
+ return f"Error loading LoRA: {e}", seed
83
 
84
 
85
  try: