Spaces:
Sleeping
Sleeping
File size: 1,614 Bytes
fa3eb69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import os
from dotenv import load_dotenv
load_dotenv()
from typing import Dict, Tuple
from collections.abc import Callable
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from policy_rag.text_utils import get_recursive_token_chunks, get_semantic_chunks
# Config Options
CHUNK_METHOD = {
'token-overlap': get_recursive_token_chunks,
'semantic': get_semantic_chunks
}
EMBEDDING_MODEL_SOURCE = {
'openai': OpenAIEmbeddings,
'huggingface': HuggingFaceInferenceAPIEmbeddings
}
# Helpers
def get_chunk_func(chunk_method: Dict) -> Tuple[Callable, Dict]:
chunk_func = CHUNK_METHOD[chunk_method['method']]
if chunk_method['method'] == 'token-overlap':
chunk_func_args = chunk_method['args']
if chunk_method['method'] == 'semantic':
args = chunk_method['args']
chunk_func_args = {
'embedding_model': EMBEDDING_MODEL_SOURCE[args['model_source']](model=args['model_name']),
'breakpoint_type': args['breakpoint_type']
}
return chunk_func, chunk_func_args
def get_embedding_model(config) -> OpenAIEmbeddings | HuggingFaceInferenceAPIEmbeddings:
if config['model_source'] == 'openai':
model = EMBEDDING_MODEL_SOURCE[config['model_source']](model=config['model_name'])
if config['model_source'] == 'huggingface':
model = EMBEDDING_MODEL_SOURCE[config['model_source']](
api_key=os.getenv('HF_API_KEY'),
model_name=config['model_name'],
api_url=config['api_url']
)
return model |