Spaces:
Running
Running
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoProcessor | |
import torch | |
from PIL import Image | |
import io | |
# Load model and processor (using CPU) | |
folder_path = "diffusers/shot-categorizer-v0" | |
model = AutoModelForCausalLM.from_pretrained(folder_path, trust_remote_code=True).eval() | |
processor = AutoProcessor.from_pretrained(folder_path, trust_remote_code=True) | |
# Define analysis function | |
def analyze_image(image): | |
# Convert Gradio image input to PIL Image | |
if isinstance(image, Image.Image): | |
img = image.convert("RGB") | |
else: | |
img = Image.open(io.BytesIO(image)).convert("RGB") | |
prompts = ["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"] | |
results = {} | |
# Process each prompt | |
with torch.no_grad(): | |
for prompt in prompts: | |
inputs = processor(text=prompt, images=img, return_tensors="pt") | |
generated_ids = model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=1024, | |
early_stopping=False, | |
do_sample=False, | |
num_beams=3, | |
) | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed_answer = processor.post_process_generation( | |
generated_text, task=prompt, image_size=(img.width, img.height) | |
) | |
results[prompt] = parsed_answer | |
# Format the output | |
output_text = "Image Analysis Results:\n\n" | |
output_text += f"Color: {results['<COLOR>']}\n" | |
output_text += f"Lighting: {results['<LIGHTING>']}\n" | |
output_text += f"Lighting Type: {results['<LIGHTING_TYPE>']}\n" | |
output_text += f"Composition: {results['<COMPOSITION>']}\n" | |
return output_text | |
# Create Gradio interface | |
with gr.Blocks(title="Image Analyzer") as demo: | |
gr.Markdown("# Image Analysis Demo") | |
gr.Markdown("Upload an image to analyze its color, lighting, and composition characteristics.") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload Image") | |
analyze_button = gr.Button("Analyze Image") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Analysis Results", lines=10) | |
# Add example images | |
examples = gr.Examples( | |
examples=["shot.jpg"], | |
inputs=image_input, | |
label="Try with this example" | |
) | |
# Connect the button to the function | |
analyze_button.click( | |
fn=analyze_image, | |
inputs=image_input, | |
outputs=output_text | |
) | |
# Launch the demo | |
demo.launch() |