File size: 5,837 Bytes
28b236b da0d374 5d85d64 da0d374 e401df7 da0d374 e401df7 387433f da0d374 387433f ec642b7 da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 387433f da0d374 ec642b7 387433f da0d374 387433f da0d374 5d85d64 387433f da0d374 387433f e401df7 5d85d64 387433f 5d85d64 387433f 5d85d64 387433f 5d85d64 387433f da0d374 387433f da0d374 387433f da0d374 387433f 5d85d64 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from smolagents import CodeAgent, Model, ChatMessage
import tools.tools as tls # Your tool definitions
load_dotenv()
"""
enforce_strict_role_alternation()
Ensures that messages follow the required pattern:
'user/assistant/user/assistant/...', starting with an optional 'system' message.
This is necessary because many chat-based models (e.g., ChatCompletion APIs)
expect the conversation format to alternate strictly between user and assistant roles,
possibly preceded by a single system message.
Parameters:
-----------
messages : list of dict
The message history. Each message is expected to be a dictionary with a 'role' key
('user', 'assistant', or 'system') and a 'content' key.
Returns:
--------
cleaned : list of dict
A sanitized version of the messages list that follows the correct role alternation rules.
"""
def enforce_strict_role_alternation(messages):
cleaned = [] # List to store the cleaned message sequence
last_role = None # Tracks the last valid role added to ensure alternation
for msg in messages:
role = msg["role"]
# Skip any message that doesn't have a valid role
if role not in ("user", "assistant", "system"):
continue
# Allow a single 'system' message only at the very beginning
if role == "system" and not cleaned:
cleaned.append(msg)
continue
# Skip messages with the same role as the previous one (breaks alternation)
if role == last_role:
continue
# Add the valid message to the cleaned list
cleaned.append(msg)
last_role = role # Update the last role for the next iteration
return cleaned
# Define a custom model class that wraps around Hugging Face's InferenceClient for chat-based models
class HuggingFaceChatModel(Model):
def __init__(self):
# Set the model ID for the specific Hugging Face model to use
model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# Create an InferenceClient with the model ID and the Hugging Face token from your environment
self.client = InferenceClient(model=model_id, token=os.getenv("HF_TOKEN"))
def generate(self, messages, stop_sequences=None):
"""
Generates a response from the chat model based on the input message history.
Parameters:
-----------
messages : list of dict
A list of message dicts in OpenAI-style format, e.g.:
[{"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi!"}]
stop_sequences : list of str, optional
A list of strings that will stop generation when encountered. Default is ["Task"].
Returns:
--------
ChatMessage
A formatted response object with role='assistant' and the model-generated content.
"""
# Set default stop sequences if none provided
if stop_sequences is None:
stop_sequences = ["Task"]
# π‘ Preprocess: Enforce valid alternation of user/assistant messages
cleaned_messages = enforce_strict_role_alternation(messages)
# π§ Call the Hugging Face chat API with cleaned messages
response = self.client.chat_completion(
messages=cleaned_messages,
stop=stop_sequences,
max_tokens=1024 # Limit the number of tokens generated in the reply
)
# π¦ Extract content from the model response and wrap it in a ChatMessage object
content = response.choices[0].message["content"]
return ChatMessage(role="assistant", content=content)
# β
Basic Agent with SmolAgents
class BasicAgent:
def __init__(self):
# Informative log to indicate that the agent is being initialized
print("β
BasicAgent initialized with Hugging Face chat model.")
# Instantiate your custom model that wraps the Hugging Face InferenceClient
self.model = HuggingFaceChatModel()
# Create the CodeAgent, which uses the tools and the chat model
self.agent = CodeAgent(
tools=[tls.search_tool, tls.calculate_cargo_travel_time], # Your list of tools
model=self.model, # The model to generate tool-using responses
additional_authorized_imports=["pandas"], # Optional: allow use of pandas in generated code
max_steps=20, # Limit the number of planning steps (tool calls + reasoning)
)
def __call__(self, messages) -> str:
"""
Handle a call to the agent with either a single question or a message history.
Parameters:
-----------
messages : Union[str, List[Dict[str, str]]]
The input from the chat interface β either:
- a plain string (just one message)
- a list of dicts, like [{"role": "user", "content": "What's the weather?"}]
Returns:
--------
str
The assistant's response as a string.
"""
# If the input is a chat history (list of messages), get the most recent user message
if isinstance(messages, list):
question = messages[-1]["content"] # Extract last message content
else:
question = messages # If it's just a string, use it directly
# Log the input for debugging
print(f"π₯ Received question: {question[:60]}...")
# Run the CodeAgent to get a response (may include tool use)
response = self.agent.run(question)
# Log the response for debugging
print(f"π€ Response generated: {response[:60]}...")
return response # Return final result
|