File size: 3,184 Bytes
1c18375
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
""" dspy_utils.py

Utilities for building a DSPy based retrieval (augmented) generation model.

:author: Didier Guillevic
:email: [email protected]
:creation: 2024-12-21
"""

import os
import dspy
from ragatouille import RAGPretrainedModel

import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class DSPyRagModel:
    def __init__(self, retrieval_model: RAGPretrainedModel):

        # Init the retrieval and language model
        self.retrieval_model = retrieval_model
        self.language_model = dspy.LM(model="mistral/mistral-large-latest", api_key=os.environ["MISTRAL_API_KEY"])

        # Set dspy retrieval and language model
        dspy.settings.configure(
            lm=self.language_model,
            rm=self.retrieval_model
        )

        # Set dspy prediction functions
        class BasicQA(dspy.Signature):
            """Answer the question given the context provided"""
            context = dspy.InputField(desc="may contain relevant facts")
            question = dspy.InputField()
            answer = dspy.OutputField(desc="Answer the given question.")

        self.predict = dspy.Predict(BasicQA, temperature=0.01)
        self.predict_chain_of_thought = dspy.ChainOfThought(BasicQA)

    def generate_response(
            self,
            question: str,
            k: int=3,
            method: str = 'chain_of_thought'
        ) -> tuple[str, str, str]:
        """Generate a response to a given question using the specified method.
        
        Args:
            question: the question to answer
            k: number of passages to retrieve
            method: method for generating the response: ['simple', 'chain_of_thought']
        
        Returns:
            - the generated answer
            - (html string): the references (origin of the snippets of text used to generate the answer)
            - (html string): the snippets of text used to generate the answer
        """
        # Retrieval
        retrieval_results = self.retrieval_model.search(query=question, k=k)
        passages = [res.get('content') for res in retrieval_results]
        metadatas = [res.get('document_metadata') for res in retrieval_results]

        # Generate response given retrieved passages
        if method == 'simple':
            response = self.predict(context=passages, question=question).answer
        elif method == 'chain_of_thought':
            response = self.predict_chain_of_thought(context=passages, question=question).answer
        else:
            raise ValueError(f"Unknown method: {method}. Expected ['simple', 'chain_of_thought']")
        
        # Create an HTML string with the references
        references = "<h4>References</h4>\n" + create_bulleted_list(metadatas)
        snippets = "<h4>Snippets</h4>\n" + create_bulleted_list(passages)

        return response, references, snippets


def create_bulleted_list(texts: list[str]) -> str:
    """
    This function takes a list of strings and returns HTML with a bulleted list.
    """
    html_items = []
    for item in texts:
        html_items.append(f"<li>{item}</li>")
    return "<ul>" + "".join(html_items) + "</ul>"