File size: 3,154 Bytes
7bf4b88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from typing import List, Union

loaded_llm_models = {}


def get_llm2vec_embeddings(text: Union[str, List[str]], 
                           model_name: str = 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp', 
                           peft_model_name: str = 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse',
                           instruction: str = '',
                           device: str = 'cuda', 
                           norm=True) -> torch.Tensor:
    """
    Get LLM2Vec embeddings for the given text.

    Args:
        text (Union[str, List[str]]): The input text to be embedded.
        model_name (str): The model to use for embedding. 
        peft_model_name (str): The model to use for PEFT embeddings. 
    
    Returns:
        torch.Tensor: The embedding(s) of the input text(s).
    """
    try:
        from llm2vec import LLM2Vec
    except ImportError:
        raise ImportError("Please install the llm2vec package using `pip install llm2vec`.")

    if peft_model_name in loaded_llm_models:
        l2v = loaded_llm_models[peft_model_name]
    else:
        l2v = LLM2Vec.from_pretrained(
            model_name,
            peft_model_name_or_path=peft_model_name,
            device_map=device,
            torch_dtype=torch.bfloat16,
        )
        loaded_llm_models[peft_model_name] = l2v
    
    if isinstance(text, str):
        text = [text]

    if len(instruction) > 0:
        text = [[instruction, t] for t in text]
    embeddings = l2v.encode(text, batch_size=len(text))
    if norm:
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    return embeddings.view(len(text), -1)


def get_gritlm_embeddings(text: Union[str, List[str]],
                          model_name: str = 'GritLM/GritLM-7B', 
                          instruction: str = '',
                          device: str = 'cuda'
                          ) -> torch.Tensor:

        try:
            from gritlm import GritLM
        except ImportError:
            raise ImportError("Please install the gritlm package using `pip install gritlm`.")

        def gritlm_instruction(instruction):
            return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n"

        """
        Get GritLM embeddings for the given text.
    
        Args:
            text (Union[str, List[str]]): The input text to be embedded.
            instruction (str): The instruction to be used for GritLM.
            model_name (str): The model to use for embedding.

        Returns:
            torch.Tensor: The embedding(s) of the input text(s).
        """

        if model_name in loaded_llm_models:
            gritlm_model = loaded_llm_models[model_name]
        else:
            gritlm_model = GritLM(model_name, torch_dtype=torch.bfloat16)
            loaded_llm_models[model_name] = gritlm_model
        
        if isinstance(text, str):
            text = [text]
    
        embeddings = gritlm_model.encode(text, instruction=gritlm_instruction(instruction))
        embeddings = torch.from_numpy(embeddings)
        return embeddings.view(len(text), -1)