File size: 5,837 Bytes
28b236b
da0d374
 
5d85d64
da0d374
e401df7
da0d374
e401df7
387433f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da0d374
387433f
 
 
ec642b7
da0d374
 
387433f
 
da0d374
387433f
da0d374
387433f
da0d374
 
 
 
387433f
da0d374
387433f
da0d374
387433f
da0d374
387433f
da0d374
 
 
 
387433f
da0d374
 
387433f
da0d374
387433f
 
da0d374
 
 
387433f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da0d374
 
 
387433f
da0d374
 
387433f
da0d374
 
 
387433f
da0d374
387433f
 
da0d374
 
 
 
 
ec642b7
 
387433f
da0d374
387433f
 
da0d374
5d85d64
387433f
da0d374
387433f
 
 
 
e401df7
 
5d85d64
387433f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d85d64
387433f
5d85d64
387433f
5d85d64
387433f
da0d374
387433f
 
da0d374
387433f
 
da0d374
387433f
 
 
5d85d64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from smolagents import CodeAgent, Model, ChatMessage
import tools.tools as tls  # Your tool definitions

load_dotenv()

"""
    enforce_strict_role_alternation()

    Ensures that messages follow the required pattern:
    'user/assistant/user/assistant/...', starting with an optional 'system' message.
    
    This is necessary because many chat-based models (e.g., ChatCompletion APIs)
    expect the conversation format to alternate strictly between user and assistant roles,
    possibly preceded by a single system message.

    Parameters:
    -----------
    messages : list of dict
        The message history. Each message is expected to be a dictionary with a 'role' key 
        ('user', 'assistant', or 'system') and a 'content' key.

    Returns:
    --------
    cleaned : list of dict
        A sanitized version of the messages list that follows the correct role alternation rules.
    """
def enforce_strict_role_alternation(messages):
    cleaned = []        # List to store the cleaned message sequence
    last_role = None    # Tracks the last valid role added to ensure alternation

    for msg in messages:
        role = msg["role"]

        # Skip any message that doesn't have a valid role
        if role not in ("user", "assistant", "system"):
            continue

        # Allow a single 'system' message only at the very beginning
        if role == "system" and not cleaned:
            cleaned.append(msg)
            continue

        # Skip messages with the same role as the previous one (breaks alternation)
        if role == last_role:
            continue

        # Add the valid message to the cleaned list
        cleaned.append(msg)
        last_role = role  # Update the last role for the next iteration

    return cleaned


# Define a custom model class that wraps around Hugging Face's InferenceClient for chat-based models
class HuggingFaceChatModel(Model):
    def __init__(self):
        # Set the model ID for the specific Hugging Face model to use
        model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"

        # Create an InferenceClient with the model ID and the Hugging Face token from your environment
        self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))

    def generate(self, messages, stop_sequences=None):
        """
        Generates a response from the chat model based on the input message history.

        Parameters:
        -----------
        messages : list of dict
            A list of message dicts in OpenAI-style format, e.g.:
            [{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}]

        stop_sequences : list of str, optional
            A list of strings that will stop generation when encountered. Default is ["Task"].

        Returns:
        --------
        ChatMessage
            A formatted response object with role='assistant' and the model-generated content.
        """

        # Set default stop sequences if none provided
        if stop_sequences is None:
            stop_sequences = ["Task"]

        # πŸ’‘ Preprocess: Enforce valid alternation of user/assistant messages
        cleaned_messages = enforce_strict_role_alternation(messages)

        # πŸ”§ Call the Hugging Face chat API with cleaned messages
        response = self.client.chat_completion(
            messages=cleaned_messages,
            stop=stop_sequences,
            max_tokens=1024  # Limit the number of tokens generated in the reply
        )

        # πŸ“¦ Extract content from the model response and wrap it in a ChatMessage object
        content = response.choices[0].message["content"]
        return ChatMessage(role="assistant", content=content)


# βœ… Basic Agent with SmolAgents
class BasicAgent:
    def __init__(self):
        # Informative log to indicate that the agent is being initialized
        print("βœ… BasicAgent initialized with Hugging Face chat model.")

        # Instantiate your custom model that wraps the Hugging Face InferenceClient
        self.model = HuggingFaceChatModel()
        
        # Create the CodeAgent, which uses the tools and the chat model
        self.agent = CodeAgent(
            tools=[tls.search_tool, tls.calculate_cargo_travel_time],  # Your list of tools
            model=self.model,  # The model to generate tool-using responses
            additional_authorized_imports=["pandas"],  # Optional: allow use of pandas in generated code
            max_steps=20,  # Limit the number of planning steps (tool calls + reasoning)
        )

    def __call__(self, messages) -> str:
        """
        Handle a call to the agent with either a single question or a message history.

        Parameters:
        -----------
        messages : Union[str, List[Dict[str, str]]]
            The input from the chat interface β€” either:
            - a plain string (just one message)
            - a list of dicts, like [{"role": "user", "content": "What's the weather?"}]

        Returns:
        --------
        str
            The assistant's response as a string.
        """

        # If the input is a chat history (list of messages), get the most recent user message
        if isinstance(messages, list):
            question = messages[-1]["content"]  # Extract last message content
        else:
            question = messages  # If it's just a string, use it directly

        # Log the input for debugging
        print(f"πŸ“₯ Received question: {question[:60]}...")

        # Run the CodeAgent to get a response (may include tool use)
        response = self.agent.run(question)

        # Log the response for debugging
        print(f"πŸ“€ Response generated: {response[:60]}...")

        return response  # Return final result