K00B404 commited on
Commit
23e8de2
·
verified ·
1 Parent(s): 7a72c33

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -9
app.py CHANGED
@@ -28,14 +28,10 @@ model_repo_id = "K00B404/pix2pix_flux"
28
  # Create dataset and dataloader
29
  class Pix2PixDataset(torch.utils.data.Dataset):
30
  def __init__(self, ds):
31
- # Filter dataset for 'original' and 'target' images
32
- #https://huggingface.co/datasets/K00B404/pix2pix_flux_set/viewer/default/train?f[label][value]=0
33
- #https://huggingface.co/datasets/K00B404/pix2pix_flux_set/viewer/default/train?f[label][value]=1
34
-
35
  self.originals = [x for x in ds["train"] if x['label'] == 0]
36
  self.targets = [x for x in ds["train"] if x['label'] == 1]
37
-
38
-
39
  # Ensure the number of original and target images match
40
  assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
41
 
@@ -47,13 +43,13 @@ class Pix2PixDataset(torch.utils.data.Dataset):
47
  return len(self.originals)
48
 
49
  def __getitem__(self, idx):
50
- # Load original and target images for the given index
51
  original_img = self.originals[idx]['image']
52
  target_img = self.targets[idx]['image']
53
 
54
  # Apply the necessary transforms
55
- original = Image.open(original_img).convert('RGB')
56
- target = Image.open(target_img).convert('RGB')
57
 
58
  # Return transformed original and target images
59
  return transform(original), transform(target)
 
28
  # Create dataset and dataloader
29
  class Pix2PixDataset(torch.utils.data.Dataset):
30
  def __init__(self, ds):
31
+ # Filter dataset for 'original' (label = 0) and 'target' (label = 1) images
 
 
 
32
  self.originals = [x for x in ds["train"] if x['label'] == 0]
33
  self.targets = [x for x in ds["train"] if x['label'] == 1]
34
+
 
35
  # Ensure the number of original and target images match
36
  assert len(self.originals) == len(self.targets), "Mismatch in number of original and target images."
37
 
 
43
  return len(self.originals)
44
 
45
  def __getitem__(self, idx):
46
+ # Directly use the 'image' object without loading via Image.open()
47
  original_img = self.originals[idx]['image']
48
  target_img = self.targets[idx]['image']
49
 
50
  # Apply the necessary transforms
51
+ original = original_img.convert('RGB')
52
+ target = target_img.convert('RGB')
53
 
54
  # Return transformed original and target images
55
  return transform(original), transform(target)