Ziqi commited on
Commit
40f7024
·
1 Parent(s): 54b8f2e
Files changed (1) hide show
  1. inference.py +12 -12
inference.py CHANGED
@@ -54,9 +54,9 @@ def inference_fn(
54
  pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
55
  else:
56
  pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
57
- # make directory to save images
58
- image_root_folder = os.path.join('experiments', model_id, 'inference')
59
- os.makedirs(image_root_folder, exist_ok = True)
60
 
61
  # if prompt is None and args.template_name is None:
62
  # raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.")
@@ -77,22 +77,22 @@ def inference_fn(
77
  prompt = prompt.lower().replace("<r>", "<R>").format("<R>")
78
 
79
 
80
- # make sub-folder
81
- image_folder = os.path.join(image_root_folder, prompt, 'samples')
82
- os.makedirs(image_folder, exist_ok = True)
83
 
84
  # batch generation
85
  images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
86
 
87
- # save generated images
88
- for idx, image in enumerate(images):
89
- image_name = f"{str(idx).zfill(4)}.png"
90
- image_path = os.path.join(image_folder, image_name)
91
- image.save(image_path)
92
 
93
  # save a grid of images
94
  image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
95
- image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png')
96
 
97
  return image_grid
98
 
 
54
  pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id),torch_dtype=torch.float16).to('cuda')
55
  else:
56
  pipe = StableDiffusionPipeline.from_pretrained(os.path.join('experiments', model_id)).to('cpu')
57
+ # # make directory to save images
58
+ # image_root_folder = os.path.join('experiments', model_id, 'inference')
59
+ # os.makedirs(image_root_folder, exist_ok = True)
60
 
61
  # if prompt is None and args.template_name is None:
62
  # raise ValueError("please input a single prompt through'--prompt' or select a batch of prompts using '--template_name'.")
 
77
  prompt = prompt.lower().replace("<r>", "<R>").format("<R>")
78
 
79
 
80
+ # # make sub-folder
81
+ # image_folder = os.path.join(image_root_folder, prompt, 'samples')
82
+ # os.makedirs(image_folder, exist_ok = True)
83
 
84
  # batch generation
85
  images = pipe(prompt, num_inference_steps=ddim_steps, guidance_scale=guidance_scale, num_images_per_prompt=num_samples).images
86
 
87
+ # # save generated images
88
+ # for idx, image in enumerate(images):
89
+ # image_name = f"{str(idx).zfill(4)}.png"
90
+ # image_path = os.path.join(image_folder, image_name)
91
+ # image.save(image_path)
92
 
93
  # save a grid of images
94
  image_grid = make_image_grid(images, rows=2, cols=math.ceil(num_samples/2))
95
+ # image_grid_path = os.path.join(image_root_folder, prompt, f'{prompt}.png')
96
 
97
  return image_grid
98