| | from typing import Any, Dict, List |
| |
|
| | from colbert.infra import ColBERTConfig |
| | from colbert.modeling.checkpoint import Checkpoint |
| | import torch |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | MODEL = "fdurant/colbert-xm-for-inference-api" |
| |
|
| | class EndpointHandler(): |
| |
|
| | def __init__(self, path=""): |
| | self._config = ColBERTConfig( |
| | |
| | doc_maxlen=512, |
| | nbits=2, |
| | kmeans_niters=4, |
| | nranks=-1, |
| | checkpoint=MODEL, |
| | ) |
| | self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3) |
| |
|
| | def __call__(self, data: Any) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str`) |
| | Return: |
| | A :obj:`list` : will be serialized and returned. |
| | When the input is a single query string, the returned list will contain a single dictionary with: |
| | - input (:obj: `str`) : The input query. |
| | - query_embedding (:obj: `list`) : The query embedding of shape (1, 32, 128). |
| | When the input is a batch (= list) of chunk strings, the returned list will contain a dictionary for each chunk: |
| | - input (:obj: `str`) : The input chunk. |
| | - chunk_embedding (:obj: `list`) : The chunk embedding of shape (1, num_tokens, 128) |
| | - token_ids (:obj: `list`) : The token ids. |
| | - token_list (:obj: `list`) : The token list. |
| | """ |
| | inputs = data["inputs"] |
| | texts = [] |
| | if isinstance(inputs, str): |
| | texts = [inputs] |
| | elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs): |
| | texts = inputs |
| | else: |
| | raise ValueError("Invalid input data format") |
| | with torch.inference_mode(): |
| | |
| | if len(texts) == 1: |
| | |
| | logger.info(f"Received query of 1 text with {len(texts[0])} characters and {len(texts[0].split())} words") |
| | embedding = self._checkpoint.queryFromText( |
| | queries=texts, |
| | full_length_search=False, |
| | ) |
| | logger.info(f"Query embedding shape: {embedding.shape}") |
| | return [ |
| | {"input": inputs, "query_embedding": embedding.tolist()[0]} |
| | ] |
| | elif len(texts) > 1: |
| | |
| | logger.info(f"Received batch of {len(texts)} chunks") |
| | for i, text in enumerate(texts): |
| | logger.info(f"Chunk {i} has {len(text)} characters and {len(text.split())} words") |
| | embeddings, token_id_lists = self._checkpoint.docFromText( |
| | docs=texts, |
| | bsize=self._config.bsize, |
| | keep_dims=True, |
| | return_tokens=True, |
| | ) |
| | logger.info(f"Chunk embeddings shape: {embeddings.shape}") |
| | token_lists = [] |
| | for text, embedding, token_ids in zip(texts, embeddings, token_id_lists): |
| | logger.debug(f"Chunk: {text}") |
| | logger.debug(f"Chunk embedding shape: {embedding.shape}") |
| | logger.debug(f"Chunk token ids: {token_ids}") |
| | token_list = self._checkpoint.doc_tokenizer.tok.convert_ids_to_tokens(token_ids) |
| | token_lists.append(token_list) |
| | logger.debug(f"Chunk tokens: {token_list}") |
| | |
| | |
| | return [ |
| | {"input": _input, "chunk_embedding": embedding.tolist(), "token_ids": token_ids.tolist(), "token_list": token_list} |
| | for _input, embedding, token_ids, token_list in zip(texts, embeddings, token_id_lists, token_lists) |
| | ] |
| | else: |
| | raise ValueError("No data to process") |
| |
|