Spaces:
Build error
Build error
from PIL import Image | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
import torch | |
import gradio as gr | |
import os | |
from .common_gui import get_folder_path, scriptdir, list_dirs | |
from .custom_logging import setup_logging | |
# Set up logging | |
log = setup_logging() | |
def load_model(): | |
# Set the device to GPU if available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Initialize the BLIP2 processor | |
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b") | |
# Initialize the BLIP2 model | |
model = Blip2ForConditionalGeneration.from_pretrained( | |
"Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16 | |
) | |
# Move the model to the specified device | |
model.to(device) | |
return processor, model, device | |
def get_images_in_directory(directory_path): | |
""" | |
Returns a list of image file paths found in the provided directory path. | |
Parameters: | |
- directory_path: A string representing the path to the directory to search for images. | |
Returns: | |
- A list of strings, where each string is the full path to an image file found in the specified directory. | |
""" | |
import os | |
# List of common image file extensions to look for | |
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"] | |
# Generate a list of image file paths in the directory | |
image_files = [ | |
# constructs the full path to the file | |
os.path.join(directory_path, file) | |
# lists all files and directories in the given path | |
for file in os.listdir(directory_path) | |
# gets the file extension in lowercase | |
if os.path.splitext(file)[1].lower() in image_extensions | |
] | |
# Return the list of image file paths | |
return image_files | |
def generate_caption( | |
file_list, | |
processor, | |
model, | |
device, | |
caption_file_ext=".txt", | |
num_beams=5, | |
repetition_penalty=1.5, | |
length_penalty=1.2, | |
max_new_tokens=40, | |
min_new_tokens=20, | |
do_sample=True, | |
temperature=1.0, | |
top_p=0.0, | |
): | |
""" | |
Fetches and processes each image in file_list, generates captions based on the image, and writes the generated captions to a file. | |
Parameters: | |
- file_list: A list of file paths pointing to the images to be captioned. | |
- processor: The preprocessor for the BLIP2 model. | |
- model: The BLIP2 model to be used for generating captions. | |
- device: The device on which the computation is performed. | |
- extension: The extension for the output text files. | |
- num_beams: Number of beams for beam search. Default: 5. | |
- repetition_penalty: Penalty for repeating tokens. Default: 1.5. | |
- length_penalty: Penalty for sentence length. Default: 1.2. | |
- max_new_tokens: Maximum number of new tokens to generate. Default: 40. | |
- min_new_tokens: Minimum number of new tokens to generate. Default: 20. | |
""" | |
for file_path in file_list: | |
image = Image.open(file_path) | |
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16) | |
if top_p == 0.0: | |
generated_ids = model.generate( | |
**inputs, | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty, | |
length_penalty=length_penalty, | |
max_new_tokens=max_new_tokens, | |
min_new_tokens=min_new_tokens, | |
) | |
else: | |
generated_ids = model.generate( | |
**inputs, | |
do_sample=do_sample, | |
top_p=top_p, | |
max_new_tokens=max_new_tokens, | |
min_new_tokens=min_new_tokens, | |
temperature=temperature, | |
) | |
generated_text = processor.batch_decode( | |
generated_ids, skip_special_tokens=True | |
)[0].strip() | |
# Construct the output file path by replacing the original file extension with the specified extension | |
output_file_path = os.path.splitext(file_path)[0] + caption_file_ext | |
# Write the generated text to the output file | |
with open(output_file_path, "w", encoding="utf-8") as output_file: | |
output_file.write(generated_text) | |
# Log the image file path with a message about the fact that the caption was generated | |
log.info(f"{file_path} caption was generated") | |
def caption_images_beam_search( | |
directory_path, | |
num_beams, | |
repetition_penalty, | |
length_penalty, | |
min_new_tokens, | |
max_new_tokens, | |
caption_file_ext, | |
): | |
""" | |
Captions all images in the specified directory using the provided prompt. | |
Parameters: | |
- directory_path: A string representing the path to the directory containing the images to be captioned. | |
""" | |
log.info("BLIP2 captionning beam...") | |
if not os.path.isdir(directory_path): | |
log.error(f"Directory {directory_path} does not exist.") | |
return | |
processor, model, device = load_model() | |
image_files = get_images_in_directory(directory_path) | |
generate_caption( | |
file_list=image_files, | |
processor=processor, | |
model=model, | |
device=device, | |
num_beams=int(num_beams), | |
repetition_penalty=float(repetition_penalty), | |
length_penalty=length_penalty, | |
min_new_tokens=int(min_new_tokens), | |
max_new_tokens=int(max_new_tokens), | |
caption_file_ext=caption_file_ext, | |
) | |
def caption_images_nucleus( | |
directory_path, | |
do_sample, | |
temperature, | |
top_p, | |
min_new_tokens, | |
max_new_tokens, | |
caption_file_ext, | |
): | |
""" | |
Captions all images in the specified directory using the provided prompt. | |
Parameters: | |
- directory_path: A string representing the path to the directory containing the images to be captioned. | |
""" | |
log.info("BLIP2 captionning nucleus...") | |
if not os.path.isdir(directory_path): | |
log.error(f"Directory {directory_path} does not exist.") | |
return | |
processor, model, device = load_model() | |
image_files = get_images_in_directory(directory_path) | |
generate_caption( | |
file_list=image_files, | |
processor=processor, | |
model=model, | |
device=device, | |
do_sample=do_sample, | |
temperature=temperature, | |
top_p=top_p, | |
min_new_tokens=int(min_new_tokens), | |
max_new_tokens=int(max_new_tokens), | |
caption_file_ext=caption_file_ext, | |
) | |
def gradio_blip2_caption_gui_tab(headless=False, directory_path=None): | |
from .common_gui import create_refresh_button | |
directory_path = ( | |
directory_path | |
if directory_path is not None | |
else os.path.join(scriptdir, "data") | |
) | |
current_train_dir = directory_path | |
def list_train_dirs(path): | |
nonlocal current_train_dir | |
current_train_dir = path | |
return list(list_dirs(path)) | |
with gr.Tab("BLIP2 Captioning"): | |
gr.Markdown( | |
"This utility uses BLIP2 to caption files for each image in a folder." | |
) | |
with gr.Group(), gr.Row(): | |
directory_path_dir = gr.Dropdown( | |
label="Image folder to caption (containing the images to caption)", | |
choices=[""] + list_train_dirs(directory_path), | |
value="", | |
interactive=True, | |
allow_custom_value=True, | |
) | |
create_refresh_button( | |
directory_path_dir, | |
lambda: None, | |
lambda: {"choices": list_train_dirs(current_train_dir)}, | |
"open_folder_small", | |
) | |
button_directory_path_dir_input = gr.Button( | |
"📂", | |
elem_id="open_folder_small", | |
elem_classes=["tool"], | |
visible=(not headless), | |
) | |
button_directory_path_dir_input.click( | |
get_folder_path, | |
outputs=directory_path_dir, | |
show_progress=False, | |
) | |
with gr.Group(), gr.Row(): | |
min_new_tokens = gr.Number( | |
value=20, | |
label="Min new tokens", | |
interactive=True, | |
step=1, | |
minimum=5, | |
maximum=300, | |
) | |
max_new_tokens = gr.Number( | |
value=40, | |
label="Max new tokens", | |
interactive=True, | |
step=1, | |
minimum=5, | |
maximum=300, | |
) | |
caption_file_ext = gr.Textbox( | |
label="Caption file extension", | |
placeholder="Extension for caption file (e.g., .caption, .txt)", | |
value=".txt", | |
interactive=True, | |
) | |
with gr.Row(): | |
with gr.Tab("Beam search"): | |
with gr.Row(): | |
num_beams = gr.Slider( | |
minimum=1, | |
maximum=16, | |
value=16, | |
step=1, | |
interactive=True, | |
label="Number of beams", | |
) | |
len_penalty = gr.Slider( | |
minimum=-1.0, | |
maximum=2.0, | |
value=1.0, | |
step=0.2, | |
interactive=True, | |
label="Length Penalty", | |
info="increase for longer sequence", | |
) | |
rep_penalty = gr.Slider( | |
minimum=1.0, | |
maximum=5.0, | |
value=1.5, | |
step=0.5, | |
interactive=True, | |
label="Repeat Penalty", | |
info="larger value prevents repetition", | |
) | |
caption_button_beam = gr.Button( | |
value="Caption images", interactive=True, variant="primary" | |
) | |
caption_button_beam.click( | |
caption_images_beam_search, | |
inputs=[ | |
directory_path_dir, | |
num_beams, | |
rep_penalty, | |
len_penalty, | |
min_new_tokens, | |
max_new_tokens, | |
caption_file_ext, | |
], | |
) | |
with gr.Tab("Nucleus sampling"): | |
with gr.Row(): | |
do_sample = gr.Checkbox(label="Sample", value=True) | |
temperature = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
value=1.0, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
info="used with nucleus sampling", | |
) | |
top_p = gr.Slider( | |
minimum=0, | |
maximum=1, | |
value=0.9, | |
step=0.1, | |
interactive=True, | |
label="Top_p", | |
) | |
caption_button_nucleus = gr.Button( | |
value="Caption images", interactive=True, variant="primary" | |
) | |
caption_button_nucleus.click( | |
caption_images_nucleus, | |
inputs=[ | |
directory_path_dir, | |
do_sample, | |
temperature, | |
top_p, | |
min_new_tokens, | |
max_new_tokens, | |
caption_file_ext, | |
], | |
) | |