pandaphd commited on
Commit
6d551b0
·
1 Parent(s): a7d4f3d

fix diffusers

Browse files
genphoto/data/dataset.py DELETED
@@ -1,950 +0,0 @@
1
- import os
2
- import random
3
- import json
4
- import torch
5
- import math
6
- import torch.nn as nn
7
- import torchvision.transforms as transforms
8
- import torch.nn.functional as F
9
- import numpy as np
10
- from torch.utils.data.dataset import Dataset
11
- from packaging import version as pver
12
- import cv2
13
- from PIL import Image
14
- from einops import rearrange
15
- from transformers import pipeline, CLIPTextModel, CLIPTokenizer
16
-
17
- import sys
18
- sys.path.append('/home/yuan418/data/project/Generative_Photography/genphoto/data/BokehMe/')
19
- from classical_renderer.scatter import ModuleRenderScatter
20
-
21
-
22
-
23
- #### for shutter speed ####
24
- def create_shutter_speed_embedding(shutter_speed_values, target_height, target_width, base_exposure=0.5):
25
- """
26
- Create an shutter_speed embedding tensor using a constant fwc value.
27
- Args:
28
- - shutter_speed_values: Tensor of shape [f, 1] containing shutter_speed values for each frame.
29
- - H: Height of the image.
30
- - W: Width of the image.
31
- - base_exposure: A base exposure value to normalize brightness (defaults to 0.18 as a common base exposure level).
32
-
33
- Returns:
34
- - shutter_speed_embedding: Tensor of shape [f, 1, H, W] where each pixel is scaled based on the shutter_speed values.
35
- """
36
- f = shutter_speed_values.shape[0]
37
-
38
- # Set a constant full well capacity (fwc)
39
- fwc = 32000 # Constant value for full well capacity
40
-
41
- # Calculate scale based on EV and sensor full well capacity (fwc)
42
- scales = (shutter_speed_values / base_exposure) * (fwc / (fwc + 0.0001))
43
-
44
- # Reshape and expand to match image dimensions
45
- scales = scales.unsqueeze(2).unsqueeze(3).expand(f, 3, target_height, target_width)
46
-
47
- # Use scales to create the final shutter_speed embedding
48
- shutter_speed_embedding = scales # Shape [f, 3, H, W]
49
- return shutter_speed_embedding
50
-
51
-
52
- def sensor_image_simulation_numpy(avg_PPP, photon_flux, fwc, Nbits, gain=1):
53
- min_val = 0
54
- max_val = 2 ** Nbits - 1
55
- theta = photon_flux * (avg_PPP / (np.mean(photon_flux) + 0.0001))
56
- theta = np.clip(theta, 0, fwc)
57
- theta = np.round(theta * gain * max_val / fwc)
58
- theta = np.clip(theta, min_val, max_val)
59
- theta = theta.astype(np.float32)
60
- return theta
61
-
62
-
63
- class CameraShutterSpeed(Dataset):
64
- def __init__(
65
- self,
66
- root_path,
67
- annotation_json,
68
- sample_n_frames=5,
69
- sample_size=[256, 384],
70
- is_Train=True,
71
- ):
72
- self.root_path = root_path
73
- self.sample_n_frames = sample_n_frames
74
- self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
75
- self.length = len(self.dataset)
76
- self.is_Train = is_Train
77
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
78
- self.sample_size = sample_size
79
-
80
- pixel_transforms = [transforms.Resize(sample_size),
81
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
82
-
83
- self.pixel_transforms = pixel_transforms
84
- self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
85
- self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
86
-
87
- def load_image_reader(self, idx):
88
- image_dict = self.dataset[idx]
89
- image_path = os.path.join(self.root_path, image_dict['base_image_path'])
90
- image_reader = cv2.imread(image_path)
91
- image_reader = cv2.cvtColor(image_reader, cv2.COLOR_BGR2RGB)
92
- image_caption = image_dict['caption']
93
-
94
- if self.is_Train:
95
- mean = 0.48
96
- std_dev = 0.25
97
- shutter_speed_values = [random.gauss(mean, std_dev) for _ in range(self.sample_n_frames)]
98
- shutter_speed_values = [max(0.1, min(1.0, ev)) for ev in shutter_speed_values]
99
- print('train shutter_speed values', shutter_speed_values)
100
-
101
- else:
102
- shutter_speed_list_str = image_dict['shutter_speed_list']
103
- shutter_speed_values = json.loads(shutter_speed_list_str)
104
- print('validation shutter_speed_values', shutter_speed_values)
105
-
106
- shutter_speed_values = torch.tensor(shutter_speed_values).unsqueeze(1)
107
- return image_path, image_reader, image_caption, shutter_speed_values
108
-
109
-
110
- def get_batch(self, idx):
111
- image_path, image_reader, image_caption, shutter_speed_values = self.load_image_reader(idx)
112
-
113
- total_frames = len(shutter_speed_values)
114
- if total_frames < 3:
115
- raise ValueError("less than 3 frames")
116
-
117
- # Generate prompts for each shutter speed value and append shutter speed information to caption
118
- prompts = []
119
- for ss in shutter_speed_values:
120
- prompt = f"<exposure: {ss.item()}>"
121
- prompts.append(prompt)
122
-
123
- # Tokenize prompts and encode to get embeddings
124
- with torch.no_grad():
125
- prompt_ids = self.tokenizer(
126
- prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
127
- ).input_ids
128
- # print('tokenizer model_max_length', self.tokenizer.model_max_length)
129
-
130
- encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
131
-
132
- # print('encoder_hidden_states shape', encoder_hidden_states.shape)
133
-
134
- # Calculate differences between consecutive embeddings (ignoring sequence_length)
135
- differences = []
136
- for i in range(1, encoder_hidden_states.size(0)):
137
- diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
138
- diff = diff.unsqueeze(0)
139
- differences.append(diff)
140
-
141
- # Add the difference between the last and the first embedding
142
- final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
143
- final_diff = final_diff.unsqueeze(0)
144
- differences.append(final_diff)
145
-
146
- # Concatenate differences along the batch dimension (f-1)
147
- concatenated_differences = torch.cat(differences, dim=0)
148
- # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
149
-
150
- frame = concatenated_differences.size(0)
151
-
152
- concatenated_differences = torch.cat(differences, dim=0)
153
-
154
- # Current shape: (f, 77, 768) Pad the second dimension (77) to 128
155
- pad_length = 128 - concatenated_differences.size(1)
156
- if pad_length > 0:
157
- # Pad along the second dimension (77 -> 128), pad only on the right side
158
- concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
159
-
160
- ## ccl = constrative camera learning
161
- ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
162
- ccl_embedding = ccl_embedding.unsqueeze(1)
163
- ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
164
-
165
- # Now handle the sensor image simulation
166
- fwc = random.uniform(19000, 64000)
167
- pixel_values = []
168
- for ee in shutter_speed_values:
169
- avg_PPP = (0.6 * ee.item() + 0.1) * fwc
170
- img_sim = sensor_image_simulation_numpy(avg_PPP, image_reader, fwc, Nbits=8, gain=1)
171
- pixel_values.append(img_sim)
172
- pixel_values = np.stack(pixel_values, axis=0)
173
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
174
-
175
- # Create shutter_speed embedding and concatenate it with CCL embedding
176
- shutter_speed_embedding = create_shutter_speed_embedding(shutter_speed_values, self.sample_size[0], self.sample_size[1])
177
-
178
- camera_embedding = torch.cat((shutter_speed_embedding, ccl_embedding), dim=1)
179
- # print('camera_embedding shape', camera_embedding.shape)
180
-
181
- return pixel_values, image_caption, camera_embedding, shutter_speed_values
182
-
183
- def __len__(self):
184
- return self.length
185
-
186
- def __getitem__(self, idx):
187
- while True:
188
- try:
189
- video, video_caption, camera_embedding, shutter_speed_values = self.get_batch(idx)
190
- break
191
- except Exception as e:
192
- idx = random.randint(0, self.length - 1)
193
-
194
- for transform in self.pixel_transforms:
195
- video = transform(video)
196
-
197
- sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, shutter_speed_values=shutter_speed_values)
198
-
199
- return sample
200
-
201
-
202
-
203
-
204
-
205
-
206
-
207
-
208
- #### for focal length ####
209
- def crop_focal_length(image_path, base_focal_length, target_focal_length, target_height, target_width, sensor_height=24.0, sensor_width=36.0):
210
- img = Image.open(image_path)
211
- width, height = img.size
212
-
213
- # Calculate base and target FOV
214
- base_x_fov = 2.0 * math.atan(sensor_width * 0.5 / base_focal_length)
215
- base_y_fov = 2.0 * math.atan(sensor_height * 0.5 / base_focal_length)
216
-
217
- target_x_fov = 2.0 * math.atan(sensor_width * 0.5 / target_focal_length)
218
- target_y_fov = 2.0 * math.atan(sensor_height * 0.5 / target_focal_length)
219
-
220
- # Calculate crop ratio, use the smaller ratio to maintain aspect ratio
221
- crop_ratio = min(target_x_fov / base_x_fov, target_y_fov / base_y_fov)
222
-
223
- crop_width = int(round(crop_ratio * width))
224
- crop_height = int(round(crop_ratio * height))
225
-
226
- # Ensure crop dimensions are within valid bounds
227
- crop_width = max(1, min(width, crop_width))
228
- crop_height = max(1, min(height, crop_height))
229
-
230
- # Crop coordinates
231
- left = int((width - crop_width) / 2)
232
- top = int((height - crop_height) / 2)
233
- right = int((width + crop_width) / 2)
234
- bottom = int((height + crop_height) / 2)
235
-
236
- # Crop the image
237
- zoomed_img = img.crop((left, top, right, bottom))
238
-
239
- # Resize the cropped image to target resolution
240
- resized_img = zoomed_img.resize((target_width, target_height), Image.Resampling.LANCZOS)
241
-
242
- # Convert the PIL image to a numpy array
243
- resized_img_np = np.array(resized_img).astype(np.float32)
244
-
245
- return resized_img_np
246
-
247
-
248
- def create_focal_length_embedding(focal_length_values, base_focal_length, target_height, target_width, sensor_height=24.0, sensor_width=36.0):
249
- device = 'cpu'
250
- focal_length_values = focal_length_values.to(device)
251
-
252
- f = focal_length_values.shape[0] # Number of frames
253
-
254
- # Convert constants to tensors to perform operations with focal_length_values
255
- sensor_width = torch.tensor(sensor_width, device=device)
256
- sensor_height = torch.tensor(sensor_height, device=device)
257
- base_focal_length = torch.tensor(base_focal_length, device=device)
258
-
259
- # Calculate the FOV for the base focal length (min_focal_length)
260
- base_fov_x = 2.0 * torch.atan(sensor_width * 0.5 / base_focal_length)
261
- base_fov_y = 2.0 * torch.atan(sensor_height * 0.5 / base_focal_length)
262
-
263
- # Calculate the FOV for each focal length in focal_length_values
264
- target_fov_x = 2.0 * torch.atan(sensor_width * 0.5 / focal_length_values)
265
- target_fov_y = 2.0 * torch.atan(sensor_height * 0.5 / focal_length_values)
266
-
267
- # Calculate crop ratio: how much of the image is cropped at the current focal length
268
- crop_ratio_xs = target_fov_x / base_fov_x # Crop ratio for horizontal axis
269
- crop_ratio_ys = target_fov_y / base_fov_y # Crop ratio for vertical axis
270
-
271
- # Get the center of the image
272
- center_h, center_w = target_height // 2, target_width // 2
273
-
274
- # Initialize a mask tensor with zeros on CPU
275
- focal_length_embedding = torch.zeros((f, 3, target_height, target_width), dtype=torch.float32) # Shape [f, 3, H, W]
276
-
277
- # Fill the center region with 1 based on the calculated crop dimensions
278
- for i in range(f):
279
- # Crop dimensions calculated using rounded float values
280
- crop_h = torch.round(crop_ratio_ys[i] * target_height).int().item() # Rounded cropped height for the current frame
281
- crop_w = torch.round(crop_ratio_xs[i] * target_width).int().item() # Rounded cropped width for the current frame
282
-
283
- # Ensure the cropped dimensions are within valid bounds
284
- crop_h = max(1, min(target_height, crop_h))
285
- crop_w = max(1, min(target_width, crop_w))
286
-
287
- # Set the center region of the focal_length embedding to 1 for the current frame
288
- focal_length_embedding[i, :,
289
- center_h - crop_h // 2: center_h + crop_h // 2,
290
- center_w - crop_w // 2: center_w + crop_w // 2] = 1.0
291
-
292
- return focal_length_embedding
293
-
294
-
295
- class CameraFocalLength(Dataset):
296
- def __init__(
297
- self,
298
- root_path,
299
- annotation_json,
300
- sample_n_frames=5,
301
- sample_size=[256, 384],
302
- is_Train=True,
303
- ):
304
- self.root_path = root_path
305
- self.sample_n_frames = sample_n_frames
306
- self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
307
- self.length = len(self.dataset)
308
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
309
- self.sample_size = sample_size
310
- pixel_transforms = [transforms.Resize(sample_size),
311
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
312
-
313
- self.pixel_transforms = pixel_transforms
314
- self.is_Train = is_Train
315
- self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
316
- self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
317
-
318
-
319
- def load_image_reader(self, idx):
320
- image_dict = self.dataset[idx]
321
-
322
- image_path = os.path.join(self.root_path, image_dict['base_image_path'])
323
- image_reader = cv2.imread(image_path)
324
-
325
- image_caption = image_dict['caption']
326
-
327
- if self.is_Train:
328
- focal_length_values = [random.uniform(24.0, 70.0) for _ in range(self.sample_n_frames)]
329
- print('train focal_length_values', focal_length_values)
330
- else:
331
- focal_length_list_str = image_dict['focal_length_list']
332
- focal_length_values = json.loads(focal_length_list_str)
333
- print('validation focal_length_values', focal_length_values)
334
-
335
- focal_length_values = torch.tensor(focal_length_values).unsqueeze(1)
336
-
337
- return image_path, image_reader, image_caption, focal_length_values
338
-
339
-
340
- def get_batch(self, idx):
341
- image_path, image_reader, image_caption, focal_length_values = self.load_image_reader(idx)
342
-
343
- total_frames = len(focal_length_values)
344
- if total_frames < 3:
345
- raise ValueError("less than 3 frames")
346
-
347
- # Generate prompts for each fl value and append fl information to caption
348
- prompts = []
349
- for fl in focal_length_values:
350
- prompt = f"<focal length: {fl.item()}>"
351
- prompts.append(prompt)
352
-
353
- # Tokenize prompts and encode to get embeddings
354
- with torch.no_grad():
355
- prompt_ids = self.tokenizer(
356
- prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
357
- ).input_ids
358
-
359
- encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
360
- # print('encoder_hidden_states shape', encoder_hidden_states.shape)
361
-
362
- # Calculate differences between consecutive embeddings (ignoring sequence_length)
363
- differences = []
364
- for i in range(1, encoder_hidden_states.size(0)):
365
- diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
366
- diff = diff.unsqueeze(0)
367
- differences.append(diff)
368
-
369
- # Add the difference between the last and the first embedding
370
- final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
371
- final_diff = final_diff.unsqueeze(0)
372
- differences.append(final_diff)
373
-
374
- # Concatenate differences along the batch dimension (f-1)
375
- concatenated_differences = torch.cat(differences, dim=0)
376
- # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
377
-
378
- frame = concatenated_differences.size(0)
379
-
380
- # Concatenate differences along the batch dimension (f)
381
- concatenated_differences = torch.cat(differences, dim=0)
382
-
383
- # Current shape: (f, 77, 768), Pad the second dimension (77) to 128
384
- pad_length = 128 - concatenated_differences.size(1)
385
- if pad_length > 0:
386
- # Pad along the second dimension (77 -> 128), pad only on the right side
387
- concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
388
-
389
- ## CCL = constrative camera learning
390
- ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
391
-
392
- ccl_embedding = ccl_embedding.unsqueeze(1)
393
- ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
394
- # print('ccl_embedding shape', ccl_embedding.shape)
395
-
396
- pixel_values = []
397
- for ff in focal_length_values:
398
- img_sim = crop_focal_length(image_path=image_path, base_focal_length=24.0, target_focal_length=ff, target_height=self.sample_size[0], target_width=self.sample_size[1], sensor_height=24.0, sensor_width=36.0)
399
-
400
- pixel_values.append(img_sim)
401
- # save_path = os.path.join(self.root_path, f"simulated_img_focal_length_{fl.item():.2f}.png")
402
- # cv2.imwrite(save_path, img_sim)
403
- # print(f"Saved image: {save_path}")
404
-
405
- pixel_values = np.stack(pixel_values, axis=0)
406
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
407
-
408
- focal_length_embedding = create_focal_length_embedding(focal_length_values, base_focal_length=24.0, target_height=self.sample_size[0], target_width=self.sample_size[1])
409
- # print('focal_length_embedding shape', focal_length_embedding.shape)
410
-
411
- camera_embedding = torch.cat((focal_length_embedding, ccl_embedding), dim=1)
412
- # print('camera_embedding shape', camera_embedding.shape)
413
-
414
- return pixel_values, image_caption, camera_embedding, focal_length_values
415
-
416
- def __len__(self):
417
- return self.length
418
-
419
- def __getitem__(self, idx):
420
- while True:
421
- try:
422
- video, video_caption, camera_embedding, focal_length_values = self.get_batch(idx)
423
- break
424
- except Exception as e:
425
- idx = random.randint(0, self.length - 1)
426
-
427
- for transform in self.pixel_transforms:
428
- video = transform(video)
429
-
430
- sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, focal_length_values=focal_length_values)
431
-
432
- return sample
433
-
434
-
435
-
436
-
437
-
438
-
439
-
440
- #### for color temperature ####
441
- def kelvin_to_rgb(kelvin):
442
- temp = kelvin / 100.0
443
-
444
- if temp <= 66:
445
- red = 255
446
- green = 99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0
447
- if temp <= 19:
448
- blue = 0
449
- else:
450
- blue = 138.5177312231 * np.log(temp - 10) - 305.0447927307
451
-
452
- elif 66<temp<=88:
453
- red = 0.5 * (255 + 329.698727446 * ((temp - 60) ** -0.19332047592))
454
- green = 0.5 * (288.1221695283 * ((temp - 60) ** -0.1155148492) + (99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0))
455
- blue = 0.5 * (138.5177312231 * np.log(temp - 10) - 305.0447927307 + 255)
456
-
457
- else:
458
- red = 329.698727446 * ((temp - 60) ** -0.19332047592)
459
- green = 288.1221695283 * ((temp - 60) ** -0.1155148492)
460
- blue = 255
461
-
462
- return np.array([red, green, blue], dtype=np.float32) / 255.0
463
-
464
-
465
-
466
- def create_color_temperature_embedding(color_temperature_values, target_height, target_width, min_color_temperature=2000, max_color_temperature=10000):
467
- """
468
- Create an color_temperature embedding tensor based on color temperature.
469
- Args:
470
- - color_temperature_values: Tensor of shape [f, 1] containing color_temperature values for each frame.
471
- - target_height: Height of the image.
472
- - target_width: Width of the image.
473
- - min_color_temperature: Minimum color_temperature value for normalization.
474
- - max_color_temperature: Maximum color_temperature value for normalization.
475
- Returns:
476
- - color_temperature_embedding: Tensor of shape [f, 3, target_height, target_width] for RGB channel scaling.
477
- """
478
- f = color_temperature_values.shape[0]
479
- rgb_factors = []
480
-
481
- # Compute RGB factors based on kelvin_to_rgb function
482
- for ct in color_temperature_values.squeeze():
483
- kelvin = min_color_temperature + (ct * (max_color_temperature - min_color_temperature)) # Map normalized color_temperature to actual Kelvin
484
- rgb = kelvin_to_rgb(kelvin)
485
- rgb_factors.append(rgb)
486
-
487
- # Convert to tensor and expand to target dimensions
488
- rgb_factors = torch.tensor(rgb_factors).float() # [f, 3]
489
- rgb_factors = rgb_factors.unsqueeze(2).unsqueeze(3) # [f, 3, 1, 1]
490
- color_temperature_embedding = rgb_factors.expand(f, 3, target_height, target_width) # [f, 3, target_height, target_width]
491
- return color_temperature_embedding
492
-
493
-
494
-
495
- def kelvin_to_rgb_smooth(kelvin):
496
- temp = kelvin / 100.0
497
-
498
- if temp <= 66:
499
- red = 255
500
- green = 99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0
501
- if temp <= 19:
502
- blue = 0
503
- else:
504
- blue = 138.5177312231 * np.log(temp - 10) - 305.0447927307
505
-
506
- elif 66<temp<=88:
507
- red = 0.5 * (255 + 329.698727446 * ((temp - 60) ** -0.19332047592))
508
- green = 0.5 * (288.1221695283 * ((temp - 60) ** -0.1155148492) + (99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0))
509
- blue = 0.5 * (138.5177312231 * np.log(temp - 10) - 305.0447927307 + 255)
510
-
511
- else:
512
- red = 329.698727446 * ((temp - 60) ** -0.19332047592)
513
- green = 288.1221695283 * ((temp - 60) ** -0.1155148492)
514
- blue = 255
515
-
516
- red = np.clip(red, 0, 255)
517
- green = np.clip(green, 0, 255)
518
- blue = np.clip(blue, 0, 255)
519
- balance_rgb = np.array([red, green, blue], dtype=np.float32)
520
-
521
- return balance_rgb
522
-
523
-
524
- def interpolate_white_balance(image, kelvin):
525
-
526
- balance_rgb = kelvin_to_rgb_smooth(kelvin.item())
527
- image = image.astype(np.float32)
528
-
529
- r, g, b = cv2.split(image)
530
- r = r * (balance_rgb[0] / 255.0)
531
- g = g * (balance_rgb[1] / 255.0)
532
- b = b * (balance_rgb[2] / 255.0)
533
-
534
- balanced_image = cv2.merge([r,g,b])
535
- balanced_image = np.clip(balanced_image, 0, 255).astype(np.uint8)
536
-
537
- return balanced_image
538
-
539
-
540
- class CameraColorTemperature(Dataset):
541
- def __init__(
542
- self,
543
- root_path,
544
- annotation_json,
545
- sample_n_frames=5,
546
- sample_size=[256, 384],
547
- is_Train=True,
548
- ):
549
- self.root_path = root_path
550
- self.sample_n_frames = sample_n_frames
551
- self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
552
-
553
- self.length = len(self.dataset)
554
- self.is_Train = is_Train
555
-
556
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
557
- self.sample_size = sample_size
558
-
559
- pixel_transforms = [transforms.Resize(sample_size),
560
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
561
-
562
- self.pixel_transforms = pixel_transforms
563
- self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
564
- self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
565
-
566
- def load_image_reader(self, idx):
567
- image_dict = self.dataset[idx]
568
-
569
- image_path = os.path.join(self.root_path, image_dict['base_image_path'])
570
- image_reader = cv2.imread(image_path)
571
- image_reader = cv2.cvtColor(image_reader, cv2.COLOR_BGR2RGB)
572
-
573
- image_caption = image_dict['caption']
574
-
575
- if self.is_Train:
576
- color_temperature_values = [random.uniform(2000.0, 10000.0) for _ in range(self.sample_n_frames)]
577
- print('train color_temperature values', color_temperature_values)
578
-
579
- else:
580
- color_temperature_list_str = image_dict['color_temperature_list']
581
- color_temperature_values = json.loads(color_temperature_list_str)
582
- print('validation color_temperature_values', color_temperature_values)
583
-
584
- color_temperature_values = torch.tensor(color_temperature_values).unsqueeze(1)
585
- return image_path, image_reader, image_caption, color_temperature_values
586
-
587
-
588
- def get_batch(self, idx):
589
- image_path, image_reader, image_caption, color_temperature_values = self.load_image_reader(idx)
590
-
591
- total_frames = len(color_temperature_values)
592
- if total_frames < 3:
593
- raise ValueError("less than 3 frames")
594
-
595
- # Generate prompts for each color_temperature value and append color_temperature information to caption
596
- prompts = []
597
- for cc in color_temperature_values:
598
- prompt = f"<color temperature: {cc.item()}>"
599
- prompts.append(prompt)
600
-
601
- # Tokenize prompts and encode to get embeddings
602
- with torch.no_grad():
603
- prompt_ids = self.tokenizer(
604
- prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
605
- ).input_ids
606
- # print('tokenizer model_max_length', self.tokenizer.model_max_length)
607
-
608
- encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
609
-
610
- # print('encoder_hidden_states shape', encoder_hidden_states.shape)
611
-
612
- # Calculate differences between consecutive embeddings (ignoring sequence_length)
613
- differences = []
614
- for i in range(1, encoder_hidden_states.size(0)):
615
- diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
616
- diff = diff.unsqueeze(0)
617
- differences.append(diff)
618
-
619
- # Add the difference between the last and the first embedding
620
- final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
621
- final_diff = final_diff.unsqueeze(0)
622
- differences.append(final_diff)
623
-
624
- # Concatenate differences along the batch dimension (f-1)
625
- concatenated_differences = torch.cat(differences, dim=0)
626
- # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
627
-
628
- frame = concatenated_differences.size(0)
629
-
630
- concatenated_differences = torch.cat(differences, dim=0)
631
-
632
- # Current shape: (f, 77, 768), Pad the second dimension (77) to 128
633
- pad_length = 128 - concatenated_differences.size(1)
634
- if pad_length > 0:
635
- # Pad along the second dimension (77 -> 128), pad only on the right side
636
- concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
637
-
638
- ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
639
- ccl_embedding = ccl_embedding.unsqueeze(1)
640
- ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
641
- # print('ccl_embedding shape', ccl_embedding.shape)
642
-
643
- # Now handle the sensor image simulation
644
- pixel_values = []
645
- for aw in color_temperature_values:
646
- img_sim = interpolate_white_balance(image_reader, aw)
647
- pixel_values.append(img_sim)
648
- pixel_values = np.stack(pixel_values, axis=0)
649
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
650
-
651
- # Create color_temperature embedding and concatenate it with CCL embedding
652
- color_temperature_embedding = create_color_temperature_embedding(color_temperature_values, self.sample_size[0], self.sample_size[1])
653
- # print('color_temperature_embedding shape', color_temperature_embedding.shape)
654
-
655
- camera_embedding = torch.cat((color_temperature_embedding, ccl_embedding), dim=1)
656
- # print('camera_embedding shape', camera_embedding.shape)
657
-
658
- return pixel_values, image_caption, camera_embedding, color_temperature_values
659
-
660
- def __len__(self):
661
- return self.length
662
-
663
- def __getitem__(self, idx):
664
- while True:
665
- try:
666
- video, video_caption, camera_embedding, color_temperature_values = self.get_batch(idx)
667
- break
668
- except Exception as e:
669
- idx = random.randint(0, self.length - 1)
670
-
671
- for transform in self.pixel_transforms:
672
- video = transform(video)
673
-
674
- sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, color_temperature_values=color_temperature_values)
675
-
676
- return sample
677
-
678
-
679
-
680
-
681
-
682
-
683
-
684
-
685
- #### for bokeh (K is the blur parameter) ####
686
- def create_bokehK_embedding(bokehK_values, target_height, target_width):
687
- """
688
- Creates a Bokeh embedding based on the given K values. The larger the K value,
689
- the more the image is blurred.
690
-
691
- Args:
692
- bokehK_values (torch.Tensor): Tensor of K values for bokeh effect.
693
- target_height (int): Desired height of the output embedding.
694
- target_width (int): Desired width of the output embedding.
695
- base_K (float): Base K value to control the minimum blur level.
696
-
697
- Returns:
698
- torch.Tensor: Bokeh embedding tensor. [f 3 h w]
699
- """
700
- f = bokehK_values.shape[0]
701
- bokehK_embedding = torch.zeros((f, 3, target_height, target_width), dtype=bokehK_values.dtype)
702
-
703
- for i in range(f):
704
- K_value = bokehK_values[i].item()
705
-
706
- kernel_size = max(K_value, 1)
707
- sigma = K_value / 3.0
708
-
709
- ax = np.linspace(-(kernel_size / 2), kernel_size / 2, int(np.ceil(kernel_size)))
710
- xx, yy = np.meshgrid(ax, ax)
711
- kernel = np.exp(-(xx ** 2 + yy ** 2) / (2 * sigma ** 2))
712
- kernel /= np.sum(kernel)
713
-
714
- scale = kernel[int(np.ceil(kernel_size) / 2), int(np.ceil(kernel_size) / 2)]
715
- bokehK_embedding[i] = scale
716
-
717
- return bokehK_embedding
718
-
719
-
720
- def bokehK_simulation(image_path, depth_map_path, K, disp_focus, gamma=2.2):
721
- ## depth map image can be inferenced online using following code ##
722
- # model_dir = "/home/modules/"
723
- # pipe = pipeline(
724
- # task="depth-estimation",
725
- # model="depth-anything/Depth-Anything-V2-Small-hf",
726
- # cache_dir=model_dir,
727
- # device=0
728
- # )
729
-
730
- # image_raw = Image.open(image_path)
731
-
732
- # disp = pipe(image_raw)["depth"]
733
- # base_name = os.path.basename(image_path)
734
- # file_name, ext = os.path.splitext(base_name)
735
-
736
- # disp_file_name = f"{file_name}_disp.png"
737
- # disp.save(disp_file_name)
738
-
739
- # disp = np.array(disp)
740
- # disp = disp.astype(np.float32)
741
- # disp /= 255.0
742
-
743
- disp = np.float32(cv2.imread(depth_map_path, cv2.IMREAD_GRAYSCALE))
744
-
745
- disp /= 255.0
746
- disp = (disp - disp.min()) / (disp.max() - disp.min())
747
- min_disp = np.min(disp)
748
- max_disp = np.max(disp)
749
-
750
- device = torch.device('cuda')
751
-
752
- # Initialize renderer
753
- classical_renderer = ModuleRenderScatter().to(device)
754
-
755
- # Load image and disparity
756
- image = cv2.imread(image_path).astype(np.float32) / 255.0
757
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
758
-
759
- # Calculate defocus
760
- defocus = K * (disp - disp_focus) / 10.0
761
-
762
- # Convert to tensors and move to GPU if available
763
- image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(device)
764
-
765
- defocus = defocus.unsqueeze(0).unsqueeze(0).to(device)
766
-
767
- bokeh_classical, defocus_dilate = classical_renderer(image**gamma, defocus*10.0)
768
- bokeh_pred = bokeh_classical ** (1/gamma)
769
- bokeh_pred = bokeh_pred.squeeze(0)
770
- bokeh_pred = bokeh_pred.permute(1, 2, 0) # remove batch dim and change channle order
771
- bokeh_pred = (bokeh_pred * 255).cpu().numpy()
772
- bokeh_pred = np.round(bokeh_pred)
773
- bokeh_pred = bokeh_pred.astype(np.float32)
774
-
775
- return bokeh_pred
776
-
777
-
778
-
779
-
780
- class CameraBokehK(Dataset):
781
- def __init__(
782
- self,
783
- root_path,
784
- annotation_json,
785
- sample_n_frames=5,
786
- sample_size=[256, 384],
787
- is_Train=True,
788
- ):
789
- self.root_path = root_path
790
- self.sample_n_frames = sample_n_frames
791
- self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
792
-
793
- self.length = len(self.dataset)
794
- self.is_Train = is_Train
795
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
796
- self.sample_size = sample_size
797
-
798
- pixel_transforms = [transforms.Resize(sample_size),
799
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
800
-
801
- self.pixel_transforms = pixel_transforms
802
- self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
803
- self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
804
-
805
- def load_image_reader(self, idx):
806
- image_dict = self.dataset[idx]
807
-
808
- image_path = os.path.join(self.root_path, image_dict['base_image_path'])
809
- depth_map_path = os.path.join(self.root_path, image_dict['depth_map_path'])
810
-
811
- image_caption = image_dict['caption']
812
-
813
-
814
- if self.is_Train:
815
- bokehK_values = [random.uniform(1.0, 30.0) for _ in range(self.sample_n_frames)]
816
- print('train bokehK values', bokehK_values)
817
-
818
- else:
819
- bokehK_list_str = image_dict['bokehK_list']
820
- bokehK_values = json.loads(bokehK_list_str)
821
- print('validation bokehK_values', bokehK_values)
822
-
823
- bokehK_values = torch.tensor(bokehK_values).unsqueeze(1)
824
- return image_path, depth_map_path, image_caption, bokehK_values
825
-
826
-
827
- def get_batch(self, idx):
828
- image_path, depth_map_path, image_caption, bokehK_values = self.load_image_reader(idx)
829
-
830
- total_frames = len(bokehK_values)
831
- if total_frames < 3:
832
- raise ValueError("less than 3 frames")
833
-
834
- # Generate prompts for each bokehK value and append bokehK information to caption
835
- prompts = []
836
- for bb in bokehK_values:
837
- prompt = f"<bokeh kernel size: {bb.item()}>"
838
- prompts.append(prompt)
839
-
840
- # Tokenize prompts and encode to get embeddings
841
- with torch.no_grad():
842
- prompt_ids = self.tokenizer(
843
- prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
844
- ).input_ids
845
- # print('tokenizer model_max_length', self.tokenizer.model_max_length)
846
-
847
- encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
848
-
849
- # print('encoder_hidden_states shape', encoder_hidden_states.shape)
850
-
851
- # Calculate differences between consecutive embeddings (ignoring sequence_length)
852
- differences = []
853
- for i in range(1, encoder_hidden_states.size(0)):
854
- diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
855
- diff = diff.unsqueeze(0)
856
- differences.append(diff)
857
-
858
- # Add the difference between the last and the first embedding
859
- final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
860
- final_diff = final_diff.unsqueeze(0)
861
- differences.append(final_diff)
862
-
863
- # Concatenate differences along the batch dimension (f-1)
864
- concatenated_differences = torch.cat(differences, dim=0)
865
-
866
- # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
867
-
868
- frame = concatenated_differences.size(0)
869
-
870
- # Concatenate differences along the batch dimension (f)
871
- concatenated_differences = torch.cat(differences, dim=0)
872
-
873
- # Current shape: (f, 77, 768), Pad the second dimension (77) to 128
874
- pad_length = 128 - concatenated_differences.size(1)
875
- if pad_length > 0:
876
- # Pad along the second dimension (77 -> 128), pad only on the right side
877
- concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
878
-
879
- ## ccl = contrastive camera learning ##
880
- ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
881
- ccl_embedding = ccl_embedding.unsqueeze(1)
882
- ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
883
- # print('ccl_embedding shape', ccl_embedding.shape)
884
-
885
- pixel_values = []
886
- for bk in bokehK_values:
887
- img_sim = bokehK_simulation(image_path, depth_map_path, bk, disp_focus=0.96, gamma=2.2)
888
- # save_path = os.path.join(self.root_path, f"simulated_img_bokeh_{bk.item():.2f}.png")
889
- # cv2.imwrite(save_path, img_sim)
890
- # print(f"Saved image: {save_path}")
891
- pixel_values.append(img_sim)
892
-
893
- pixel_values = np.stack(pixel_values, axis=0)
894
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
895
-
896
- # Create bokehK embedding and concatenate it with CCL embedding
897
- bokehK_embedding = create_bokehK_embedding(bokehK_values, self.sample_size[0], self.sample_size[1])
898
-
899
- camera_embedding = torch.cat((bokehK_embedding, ccl_embedding), dim=1)
900
- # print('camera_embedding shape', camera_embedding.shape)
901
-
902
- return pixel_values, image_caption, camera_embedding, bokehK_values
903
-
904
- def __len__(self):
905
- return self.length
906
-
907
- def __getitem__(self, idx):
908
- while True:
909
- try:
910
- video, video_caption, camera_embedding, bokehK_values = self.get_batch(idx)
911
- break
912
- except Exception as e:
913
- idx = random.randint(0, self.length - 1)
914
-
915
- for transform in self.pixel_transforms:
916
- video = transform(video)
917
-
918
- sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, bokehK_values=bokehK_values)
919
-
920
- return sample
921
-
922
-
923
-
924
- def test_camera_bokehK_dataset():
925
- root_path = '/home/yuan418/data/project/camera_dataset/camera_bokehK/'
926
- annotation_json = 'annotations/inference.json'
927
-
928
- print('------------------')
929
- dataset = CameraBokehK(
930
- root_path=root_path,
931
- annotation_json=annotation_json,
932
- sample_n_frames=4,
933
- sample_size=[256, 384],
934
- is_Train=False,
935
- )
936
-
937
- # choose one sample for testing
938
- idx = 1
939
- sample = dataset[idx]
940
-
941
- pixel_values = sample['pixel_values']
942
- text = sample['text']
943
- camera_embedding = sample['camera_embedding']
944
- print(f"Pixel values shape: {pixel_values.shape}")
945
- print(f"Text: {text}")
946
- print(f"camera embedding shape: {camera_embedding.shape}")
947
-
948
-
949
- if __name__ == "__main__":
950
- test_camera_bokehK_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
genphoto/models/unet.py CHANGED
@@ -11,14 +11,14 @@ from einops import repeat, rearrange
11
  from dataclasses import dataclass
12
  from typing import List, Optional, Tuple, Union, Dict, Any
13
 
14
- from diffusers.configuration_utils import ConfigMixin, register_to_config
15
- from diffusers.models.attention_processor import AttentionProcessor
16
-
17
- from diffusers.models.modeling_utils import ModelMixin
18
- from diffusers.utils import BaseOutput, logging
19
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
- from diffusers.models.attention_processor import LoRAAttnProcessor
21
- from diffusers.loaders import AttnProcsLayers, UNet2DConditionLoadersMixin
22
 
23
  from genphoto.models.unet_blocks import (
24
  CrossAttnDownBlock3D,
 
11
  from dataclasses import dataclass
12
  from typing import List, Optional, Tuple, Union, Dict, Any
13
 
14
+ from ..diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from ..diffusers.models.attention_processor import AttentionProcessor
16
+
17
+ from ..diffusers.models.modeling_utils import ModelMixin
18
+ from ..diffusers.utils import BaseOutput, logging
19
+ from ..diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+ from ..diffusers.models.attention_processor import LoRAAttnProcessor
21
+ from ..diffusers.loaders import AttnProcsLayers, UNet2DConditionLoadersMixin
22
 
23
  from genphoto.models.unet_blocks import (
24
  CrossAttnDownBlock3D,
genphoto/pipelines/pipeline_animation.py CHANGED
@@ -7,14 +7,14 @@ import numpy as np
7
 
8
  from typing import Callable, List, Optional, Union
9
  from dataclasses import dataclass
10
- from diffusers.utils import is_accelerate_available
11
  from packaging import version
12
  from einops import rearrange
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
- from diffusers.configuration_utils import FrozenDict
15
- from diffusers.models import AutoencoderKL
16
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
17
- from diffusers.schedulers import (
18
  DDIMScheduler,
19
  DPMSolverMultistepScheduler,
20
  EulerAncestralDiscreteScheduler,
@@ -22,8 +22,8 @@ from diffusers.schedulers import (
22
  LMSDiscreteScheduler,
23
  PNDMScheduler,
24
  )
25
- from diffusers.loaders import LoraLoaderMixin
26
- from diffusers.utils import deprecate, logging, BaseOutput
27
 
28
  from genphoto.models.camera_adaptor import CameraCameraEncoder
29
  from genphoto.models.unet import UNet3DConditionModel
 
7
 
8
  from typing import Callable, List, Optional, Union
9
  from dataclasses import dataclass
10
+ from ..diffusers.utils import is_accelerate_available
11
  from packaging import version
12
  from einops import rearrange
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
+ from ..diffusers.configuration_utils import FrozenDict
15
+ from ..diffusers.models import AutoencoderKL
16
+ from ..diffusers.pipelines.pipeline_utils import DiffusionPipeline
17
+ from ..diffusers.schedulers import (
18
  DDIMScheduler,
19
  DPMSolverMultistepScheduler,
20
  EulerAncestralDiscreteScheduler,
 
22
  LMSDiscreteScheduler,
23
  PNDMScheduler,
24
  )
25
+ from ..diffusers.loaders import LoraLoaderMixin
26
+ from ..diffusers.utils import deprecate, logging, BaseOutput
27
 
28
  from genphoto.models.camera_adaptor import CameraCameraEncoder
29
  from genphoto.models.unet import UNet3DConditionModel
inference_bokehK.py CHANGED
@@ -12,7 +12,7 @@ from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
- from diffusers import AutoencoderKL, DDIMScheduler
16
  from einops import rearrange
17
 
18
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
 
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
+ from .diffusers import AutoencoderKL, DDIMScheduler
16
  from einops import rearrange
17
 
18
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
inference_color_temperature.py CHANGED
@@ -12,7 +12,7 @@ from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
- from diffusers import AutoencoderKL, DDIMScheduler
16
 
17
  from einops import rearrange
18
 
 
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
+ from .diffusers import AutoencoderKL, DDIMScheduler
16
 
17
  from einops import rearrange
18
 
inference_focal_length.py CHANGED
@@ -12,7 +12,7 @@ from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
- from diffusers import AutoencoderKL, DDIMScheduler
16
 
17
 
18
  from einops import rearrange
 
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
+ from .diffusers import AutoencoderKL, DDIMScheduler
16
 
17
 
18
  from einops import rearrange
inference_shutter_speed.py CHANGED
@@ -12,7 +12,7 @@ from omegaconf import OmegaConf
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
- from diffusers import AutoencoderKL, DDIMScheduler
16
  from einops import rearrange
17
 
18
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
 
12
  from torch.utils.data import Dataset
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
15
+ from .diffusers import AutoencoderKL, DDIMScheduler
16
  from einops import rearrange
17
 
18
  from genphoto.pipelines.pipeline_animation import GenPhotoPipeline