Spaces:
Runtime error
Runtime error
import shutil | |
import subprocess | |
import sys | |
import os | |
def list_files_in_directory(directory): | |
for root, dirs, files in os.walk(directory): | |
for file in files: | |
print(os.path.join(root, file)) | |
list_files_in_directory("Data") | |
# Danh sách các gói cần cài đặt | |
required_packages = [ | |
"sentence-transformers", | |
"faiss-cpu", # Thay thế faiss-gpu bằng faiss-cpu để tránh lỗi - bạn có thể thay đổi tùy theo yêu cầu | |
"numpy", # Đảm bảo cài đặt numpy mới nhất | |
"gradio", | |
"matplotlib" | |
] | |
# Cài đặt các gói nếu chưa được cài đặt | |
def install_packages(packages): | |
for package in packages: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
install_packages(required_packages) | |
# Import các thư viện cần thiết sau khi cài đặt | |
from glob import glob | |
from PIL import Image | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
import matplotlib.pyplot as plt | |
import random | |
import gradio as gr | |
# Tạo embbeding | |
model = SentenceTransformer('clip-ViT-B-32') | |
image_path = "data" | |
# Tìm tất cả các tệp ảnh có định dạng .jpg | |
image_files = glob(os.path.join(image_path, "**/*.jpg"), recursive=True) | |
print("Found image files:", image_files) | |
if not image_files: | |
raise ValueError("No image files found in the specified directory.") | |
chunk_size = 256 | |
embeddings = [] | |
def process_chunk(chunk): | |
images = [] | |
for image_file in chunk: | |
images.append(Image.open(image_file)) | |
chunk_embeddings = model.encode(images) | |
return chunk_embeddings | |
for i in range(0, len(image_files), chunk_size): | |
chunk = image_files[i:i + chunk_size] | |
embeddings.extend(process_chunk(chunk)) | |
print("Number of embeddings:", len(embeddings)) | |
# Dựng lên vectorDB với FAISS | |
dimension = len(embeddings[0]) | |
index = faiss.IndexFlatIP(dimension) | |
index = faiss.IndexIDMap(index) | |
vectors = np.array(embeddings).astype('float32') | |
index.add_with_ids(vectors, np.array(range(len(embeddings)))) | |
# Save index vào file | |
faiss.write_index(index, "index.faiss") | |
# Ghi các tên ảnh vào tệp tin để load khi cần | |
with open("image_files.txt", "w") as f: | |
for image_file in image_files: | |
f.write(image_file + "\n") | |
def search_image(query, model, index, image_files, top_k=5): | |
# Query có thể là ảnh hoặc text | |
# Kiểm tra nếu query là một numpy array thì chuyển lại về dạng ảnh của PIL | |
if isinstance(query, np.ndarray): | |
query = Image.fromarray(query) | |
query_embedding = model.encode(query) | |
query_embedding = query_embedding.astype("float32").reshape(1, -1) | |
distances, indices = index.search(query_embedding, top_k) | |
retrieved_image_files = [image_files[i] for i in indices[0]] | |
return retrieved_image_files | |
def display(query_image, query_text): | |
if query_text: | |
return search_image(query_text, model, index, image_files, 10) | |
else: | |
return search_image(query_image, model, index, image_files, 10) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
query_text = gr.Textbox(label="Nhập văn bản", lines=3) | |
query_image = gr.Image(label="Tải lên ảnh", height=500) | |
with gr.Column(): | |
result_images = gr.Gallery(label="Bộ ảnh kết quả", height=645, columns=2, rows=5) | |
button = gr.Button("Tìm kiếm") | |
button.click(fn=lambda query_image, query_text: display(query_image, query_text), inputs=[query_image, query_text], outputs=result_images) | |
demo.launch() | |