danube2024 commited on
Commit
18b2b4b
·
verified ·
1 Parent(s): 8b1647d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -54
app.py CHANGED
@@ -12,8 +12,6 @@ from PIL import Image, ImageEnhance, ImageOps
12
  device = "cpu" # or "cuda" if GPU is available
13
  torch_dtype = torch.float32 # if using CPU or float16 for GPU
14
 
15
- # --- Load Base SDXL Model ---
16
- # (Large model, be sure you have enough memory or use fewer steps)
17
  print("Loading SDXL Base model...")
18
  pipe = StableDiffusionXLPipeline.from_pretrained(
19
  "stabilityai/stable-diffusion-xl-base-1.0",
@@ -21,86 +19,57 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
21
  )
22
  pipe.to(device)
23
 
24
- # --- Load LoRA Weights from KappaNeuro/bas-relief ---
25
- # The safetensors file is named "BAS-RELIEF.safetensors"
26
- # This merges the LoRA into the pipeline so you can use it via the "BAS-RELIEF" token
27
  print("Loading bas-relief LoRA weights...")
 
 
28
  pipe.load_lora_weights(
29
- repo_id_or_path="KappaNeuro/bas-relief",
30
  weight_name="BAS-RELIEF.safetensors"
31
  )
32
 
33
- # --- Load Depth Estimation Model ---
34
- # We'll use Intel's DPT for depth. On CPU, it's also relatively large, so be cautious of performance.
35
  print("Loading DPT Depth model...")
36
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
37
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
38
 
39
 
40
- ############################################
41
- # 2. Depth Map Enhancement (PIL-based)
42
- ############################################
43
  def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
44
  """
45
- - Normalize depth to [0, 255]
46
- - Auto-contrast to emphasize details
47
- - Sharpen edges
48
  """
49
  d_min, d_max = depth_arr.min(), depth_arr.max()
50
  depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
51
  depth_stretched = (depth_stretched * 255).astype(np.uint8)
52
 
53
  depth_pil = Image.fromarray(depth_stretched)
54
-
55
- # Auto-contrast
56
  depth_pil = ImageOps.autocontrast(depth_pil)
57
 
58
- # Sharpen
59
  enhancer = ImageEnhance.Sharpness(depth_pil)
60
  depth_pil = enhancer.enhance(2.0)
61
 
62
  return depth_pil
63
 
64
 
65
- ############################################
66
- # 3. Generation + Depth Inference Function
67
- ############################################
68
  def generate_bas_relief_and_depth(prompt: str):
69
- """
70
- 1) Generate a 'bas-relief' style image using the LoRA from KappaNeuro/bas-relief.
71
- - Must include "BAS-RELIEF" token in the prompt for the style to apply.
72
- 2) Compute a depth map using Intel/DPT-Large.
73
- 3) Return (image, depth_map).
74
- """
75
-
76
- # -- Step A: Merge the user's prompt with "BAS-RELIEF" instance token --
77
- # You can experiment with different prompt styles:
78
- # e.g. "BAS-RELIEF sculpture of a woman in shibari, marble, octane render..."
79
  full_prompt = f"BAS-RELIEF {prompt}"
80
 
81
- # -- Step B: Generate the image with SDXL + LoRA
82
- # Keep resolution modest to avoid timeouts on CPU
83
  print("Generating bas-relief image...")
84
  result = pipe(
85
  prompt=full_prompt,
86
- num_inference_steps=15, # Lower steps => faster (but lower quality)
87
  guidance_scale=7.5,
88
- height=512, # can reduce to e.g. 384 if still too slow
89
  width=512
90
  )
91
-
92
- # Extract image from pipeline result
93
  generated_image = result.images[0]
94
 
95
- # -- Step C: Depth Estimation with DPT
96
  print("Running depth estimation...")
97
  inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
98
-
99
  with torch.no_grad():
100
  outputs = depth_model(**inputs)
101
- predicted_depth = outputs.predicted_depth # shape: [batch, height, width]
102
 
103
- # Resize to match original image resolution
104
  prediction = torch.nn.functional.interpolate(
105
  predicted_depth.unsqueeze(1),
106
  size=generated_image.size[::-1],
@@ -109,31 +78,24 @@ def generate_bas_relief_and_depth(prompt: str):
109
  ).squeeze(0)
110
 
111
  depth_arr = prediction.cpu().numpy()
112
- depth_map_pil = enhance_depth_map(depth_arr)
113
 
114
- return generated_image, depth_map_pil
115
 
116
 
117
- ############################################
118
- # 4. Gradio Interface
119
- ############################################
120
- title = "Bas-Relief with SDXL + LoRA + Depth Map"
121
  description = (
122
- "This demo loads SDXL-base on CPU (slow!) and merges LoRA from KappaNeuro/bas-relief. "
123
- "Use 'BAS-RELIEF' in your prompt for the style. Then we generate a depth map using DPT."
124
- "Lower resolution or fewer steps if you get timeouts."
125
  )
126
 
127
  iface = gr.Interface(
128
  fn=generate_bas_relief_and_depth,
129
  inputs=gr.Textbox(
130
  label="Describe your scene/style",
131
- placeholder="sculpture of a woman in shibari, marble, intricate details"
132
  ),
133
- outputs=[
134
- gr.Image(label="Bas-Relief Image"),
135
- gr.Image(label="Depth Map"),
136
- ],
137
  title=title,
138
  description=description
139
  )
 
12
  device = "cpu" # or "cuda" if GPU is available
13
  torch_dtype = torch.float32 # if using CPU or float16 for GPU
14
 
 
 
15
  print("Loading SDXL Base model...")
16
  pipe = StableDiffusionXLPipeline.from_pretrained(
17
  "stabilityai/stable-diffusion-xl-base-1.0",
 
19
  )
20
  pipe.to(device)
21
 
 
 
 
22
  print("Loading bas-relief LoRA weights...")
23
+ # IMPORTANT: Pass the first argument as a string to the repo or path,
24
+ # and `weight_name` as a kwarg. That matches the actual function signature.
25
  pipe.load_lora_weights(
26
+ "KappaNeuro/bas-relief", # repo / path
27
  weight_name="BAS-RELIEF.safetensors"
28
  )
29
 
 
 
30
  print("Loading DPT Depth model...")
31
  feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
32
  depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
33
 
34
 
 
 
 
35
  def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
36
  """
37
+ Normalize depth to [0, 255], auto-contrast, and sharpen.
 
 
38
  """
39
  d_min, d_max = depth_arr.min(), depth_arr.max()
40
  depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
41
  depth_stretched = (depth_stretched * 255).astype(np.uint8)
42
 
43
  depth_pil = Image.fromarray(depth_stretched)
 
 
44
  depth_pil = ImageOps.autocontrast(depth_pil)
45
 
 
46
  enhancer = ImageEnhance.Sharpness(depth_pil)
47
  depth_pil = enhancer.enhance(2.0)
48
 
49
  return depth_pil
50
 
51
 
 
 
 
52
  def generate_bas_relief_and_depth(prompt: str):
53
+ # We prepend "BAS-RELIEF" to ensure the LoRA style is triggered.
 
 
 
 
 
 
 
 
 
54
  full_prompt = f"BAS-RELIEF {prompt}"
55
 
 
 
56
  print("Generating bas-relief image...")
57
  result = pipe(
58
  prompt=full_prompt,
59
+ num_inference_steps=15, # Lower for speed on CPU
60
  guidance_scale=7.5,
61
+ height=512,
62
  width=512
63
  )
 
 
64
  generated_image = result.images[0]
65
 
 
66
  print("Running depth estimation...")
67
  inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
 
68
  with torch.no_grad():
69
  outputs = depth_model(**inputs)
70
+ predicted_depth = outputs.predicted_depth
71
 
72
+ # Resize depth map to match original image
73
  prediction = torch.nn.functional.interpolate(
74
  predicted_depth.unsqueeze(1),
75
  size=generated_image.size[::-1],
 
78
  ).squeeze(0)
79
 
80
  depth_arr = prediction.cpu().numpy()
81
+ depth_pil = enhance_depth_map(depth_arr)
82
 
83
+ return generated_image, depth_pil
84
 
85
 
86
+ title = "Bas-Relief (SDXL + LoRA) + Depth Map"
 
 
 
87
  description = (
88
+ "Load SDXL base on CPU, apply 'BAS-RELIEF.safetensors' LoRA from KappaNeuro/bas-relief. "
89
+ "Then run DPT for depth estimation."
 
90
  )
91
 
92
  iface = gr.Interface(
93
  fn=generate_bas_relief_and_depth,
94
  inputs=gr.Textbox(
95
  label="Describe your scene/style",
96
+ placeholder="e.g., 'sculpture of a woman in shibari, marble, intricate details'"
97
  ),
98
+ outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
 
 
 
99
  title=title,
100
  description=description
101
  )