Spaces:
Paused
Paused
Rishi Desai
commited on
Commit
·
ca96bd8
1
Parent(s):
a446ad0
added batching by category
Browse files
main.py
CHANGED
|
@@ -50,7 +50,7 @@ def validate_input_directory(input_dir):
|
|
| 50 |
print(f" - {file}")
|
| 51 |
sys.exit(1)
|
| 52 |
|
| 53 |
-
def process_images(input_dir, output_dir, fix_outfit=False):
|
| 54 |
"""Process all images in the input directory and generate captions."""
|
| 55 |
input_path = Path(input_dir)
|
| 56 |
output_path = Path(output_dir) if output_dir else input_path
|
|
@@ -64,9 +64,9 @@ def process_images(input_dir, output_dir, fix_outfit=False):
|
|
| 64 |
# Track the number of processed images
|
| 65 |
processed_count = 0
|
| 66 |
|
| 67 |
-
# Collect all images into a
|
| 68 |
-
|
| 69 |
-
|
| 70 |
|
| 71 |
# Get all files in the input directory
|
| 72 |
for file_path in input_path.iterdir():
|
|
@@ -74,26 +74,54 @@ def process_images(input_dir, output_dir, fix_outfit=False):
|
|
| 74 |
try:
|
| 75 |
# Load the image
|
| 76 |
image = Image.open(file_path).convert("RGB")
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
except Exception as e:
|
| 80 |
print(f"Error loading {file_path.name}: {e}")
|
| 81 |
|
| 82 |
# Log the number of images found
|
| 83 |
-
|
|
|
|
| 84 |
|
| 85 |
-
if not
|
| 86 |
print("No valid images found to process.")
|
| 87 |
return
|
| 88 |
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
| 97 |
for file_path, caption in zip(image_paths, captions):
|
| 98 |
try:
|
| 99 |
# Create caption file path (same name but with .txt extension)
|
|
@@ -111,18 +139,16 @@ def process_images(input_dir, output_dir, fix_outfit=False):
|
|
| 111 |
# Copy caption to output directory
|
| 112 |
shutil.copy2(caption_path, output_path / caption_filename)
|
| 113 |
|
| 114 |
-
processed_count += 1
|
| 115 |
print(f"Processed {file_path.name} → {caption_filename}")
|
| 116 |
except Exception as e:
|
| 117 |
print(f"Error processing {file_path.name}: {e}")
|
| 118 |
|
| 119 |
-
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
| 120 |
-
|
| 121 |
def main():
|
| 122 |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
|
| 123 |
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
|
| 124 |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
|
| 125 |
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
|
|
|
|
| 126 |
|
| 127 |
args = parser.parse_args()
|
| 128 |
|
|
@@ -132,7 +158,7 @@ def main():
|
|
| 132 |
return
|
| 133 |
|
| 134 |
# Process images
|
| 135 |
-
process_images(args.input, args.output, args.fix_outfit)
|
| 136 |
|
| 137 |
if __name__ == "__main__":
|
| 138 |
main()
|
|
|
|
| 50 |
print(f" - {file}")
|
| 51 |
sys.exit(1)
|
| 52 |
|
| 53 |
+
def process_images(input_dir, output_dir, fix_outfit=False, batch_images=False):
|
| 54 |
"""Process all images in the input directory and generate captions."""
|
| 55 |
input_path = Path(input_dir)
|
| 56 |
output_path = Path(output_dir) if output_dir else input_path
|
|
|
|
| 64 |
# Track the number of processed images
|
| 65 |
processed_count = 0
|
| 66 |
|
| 67 |
+
# Collect all images into a dictionary grouped by category
|
| 68 |
+
images_by_category = {}
|
| 69 |
+
image_paths_by_category = {}
|
| 70 |
|
| 71 |
# Get all files in the input directory
|
| 72 |
for file_path in input_path.iterdir():
|
|
|
|
| 74 |
try:
|
| 75 |
# Load the image
|
| 76 |
image = Image.open(file_path).convert("RGB")
|
| 77 |
+
|
| 78 |
+
# Determine the category from the filename
|
| 79 |
+
category = file_path.stem.rsplit('_', 1)[0]
|
| 80 |
+
|
| 81 |
+
# Add image to the appropriate category
|
| 82 |
+
if category not in images_by_category:
|
| 83 |
+
images_by_category[category] = []
|
| 84 |
+
image_paths_by_category[category] = []
|
| 85 |
+
|
| 86 |
+
images_by_category[category].append(image)
|
| 87 |
+
image_paths_by_category[category].append(file_path)
|
| 88 |
except Exception as e:
|
| 89 |
print(f"Error loading {file_path.name}: {e}")
|
| 90 |
|
| 91 |
# Log the number of images found
|
| 92 |
+
total_images = sum(len(images) for images in images_by_category.values())
|
| 93 |
+
print(f"Found {total_images} images to process.")
|
| 94 |
|
| 95 |
+
if not total_images:
|
| 96 |
print("No valid images found to process.")
|
| 97 |
return
|
| 98 |
|
| 99 |
+
# Process images by category if batch_images is True
|
| 100 |
+
if batch_images:
|
| 101 |
+
for category, images in images_by_category.items():
|
| 102 |
+
image_paths = image_paths_by_category[category]
|
| 103 |
+
try:
|
| 104 |
+
# Generate captions for the entire category
|
| 105 |
+
captions = caption_images(images)
|
| 106 |
+
write_captions(image_paths, captions, input_path, output_path)
|
| 107 |
+
processed_count += len(images)
|
| 108 |
+
except Exception as e:
|
| 109 |
+
print(f"Error generating captions for category '{category}': {e}")
|
| 110 |
+
else:
|
| 111 |
+
# Process all images at once if batch_images is False
|
| 112 |
+
all_images = [img for imgs in images_by_category.values() for img in imgs]
|
| 113 |
+
all_image_paths = [path for paths in image_paths_by_category.values() for path in paths]
|
| 114 |
+
try:
|
| 115 |
+
captions = caption_images(all_images)
|
| 116 |
+
write_captions(all_image_paths, captions, input_path, output_path)
|
| 117 |
+
processed_count += len(all_images)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Error generating captions: {e}")
|
| 120 |
|
| 121 |
+
print(f"\nProcessing complete. {processed_count} images were captioned.")
|
| 122 |
+
|
| 123 |
+
def write_captions(image_paths, captions, input_path, output_path):
|
| 124 |
+
"""Helper function to write captions to files."""
|
| 125 |
for file_path, caption in zip(image_paths, captions):
|
| 126 |
try:
|
| 127 |
# Create caption file path (same name but with .txt extension)
|
|
|
|
| 139 |
# Copy caption to output directory
|
| 140 |
shutil.copy2(caption_path, output_path / caption_filename)
|
| 141 |
|
|
|
|
| 142 |
print(f"Processed {file_path.name} → {caption_filename}")
|
| 143 |
except Exception as e:
|
| 144 |
print(f"Error processing {file_path.name}: {e}")
|
| 145 |
|
|
|
|
|
|
|
| 146 |
def main():
|
| 147 |
parser = argparse.ArgumentParser(description='Generate captions for images using GPT-4o.')
|
| 148 |
parser.add_argument('--input', type=str, required=True, help='Directory containing images')
|
| 149 |
parser.add_argument('--output', type=str, help='Directory to save images and captions (defaults to input directory)')
|
| 150 |
parser.add_argument('--fix_outfit', action='store_true', help='Flag to indicate if character has one outfit')
|
| 151 |
+
parser.add_argument('--batch_images', action='store_true', help='Flag to indicate if images should be processed in batches')
|
| 152 |
|
| 153 |
args = parser.parse_args()
|
| 154 |
|
|
|
|
| 158 |
return
|
| 159 |
|
| 160 |
# Process images
|
| 161 |
+
process_images(args.input, args.output, args.fix_outfit, args.batch_images)
|
| 162 |
|
| 163 |
if __name__ == "__main__":
|
| 164 |
main()
|