import gradio as gr from transformers import AutoModelForCausalLM, AutoProcessor import torch from PIL import Image import io # Load model and processor (using CPU) folder_path = "diffusers/shot-categorizer-v0" model = AutoModelForCausalLM.from_pretrained(folder_path, trust_remote_code=True).eval() processor = AutoProcessor.from_pretrained(folder_path, trust_remote_code=True) # Define analysis function def analyze_image(image): # Convert Gradio image input to PIL Image if isinstance(image, Image.Image): img = image.convert("RGB") else: img = Image.open(io.BytesIO(image)).convert("RGB") prompts = ["", "", "", ""] results = {} # Process each prompt with torch.no_grad(): for prompt in prompts: inputs = processor(text=prompt, images=img, return_tensors="pt") generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=prompt, image_size=(img.width, img.height) ) results[prompt] = parsed_answer # Format the output output_text = "Image Analysis Results:\n\n" output_text += f"Color: {results['']}\n" output_text += f"Lighting: {results['']}\n" output_text += f"Lighting Type: {results['']}\n" output_text += f"Composition: {results['']}\n" return output_text # Create Gradio interface with gr.Blocks(title="Image Analyzer") as demo: gr.Markdown("# Image Analysis Demo") gr.Markdown("Upload an image to analyze its color, lighting, and composition characteristics.") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload Image") analyze_button = gr.Button("Analyze Image") with gr.Column(): output_text = gr.Textbox(label="Analysis Results", lines=10) # Add example images examples = gr.Examples( examples=["shot.jpg"], inputs=image_input, label="Try with this example" ) # Connect the button to the function analyze_button.click( fn=analyze_image, inputs=image_input, outputs=output_text ) # Launch the demo demo.launch()