Morgan Funtowicz commited on
Commit
5e1abf0
·
1 Parent(s): 6ce5654

feat(embeddings): do not tokenize twice

Browse files
Files changed (1) hide show
  1. handler.py +61 -20
handler.py CHANGED
@@ -1,13 +1,17 @@
1
  import platform
2
- from typing import Union, Sequence, Sized
 
 
3
 
4
  import torch
5
  from hfendpoints.openai import Context, run
6
  from hfendpoints.openai.embeddings import Embedding, EmbeddingEndpoint, EmbeddingRequest, EmbeddingResponse, Usage
7
- from hfendpoints import EndpointConfig, Handler, __version__
8
  from loguru import logger
9
- from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF
10
  from sentence_transformers import SentenceTransformer
 
 
11
 
12
  # Not used for now
13
  SUPPORTED_AMP_DTYPES = {torch.float32, torch.bfloat16}
@@ -27,17 +31,47 @@ def has_bf16_support() -> bool:
27
  return torch.cpu._is_avx512_bf16_supported() or torch.cpu._is_avx512_supported()
28
 
29
 
30
- def get_usage(tokens: Union[Sized, Sequence[Sized]]) -> Usage:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
  Compute the number of processed tokens and return as Usage object matching OpenAI
33
- :param tokens: List or nested List of tokens
34
  :return: Usage object matching OpenAI specifications
35
  """
36
- num_tokens = tokens["attention_mask"].sum().item()
37
  return Usage(prompt_tokens=num_tokens, total_tokens=num_tokens)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class SentenceTransformerHandler(Handler):
40
- __slots__ = ("_config", "_dtype", "_model", "_model_name", "_use_amp")
41
 
42
  def __init__(self, config: EndpointConfig):
43
  self._config = config
@@ -47,44 +81,51 @@ class SentenceTransformerHandler(Handler):
47
  self._allocate_model()
48
 
49
  def _allocate_model(self):
50
- # Denormal number is used to store extremely small numbers that are close to 0.
51
  # Computations with denormal numbers are remarkably slower than normalized number.
52
  torch.set_flush_denormal(True)
53
 
54
  dtype = torch.bfloat16 if has_bf16_support() else torch.float32
55
  model = SentenceTransformer(self._config.model_id, device="cpu", model_kwargs={"torch_dtype": dtype})
56
 
 
57
  if platform.machine() == "x86_64":
58
  import intel_extension_for_pytorch as ipex
59
  logger.info(f"x64 platform detected: {platform.processor()}")
60
 
 
 
 
 
61
  with torch.inference_mode():
62
  model = model.eval()
63
  model = model.to(memory_format=torch.channels_last)
64
- model = ipex.optimize(model, dtype=dtype, weights_prepack=False, graph_mode=True, concat_linear=True)
 
 
65
  model = torch.compile(model, dynamic=True, backend="ipex")
 
 
 
66
  else:
67
  model = torch.compile(model)
68
 
69
- self._model = model
70
  self._dtype = dtype
71
  self._use_amp = dtype in SUPPORTED_AMP_DTYPES
 
72
 
73
  async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
74
  with torch.backends.mkldnn.verbose(VERBOSE_ON_CREATION if self._config.is_debug else VERBOSE_OFF):
75
  with torch.inference_mode(), torch.amp.autocast("cpu", dtype=self._dtype, enabled=self._use_amp):
76
- tokens = self._model.tokenize(request.input)
77
- vectors = self._model.encode(request.input)
78
 
79
  embeddings = [None] * len(request)
80
- if not request.is_batched:
81
- embeddings[0] = Embedding(index=0, embedding=vectors.tolist())
82
- else:
83
- for (index, embedding) in enumerate(vectors.tolist()):
84
- embedding = Embedding(index=index, embedding=embedding)
85
- embeddings[index] = embedding
86
-
87
- usage = get_usage(tokens)
88
  return EmbeddingResponse(model=self._model_name, embeddings=embeddings, usage=usage)
89
 
90
 
 
1
  import platform
2
+ from functools import reduce
3
+ from operator import itemgetter
4
+ from typing import Generator, Tuple
5
 
6
  import torch
7
  from hfendpoints.openai import Context, run
8
  from hfendpoints.openai.embeddings import Embedding, EmbeddingEndpoint, EmbeddingRequest, EmbeddingResponse, Usage
9
+ from intel_extension_for_pytorch.cpu.runtime import pin
10
  from loguru import logger
11
+ from hfendpoints import EndpointConfig, Handler, __version__
12
  from sentence_transformers import SentenceTransformer
13
+ from torch.nn import Module
14
+ from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF
15
 
16
  # Not used for now
17
  SUPPORTED_AMP_DTYPES = {torch.float32, torch.bfloat16}
 
31
  return torch.cpu._is_avx512_bf16_supported() or torch.cpu._is_avx512_supported()
32
 
33
 
34
+ def get_cores_pinning_strategy() -> "CPUPool":
35
+ import intel_extension_for_pytorch as ipex
36
+
37
+ # Retrieve the number of nodes
38
+ num_nodes = ipex.cpu.runtime.runtime_utils.get_num_nodes()
39
+ cpu_cores_id = [ipex.cpu.runtime.runtime_utils.get_core_list_of_node_id(node_id) for node_id in range(num_nodes)]
40
+
41
+ if num_nodes == 1:
42
+ pinned_cpu_cores_id = cpu_cores_id[0]
43
+ else:
44
+ pinned_cpu_cores_id = [core_id for node in cpu_cores_id for core_id in node]
45
+
46
+ logger.info(f"Pinning CPU cores to {pinned_cpu_cores_id}")
47
+ return ipex.cpu.runtime.CPUPool(pinned_cpu_cores_id)
48
+ # return ipex.cpu.runtime.CPUPool(node_id=0)
49
+
50
+
51
+ def get_usage(mask: torch.IntTensor) -> Usage:
52
  """
53
  Compute the number of processed tokens and return as Usage object matching OpenAI
54
+ :param mask: Attention mask tensor, as returned by the model
55
  :return: Usage object matching OpenAI specifications
56
  """
57
+ num_tokens = sum(m.sum().item() for m in mask)
58
  return Usage(prompt_tokens=num_tokens, total_tokens=num_tokens)
59
 
60
+
61
+ class SentenceTransformerWithUsage(Module):
62
+ __slots__ = ("_model", )
63
+
64
+ def __init__(self, model: SentenceTransformer):
65
+ super().__init__()
66
+ self._model = model
67
+
68
+ def forward(self, sentences: list[str]) -> Tuple[Generator[torch.Tensor], Generator[torch.Tensor]]:
69
+ vectors = self._model.encode(sentences, output_value=None)
70
+ return map(itemgetter('attention_mask'), vectors), map(itemgetter('sentence_embedding'), vectors)
71
+
72
+
73
  class SentenceTransformerHandler(Handler):
74
+ __slots__ = ("_config", "_dtype", "_model", "_model_name", "_pinned_cores", "_use_amp")
75
 
76
  def __init__(self, config: EndpointConfig):
77
  self._config = config
 
81
  self._allocate_model()
82
 
83
  def _allocate_model(self):
84
+ # Denormal number is used to store tiny numbers that are close to 0.
85
  # Computations with denormal numbers are remarkably slower than normalized number.
86
  torch.set_flush_denormal(True)
87
 
88
  dtype = torch.bfloat16 if has_bf16_support() else torch.float32
89
  model = SentenceTransformer(self._config.model_id, device="cpu", model_kwargs={"torch_dtype": dtype})
90
 
91
+
92
  if platform.machine() == "x86_64":
93
  import intel_extension_for_pytorch as ipex
94
  logger.info(f"x64 platform detected: {platform.processor()}")
95
 
96
+ # Retrieve all the physical cores ID for all the CPU nodes
97
+ self._pinned_cores = get_cores_pinning_strategy()
98
+
99
+ # Optimize the model for inference
100
  with torch.inference_mode():
101
  model = model.eval()
102
  model = model.to(memory_format=torch.channels_last)
103
+
104
+ # Apply IPEx optimizations
105
+ model = ipex.optimize(model, dtype=dtype, weights_prepack=True, graph_mode=True, concat_linear=True)
106
  model = torch.compile(model, dynamic=True, backend="ipex")
107
+
108
+ # model = ipex.cpu.runtime.MultiStreamModule(SentenceTransformerWithUsage(model), num_streams=1)
109
+
110
  else:
111
  model = torch.compile(model)
112
 
 
113
  self._dtype = dtype
114
  self._use_amp = dtype in SUPPORTED_AMP_DTYPES
115
+ self._model = SentenceTransformerWithUsage(model)
116
 
117
  async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
118
  with torch.backends.mkldnn.verbose(VERBOSE_ON_CREATION if self._config.is_debug else VERBOSE_OFF):
119
  with torch.inference_mode(), torch.amp.autocast("cpu", dtype=self._dtype, enabled=self._use_amp):
120
+ with pin(self._pinned_cores):
121
+ mask, vectors = self._model(request.input if request.is_batched else [request.input])
122
 
123
  embeddings = [None] * len(request)
124
+ for (index, embedding) in enumerate(vectors):
125
+ embedding = Embedding(index=index, embedding=embedding.tolist())
126
+ embeddings[index] = embedding
127
+
128
+ usage = get_usage(mask)
 
 
 
129
  return EmbeddingResponse(model=self._model_name, embeddings=embeddings, usage=usage)
130
 
131