fashxp commited on
Commit
15a969b
·
1 Parent(s): 8f80642

added caches

Browse files
Files changed (3) hide show
  1. requirements.txt +3 -1
  2. src/embeddings.py +7 -3
  3. src/main.py +9 -9
requirements.txt CHANGED
@@ -7,4 +7,6 @@ sacremoses
7
  torch
8
  pillow
9
  protobuf
10
- # Optional dependencies for specific features
 
 
 
7
  torch
8
  pillow
9
  protobuf
10
+
11
+ # Optional dependencies for specific features
12
+ einops
src/embeddings.py CHANGED
@@ -62,7 +62,7 @@ class BaseEmbeddingTaskService:
62
  """Load and cache processor for the model using AutoProcessor"""
63
  if model_name not in self._processor_cache:
64
  try:
65
- self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name)
66
  self._logger.info(f"Loaded processor for model: {model_name}")
67
  except Exception as e:
68
  self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}")
@@ -70,6 +70,8 @@ class BaseEmbeddingTaskService:
70
  status_code=404,
71
  detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}"
72
  )
 
 
73
  return self._processor_cache[model_name]
74
 
75
  def _load_model(self, model_name: str, cache_suffix: str = ""):
@@ -78,7 +80,7 @@ class BaseEmbeddingTaskService:
78
  if cache_key not in self._model_cache:
79
  try:
80
  device = self._get_device()
81
- model = AutoModel.from_pretrained(model_name)
82
  model.to(device)
83
  self._model_cache[cache_key] = model
84
  self._logger.info(f"Loaded model: {model_name} on {device}")
@@ -88,6 +90,8 @@ class BaseEmbeddingTaskService:
88
  status_code=404,
89
  detail=f"Model '{model_name}' could not be loaded: {str(e)}"
90
  )
 
 
91
  return self._model_cache[cache_key]
92
 
93
  async def get_embedding_vector_size(self, model_name: str) -> dict:
@@ -335,7 +339,7 @@ class TextEmbeddingTaskService(BaseEmbeddingTaskService):
335
  """Main method to generate text embeddings"""
336
  embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
337
 
338
- self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:50]}...")
339
 
340
  # Load processor and model using auto-detection
341
  processor = self._load_processor(model_name)
 
62
  """Load and cache processor for the model using AutoProcessor"""
63
  if model_name not in self._processor_cache:
64
  try:
65
+ self._processor_cache[model_name] = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
66
  self._logger.info(f"Loaded processor for model: {model_name}")
67
  except Exception as e:
68
  self._logger.error(f"Failed to load processor for model '{model_name}': {str(e)}")
 
70
  status_code=404,
71
  detail=f"Processor for model '{model_name}' could not be loaded: {str(e)}"
72
  )
73
+ else:
74
+ self._logger.info(f"Using cached processor for model: {model_name}")
75
  return self._processor_cache[model_name]
76
 
77
  def _load_model(self, model_name: str, cache_suffix: str = ""):
 
80
  if cache_key not in self._model_cache:
81
  try:
82
  device = self._get_device()
83
+ model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
84
  model.to(device)
85
  self._model_cache[cache_key] = model
86
  self._logger.info(f"Loaded model: {model_name} on {device}")
 
90
  status_code=404,
91
  detail=f"Model '{model_name}' could not be loaded: {str(e)}"
92
  )
93
+ else:
94
+ self._logger.info(f"Using cached model: {model_name} (cache key: {cache_key})")
95
  return self._model_cache[cache_key]
96
 
97
  async def get_embedding_vector_size(self, model_name: str) -> dict:
 
339
  """Main method to generate text embeddings"""
340
  embedding_request: EmbeddingRequest = await self.get_embedding_request(request)
341
 
342
+ self._logger.info(f"Generating text embedding for: {embedding_request.inputs[:500]}...")
343
 
344
  # Load processor and model using auto-detection
345
  processor = self._load_processor(model_name)
src/main.py CHANGED
@@ -29,6 +29,10 @@ logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s')
29
  logger = logging.getLogger(__name__)
30
  logger.setLevel(logging.DEBUG)
31
 
 
 
 
 
32
 
33
  class StreamToLogger(object):
34
  def __init__(self, logger, log_level):
@@ -333,8 +337,7 @@ async def image_embedding(
333
  """
334
 
335
  model_name = model_name.rstrip("/")
336
- imageEmbeddingTask = ImageEmbeddingTaskService(logger)
337
- return await imageEmbeddingTask.generate_embedding(request, model_name)
338
 
339
 
340
  # =========================
@@ -399,8 +402,7 @@ async def image_embedding_upload(
399
  """
400
 
401
  model_name = model_name.rstrip("/")
402
- imageEmbeddingTask = ImageEmbeddingTaskService(logger)
403
- return await imageEmbeddingTask.generate_embedding_from_upload(image, model_name)
404
 
405
 
406
  # =========================
@@ -439,8 +441,7 @@ async def text_embedding(
439
  """
440
 
441
  model_name = model_name.rstrip("/")
442
- textEmbeddingTask = TextEmbeddingTaskService(logger)
443
- return await textEmbeddingTask.generate_embedding(request, model_name)
444
 
445
 
446
  # =========================
@@ -483,6 +484,5 @@ async def embedding_vector_size(
483
  """
484
 
485
  model_name = model_name.rstrip("/")
486
- # We can use either ImageEmbeddingTaskService or TextEmbeddingTaskService as they inherit from the same base class
487
- embeddingTask = ImageEmbeddingTaskService(logger)
488
- return await embeddingTask.get_embedding_vector_size(model_name)
 
29
  logger = logging.getLogger(__name__)
30
  logger.setLevel(logging.DEBUG)
31
 
32
+ # Create singleton instances of embedding services to enable model caching across requests
33
+ image_embedding_service = ImageEmbeddingTaskService(logger)
34
+ text_embedding_service = TextEmbeddingTaskService(logger)
35
+
36
 
37
  class StreamToLogger(object):
38
  def __init__(self, logger, log_level):
 
337
  """
338
 
339
  model_name = model_name.rstrip("/")
340
+ return await image_embedding_service.generate_embedding(request, model_name)
 
341
 
342
 
343
  # =========================
 
402
  """
403
 
404
  model_name = model_name.rstrip("/")
405
+ return await image_embedding_service.generate_embedding_from_upload(image, model_name)
 
406
 
407
 
408
  # =========================
 
441
  """
442
 
443
  model_name = model_name.rstrip("/")
444
+ return await text_embedding_service.generate_embedding(request, model_name)
 
445
 
446
 
447
  # =========================
 
484
  """
485
 
486
  model_name = model_name.rstrip("/")
487
+ # We can use either embedding service as they inherit from the same base class
488
+ return await image_embedding_service.get_embedding_vector_size(model_name)