|
import modal |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
from .app import app |
|
from .image import image |
|
from .volume import volume |
|
|
|
|
|
@app.cls(gpu="T4", image=image, volumes={"/volume": volume}) |
|
class TaskModelRetrieverModalApp: |
|
@modal.enter() |
|
def setup(self): |
|
tasks = ["object-detection", "image-segmentation"] |
|
self.vector_stores = {} |
|
for task in tasks: |
|
self.vector_stores[task] = FAISS.load_local( |
|
folder_path=f"/volume/vector_store/{task}", |
|
embeddings=HuggingFaceEmbeddings( |
|
model_name="all-MiniLM-L6-v2", |
|
model_kwargs={"device": "cuda"}, |
|
encode_kwargs={"normalize_embeddings": True}, |
|
show_progress=True, |
|
), |
|
index_name="faiss_index", |
|
allow_dangerous_deserialization=True, |
|
) |
|
|
|
def forward(self, task: str, query: str) -> str: |
|
docs = self.vector_stores[task].similarity_search(query, k=7) |
|
model_ids = [doc.metadata["model_id"] for doc in docs] |
|
model_labels = [doc.metadata["model_labels"] for doc in docs] |
|
models_dict = {model_id: model_labels for model_id, model_labels in zip(model_ids, model_labels)} |
|
return models_dict |
|
|
|
@modal.method() |
|
def object_detection_search(self, query: str) -> str: |
|
return self.forward("object-detection", query) |
|
|
|
@modal.method() |
|
def image_segmentation_search(self, query: str) -> str: |
|
return self.forward("image-segmentation", query) |
|
|