added caches
Browse files- requirements.txt +3 -1
- src/embeddings.py +7 -3
- src/main.py +9 -9
requirements.txt
CHANGED
@@ -7,4 +7,6 @@ sacremoses
|
|
7 |
torch
|
8 |
pillow
|
9 |
protobuf
|
10 |
-
|
|
|
|
|
|
7 |
torch
|
8 |
pillow
|
9 |
protobuf
|
10 |
+
|
11 |
+
# Optional dependencies for specific features
|
12 |
+
einops
|
src/embeddings.py
CHANGED
@@ -62,7 +62,7 @@ class BaseEmbeddingTaskService:
|
|
62 |
"""Load and cache processor for the model using AutoProcessor"""
|
63 |
if model_name not in self._processor_cache:
|
64 |
try:
|
65 |
-
self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name)
|
66 |
self._logger.info(f"Loaded processor for model: {model_name}")
|
67 |
except Exception as e:
|
68 |
self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}")
|
@@ -70,6 +70,8 @@ class BaseEmbeddingTaskService:
|
|
70 |
status_code=404,
|
71 |
detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}"
|
72 |
)
|
|
|
|
|
73 |
return self._processor_cache[model_name]
|
74 |
|
75 |
def _load_model(self, model_name: str, cache_suffix: str = ""):
|
@@ -78,7 +80,7 @@ class BaseEmbeddingTaskService:
|
|
78 |
if cache_key not in self._model_cache:
|
79 |
try:
|
80 |
device = self._get_device()
|
81 |
-
model = AutoModel.from_pretrained(model_name)
|
82 |
model.to(device)
|
83 |
self._model_cache[cache_key] = model
|
84 |
self._logger.info(f"Loaded model: {model_name} on {device}")
|
@@ -88,6 +90,8 @@ class BaseEmbeddingTaskService:
|
|
88 |
status_code=404,
|
89 |
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
|
90 |
)
|
|
|
|
|
91 |
return self._model_cache[cache_key]
|
92 |
|
93 |
async def get_embedding_vector_size(self, model_name: str) -> dict:
|
@@ -335,7 +339,7 @@ class TextEmbeddingTaskService(BaseEmbeddingTaskService):
|
|
335 |
"""Main method to generate text embeddings"""
|
336 |
embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
|
337 |
|
338 |
-
self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:
|
339 |
|
340 |
# Load processor and model using auto-detection
|
341 |
processor = self._load_processor(model_name)
|
|
|
62 |
"""Load and cache processor for the model using AutoProcessor"""
|
63 |
if model_name not in self._processor_cache:
|
64 |
try:
|
65 |
+
self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
66 |
self._logger.info(f"Loaded processor for model: {model_name}")
|
67 |
except Exception as e:
|
68 |
self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}")
|
|
|
70 |
status_code=404,
|
71 |
detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}"
|
72 |
)
|
73 |
+
else:
|
74 |
+
self._logger.info(f"Using cached processor for model: {model_name}")
|
75 |
return self._processor_cache[model_name]
|
76 |
|
77 |
def _load_model(self, model_name: str, cache_suffix: str = ""):
|
|
|
80 |
if cache_key not in self._model_cache:
|
81 |
try:
|
82 |
device = self._get_device()
|
83 |
+
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
84 |
model.to(device)
|
85 |
self._model_cache[cache_key] = model
|
86 |
self._logger.info(f"Loaded model: {model_name} on {device}")
|
|
|
90 |
status_code=404,
|
91 |
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
|
92 |
)
|
93 |
+
else:
|
94 |
+
self._logger.info(f"Using cached model: {model_name} (cache key: {cache_key})")
|
95 |
return self._model_cache[cache_key]
|
96 |
|
97 |
async def get_embedding_vector_size(self, model_name: str) -> dict:
|
|
|
339 |
"""Main method to generate text embeddings"""
|
340 |
embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
|
341 |
|
342 |
+
self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:500]}...")
|
343 |
|
344 |
# Load processor and model using auto-detection
|
345 |
processor = self._load_processor(model_name)
|
src/main.py
CHANGED
@@ -29,6 +29,10 @@ logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
|
|
29 |
logger = logging.getLogger(__name__)
|
30 |
logger.setLevel(logging.DEBUG)
|
31 |
|
|
|
|
|
|
|
|
|
32 |
|
33 |
class StreamToLogger(object):
|
34 |
def __init__(self, logger, log_level):
|
@@ -333,8 +337,7 @@ async def image_embedding(
|
|
333 |
"""
|
334 |
|
335 |
model_name = model_name.rstrip("/")
|
336 |
-
|
337 |
-
return await imageEmbeddingTask.generate_embedding(request, model_name)
|
338 |
|
339 |
|
340 |
# =========================
|
@@ -399,8 +402,7 @@ async def image_embedding_upload(
|
|
399 |
"""
|
400 |
|
401 |
model_name = model_name.rstrip("/")
|
402 |
-
|
403 |
-
return await imageEmbeddingTask.generate_embedding_from_upload(image, model_name)
|
404 |
|
405 |
|
406 |
# =========================
|
@@ -439,8 +441,7 @@ async def text_embedding(
|
|
439 |
"""
|
440 |
|
441 |
model_name = model_name.rstrip("/")
|
442 |
-
|
443 |
-
return await textEmbeddingTask.generate_embedding(request, model_name)
|
444 |
|
445 |
|
446 |
# =========================
|
@@ -483,6 +484,5 @@ async def embedding_vector_size(
|
|
483 |
"""
|
484 |
|
485 |
model_name = model_name.rstrip("/")
|
486 |
-
# We can use either
|
487 |
-
|
488 |
-
return await embeddingTask.get_embedding_vector_size(model_name)
|
|
|
29 |
logger = logging.getLogger(__name__)
|
30 |
logger.setLevel(logging.DEBUG)
|
31 |
|
32 |
+
# Create singleton instances of embedding services to enable model caching across requests
|
33 |
+
image_embedding_service = ImageEmbeddingTaskService(logger)
|
34 |
+
text_embedding_service = TextEmbeddingTaskService(logger)
|
35 |
+
|
36 |
|
37 |
class StreamToLogger(object):
|
38 |
def __init__(self, logger, log_level):
|
|
|
337 |
"""
|
338 |
|
339 |
model_name = model_name.rstrip("/")
|
340 |
+
return await image_embedding_service.generate_embedding(request, model_name)
|
|
|
341 |
|
342 |
|
343 |
# =========================
|
|
|
402 |
"""
|
403 |
|
404 |
model_name = model_name.rstrip("/")
|
405 |
+
return await image_embedding_service.generate_embedding_from_upload(image, model_name)
|
|
|
406 |
|
407 |
|
408 |
# =========================
|
|
|
441 |
"""
|
442 |
|
443 |
model_name = model_name.rstrip("/")
|
444 |
+
return await text_embedding_service.generate_embedding(request, model_name)
|
|
|
445 |
|
446 |
|
447 |
# =========================
|
|
|
484 |
"""
|
485 |
|
486 |
model_name = model_name.rstrip("/")
|
487 |
+
# We can use either embedding service as they inherit from the same base class
|
488 |
+
return await image_embedding_service.get_embedding_vector_size(model_name)
|
|