File size: 1,927 Bytes
1b9e4e6
 
 
 
 
 
 
6cd7aaf
1b9e4e6
 
 
6cd7aaf
 
 
1b9e4e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cd7aaf
 
 
 
1b9e4e6
6cd7aaf
 
 
 
1b9e4e6
6cd7aaf
 
1b9e4e6
6cd7aaf
 
1b9e4e6
6cd7aaf
 
1b9e4e6
6cd7aaf
1b9e4e6
 
 
6cd7aaf
1b9e4e6
 
 
 
6cd7aaf
 
 
1b9e4e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
!pip install -U adapter-transformers
!pip install -U transformers
import gradio as gr
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

# Load the model and processor
model = CLIPModel.from_pretrained("Taarhoinc/TaarhoGen1")
processor = CLIPProcessor.from_pretrained("Taarhoinc/TaarhoGen1")

# Define the function to describe a floor plan
def describe_floorplan(floorplan_image: Image.Image, top_k: int = 3):
    """Describes a floor plan drawing by listing components."""

    # Define a list of common floor plan components
    components = [
        "bedroom",
        "kitchen",
        "bathroom",
        "living room",
        "dining room",
        "hallway",
        "garage",
        "balcony",
        "stairs",
        "door",
        "window",
    ]

    # Preprocess the image and text prompts
    inputs = processor(
        text=components, images=floorplan_image, return_tensors="pt", padding=True
    )

    # Get the logits (similarity scores)
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image

    # Get the predicted probabilities
    probs = logits_per_image.softmax(dim=1).cpu().numpy()[0]

    # Get the indices of the top-k components
    top_k_indices = probs.argsort()[-top_k:][::-1]

    # Get the top-k components
    detected_components = [components[i] for i in top_k_indices]

    return ", ".join(detected_components)  # Return as a comma-separated string

# Create the Gradio interface
gr.Interface(
    fn=describe_floorplan,
    inputs=[
        gr.Image(label="Upload a floor plan drawing", type="pil"),
        gr.Slider(1, 10, step=1, value=3, label="Number of components to detect"),
    ],
    outputs=gr.Label(label="Detected Components"),
    title="Floor Plan Description with TaarhoGen1",
    description="Upload a floor plan drawing to get a list of detected components.",
).launch()