MultiModal / app.py
locboyf1's picture
Update app.py
37f6bd5 verified
import shutil
import subprocess
import sys
import os
def list_files_in_directory(directory):
for root, dirs, files in os.walk(directory):
for file in files:
print(os.path.join(root, file))
list_files_in_directory("Data")
# Danh sách các gói cần cài đặt
required_packages = [
"sentence-transformers",
"faiss-cpu", # Thay thế faiss-gpu bằng faiss-cpu để tránh lỗi - bạn có thể thay đổi tùy theo yêu cầu
"numpy", # Đảm bảo cài đặt numpy mới nhất
"gradio",
"matplotlib"
]
# Cài đặt các gói nếu chưa được cài đặt
def install_packages(packages):
for package in packages:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
install_packages(required_packages)
# Import các thư viện cần thiết sau khi cài đặt
from glob import glob
from PIL import Image
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
import matplotlib.pyplot as plt
import random
import gradio as gr
# Tạo embbeding
model = SentenceTransformer('clip-ViT-B-32')
image_path = "data"
# Tìm tất cả các tệp ảnh có định dạng .jpg
image_files = glob(os.path.join(image_path, "**/*.jpg"), recursive=True)
print("Found image files:", image_files)
if not image_files:
raise ValueError("No image files found in the specified directory.")
chunk_size = 256
embeddings = []
def process_chunk(chunk):
images = []
for image_file in chunk:
images.append(Image.open(image_file))
chunk_embeddings = model.encode(images)
return chunk_embeddings
for i in range(0, len(image_files), chunk_size):
chunk = image_files[i:i + chunk_size]
embeddings.extend(process_chunk(chunk))
print("Number of embeddings:", len(embeddings))
# Dựng lên vectorDB với FAISS
dimension = len(embeddings[0])
index = faiss.IndexFlatIP(dimension)
index = faiss.IndexIDMap(index)
vectors = np.array(embeddings).astype('float32')
index.add_with_ids(vectors, np.array(range(len(embeddings))))
# Save index vào file
faiss.write_index(index, "index.faiss")
# Ghi các tên ảnh vào tệp tin để load khi cần
with open("image_files.txt", "w") as f:
for image_file in image_files:
f.write(image_file + "\n")
def search_image(query, model, index, image_files, top_k=5):
# Query có thể là ảnh hoặc text
# Kiểm tra nếu query là một numpy array thì chuyển lại về dạng ảnh của PIL
if isinstance(query, np.ndarray):
query = Image.fromarray(query)
query_embedding = model.encode(query)
query_embedding = query_embedding.astype("float32").reshape(1, -1)
distances, indices = index.search(query_embedding, top_k)
retrieved_image_files = [image_files[i] for i in indices[0]]
return retrieved_image_files
def display(query_image, query_text):
if query_text:
return search_image(query_text, model, index, image_files, 10)
else:
return search_image(query_image, model, index, image_files, 10)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
query_text = gr.Textbox(label="Nhập văn bản", lines=3)
query_image = gr.Image(label="Tải lên ảnh", height=500)
with gr.Column():
result_images = gr.Gallery(label="Bộ ảnh kết quả", height=645, columns=2, rows=5)
button = gr.Button("Tìm kiếm")
button.click(fn=lambda query_image, query_text: display(query_image, query_text), inputs=[query_image, query_text], outputs=result_images)
demo.launch()