File size: 1,927 Bytes
1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 6cd7aaf 1b9e4e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
!pip install -U adapter-transformers !pip install -U transformers import gradio as gr from transformers import CLIPProcessor, CLIPModel from PIL import Image import torch # Load the model and processor model = CLIPModel.from_pretrained("Taarhoinc/TaarhoGen1") processor = CLIPProcessor.from_pretrained("Taarhoinc/TaarhoGen1") # Define the function to describe a floor plan def describe_floorplan(floorplan_image: Image.Image, top_k: int = 3): """Describes a floor plan drawing by listing components.""" # Define a list of common floor plan components components = [ "bedroom", "kitchen", "bathroom", "living room", "dining room", "hallway", "garage", "balcony", "stairs", "door", "window", ] # Preprocess the image and text prompts inputs = processor( text=components, images=floorplan_image, return_tensors="pt", padding=True ) # Get the logits (similarity scores) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image # Get the predicted probabilities probs = logits_per_image.softmax(dim=1).cpu().numpy()[0] # Get the indices of the top-k components top_k_indices = probs.argsort()[-top_k:][::-1] # Get the top-k components detected_components = [components[i] for i in top_k_indices] return ", ".join(detected_components) # Return as a comma-separated string # Create the Gradio interface gr.Interface( fn=describe_floorplan, inputs=[ gr.Image(label="Upload a floor plan drawing", type="pil"), gr.Slider(1, 10, step=1, value=3, label="Number of components to detect"), ], outputs=gr.Label(label="Detected Components"), title="Floor Plan Description with TaarhoGen1", description="Upload a floor plan drawing to get a list of detected components.", ).launch() |