Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -28,14 +28,10 @@ model_repo_id = "K00B404/pix2pix_flux"
|
|
28 |
# Create dataset and dataloader
|
29 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
30 |
def __init__(self, ds):
|
31 |
-
# Filter dataset for 'original' and 'target' images
|
32 |
-
#https://huggingface.co/datasets/K00B404/pix2pix_flux_set/viewer/default/train?f[label][value]=0
|
33 |
-
#https://huggingface.co/datasets/K00B404/pix2pix_flux_set/viewer/default/train?f[label][value]=1
|
34 |
-
|
35 |
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
36 |
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
37 |
-
|
38 |
-
|
39 |
# Ensure the number of original and target images match
|
40 |
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
|
41 |
|
@@ -47,13 +43,13 @@ class Pix2PixDataset(torch.utils.data.Dataset):
|
|
47 |
return len(self.originals)
|
48 |
|
49 |
def __getitem__(self, idx):
|
50 |
-
#
|
51 |
original_img = self.originals[idx]['image']
|
52 |
target_img = self.targets[idx]['image']
|
53 |
|
54 |
# Apply the necessary transforms
|
55 |
-
original =
|
56 |
-
target =
|
57 |
|
58 |
# Return transformed original and target images
|
59 |
return transform(original), transform(target)
|
|
|
28 |
# Create dataset and dataloader
|
29 |
class Pix2PixDataset(torch.utils.data.Dataset):
|
30 |
def __init__(self, ds):
|
31 |
+
# Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
|
|
|
|
|
|
|
32 |
self.originals = [x for x in ds["train"] if x['label'] == 0]
|
33 |
self.targets = [x for x in ds["train"] if x['label'] == 1]
|
34 |
+
|
|
|
35 |
# Ensure the number of original and target images match
|
36 |
assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
|
37 |
|
|
|
43 |
return len(self.originals)
|
44 |
|
45 |
def __getitem__(self, idx):
|
46 |
+
# Directly use the 'image' object without loading via Image.open()
|
47 |
original_img = self.originals[idx]['image']
|
48 |
target_img = self.targets[idx]['image']
|
49 |
|
50 |
# Apply the necessary transforms
|
51 |
+
original = original_img.convert('RGB')
|
52 |
+
target = target_img.convert('RGB')
|
53 |
|
54 |
# Return transformed original and target images
|
55 |
return transform(original), transform(target)
|