ouclxy commited on
Commit
cab03fb
·
verified ·
1 Parent(s): bc26b52

Update test_stablehairv2.py

Browse files
Files changed (1) hide show
  1. test_stablehairv2.py +10 -9
test_stablehairv2.py CHANGED
@@ -100,10 +100,11 @@ def log_validation(
100
  gscale = float(_os.getenv('SH_GUIDANCE', getattr(args, 'guidance_scale', 1.5)))
101
  vlen = int(_os.getenv('SH_VIDEO_LENGTH', getattr(args, 'video_length', 21)))
102
  # 统一时序长度:上下文帧数始终等于视频帧数(不再读取 SH_CONTEXT_FRAMES)
103
- cframes = vlen
104
- print("推理步数:",steps)
105
- print("guidance_scale:",gscale)
106
- print("视频帧数:",vlen)
 
107
  # Generate camera trajectory with exactly vlen frames
108
  angles = np.linspace(0, 2 * np.pi, vlen, endpoint=False)
109
  X = 0.4 * np.sin(angles)
@@ -132,7 +133,7 @@ def log_validation(
132
  id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
133
  id_image = cv2.resize(id_image, (512, 512))
134
 
135
- id_list = [id_image for _ in range(12)]
136
  if align_enabled:
137
  hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
138
  prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
@@ -152,8 +153,8 @@ def log_validation(
152
  result = pipeline(
153
  prompt="",
154
  negative_prompt="",
155
- num_inference_steps=30,
156
- guidance_scale=1.5,
157
  width=512,
158
  height=512,
159
  controlnet_condition=id_list,
@@ -165,8 +166,8 @@ def log_validation(
165
  poses=None,
166
  x=x_tensor,
167
  y=y_tensor,
168
- video_length=21,
169
- context_frames=12,
170
  )
171
  video = torch.cat([result.videos, result.videos], dim=0)
172
  video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")
 
100
  gscale = float(_os.getenv('SH_GUIDANCE', getattr(args, 'guidance_scale', 1.5)))
101
  vlen = int(_os.getenv('SH_VIDEO_LENGTH', getattr(args, 'video_length', 21)))
102
  # 统一时序长度:上下文帧数始终等于视频帧数(不再读取 SH_CONTEXT_FRAMES)
103
+ cframes = int(_os.getenv('SH_CFRAMES', getattr(args, 'cframes', 12)))
104
+ print("[cfg]推理步数:",steps)
105
+ print("[cfg]guidance_scale:",gscale)
106
+ print("[cfg]视频帧数:",vlen)
107
+ print("[cfg]cframes:",cframes)
108
  # Generate camera trajectory with exactly vlen frames
109
  angles = np.linspace(0, 2 * np.pi, vlen, endpoint=False)
110
  X = 0.4 * np.sin(angles)
 
133
  id_image = cv2.cvtColor(cv2.imread(temp_bald_path), cv2.COLOR_BGR2RGB)
134
  id_image = cv2.resize(id_image, (512, 512))
135
 
136
+ id_list = [id_image for _ in range(cframes)]
137
  if align_enabled:
138
  hair_image = _maybe_align_image(args.validation_hairs[0], output_size=align_size, prefer_cuda=prefer_cuda)
139
  prompt_img = _maybe_align_image(args.validation_ids[0], output_size=align_size, prefer_cuda=prefer_cuda)
 
153
  result = pipeline(
154
  prompt="",
155
  negative_prompt="",
156
+ num_inference_steps=steps,
157
+ guidance_scale=gscale,
158
  width=512,
159
  height=512,
160
  controlnet_condition=id_list,
 
166
  poses=None,
167
  x=x_tensor,
168
  y=y_tensor,
169
+ video_length=vlen,
170
+ context_frames=cframes,
171
  )
172
  video = torch.cat([result.videos, result.videos], dim=0)
173
  video_path = os.path.join(output_dir, f"generated_video_{idx}.mp4")