File size: 1,614 Bytes
fa3eb69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
from dotenv import load_dotenv
load_dotenv()

from typing import Dict, Tuple
from collections.abc import Callable

from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings

from policy_rag.text_utils import get_recursive_token_chunks, get_semantic_chunks



# Config Options
CHUNK_METHOD = {
    'token-overlap': get_recursive_token_chunks,
    'semantic': get_semantic_chunks
}

EMBEDDING_MODEL_SOURCE = {
    'openai': OpenAIEmbeddings,
    'huggingface': HuggingFaceInferenceAPIEmbeddings
}


# Helpers
def get_chunk_func(chunk_method: Dict) -> Tuple[Callable, Dict]:
    chunk_func = CHUNK_METHOD[chunk_method['method']]

    if chunk_method['method'] == 'token-overlap':
        chunk_func_args = chunk_method['args']
    
    if chunk_method['method'] == 'semantic':
        args = chunk_method['args']
        chunk_func_args = {
            'embedding_model': EMBEDDING_MODEL_SOURCE[args['model_source']](model=args['model_name']),
            'breakpoint_type': args['breakpoint_type']
        }

    return chunk_func, chunk_func_args


def get_embedding_model(config) -> OpenAIEmbeddings | HuggingFaceInferenceAPIEmbeddings:
    if config['model_source'] == 'openai':
        model = EMBEDDING_MODEL_SOURCE[config['model_source']](model=config['model_name'])

    if config['model_source'] == 'huggingface':
        model = EMBEDDING_MODEL_SOURCE[config['model_source']](
            api_key=os.getenv('HF_API_KEY'),
            model_name=config['model_name'],
            api_url=config['api_url']
        )

    return model