Spaces:
Sleeping
Sleeping
WilliamGazeley
commited on
Commit
·
7c1f337
1
Parent(s):
49c70c9
Implemene Azure search tool
Browse files- requirements.txt +3 -0
- src/config.py +7 -1
- src/functions.py +33 -0
requirements.txt
CHANGED
@@ -23,3 +23,6 @@ yfinance==0.2.36
|
|
23 |
transformers==4.40.2
|
24 |
langchain==0.1.9
|
25 |
accelerate==0.27.2
|
|
|
|
|
|
|
|
23 |
transformers==4.40.2
|
24 |
langchain==0.1.9
|
25 |
accelerate==0.27.2
|
26 |
+
azure-search-documents==11.6.0b1
|
27 |
+
azure-identity==1.16.0
|
28 |
+
|
src/config.py
CHANGED
@@ -5,7 +5,13 @@ class Config(BaseSettings):
|
|
5 |
hf_token: str = Field(...)
|
6 |
hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
|
7 |
headless: bool = Field(False, description="Run in headless mode.")
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
10 |
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
11 |
load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
|
|
|
5 |
hf_token: str = Field(...)
|
6 |
hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
|
7 |
headless: bool = Field(False, description="Run in headless mode.")
|
8 |
+
|
9 |
+
az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
|
10 |
+
az_search_api_key: str = Field(...)
|
11 |
+
az_search_idx_name: str = Field("analysis-index")
|
12 |
+
az_search_top_k: int = Field(2, description="Max number of results to retrun")
|
13 |
+
az_search_min_score: float = Field(12.0, description="Only results above this confidence score is used")
|
14 |
+
|
15 |
chat_template: str = Field("chatml", description="Chat template for prompt formatting")
|
16 |
num_fewshot: int | None = Field(None, description="Option to use json mode examples")
|
17 |
load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
|
src/functions.py
CHANGED
@@ -11,6 +11,38 @@ from utils import inference_logger
|
|
11 |
from langchain.tools import tool
|
12 |
from langchain_core.utils.function_calling import convert_to_openai_tool
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
@tool
|
15 |
def google_search_and_scrape(query: str) -> dict:
|
16 |
"""
|
@@ -246,6 +278,7 @@ def get_company_profile(symbol: str) -> dict:
|
|
246 |
|
247 |
def get_openai_tools() -> List[dict]:
|
248 |
functions = [
|
|
|
249 |
google_search_and_scrape,
|
250 |
get_current_stock_price,
|
251 |
get_company_news,
|
|
|
11 |
from langchain.tools import tool
|
12 |
from langchain_core.utils.function_calling import convert_to_openai_tool
|
13 |
|
14 |
+
from azure.core.credentials import AzureKeyCredential
|
15 |
+
from azure.search.documents import SearchClient
|
16 |
+
|
17 |
+
|
18 |
+
az_creds = AzureKeyCredential(config.az_search_api_key)
|
19 |
+
az_search_client = SearchClient(config.az_search_endpoint, config.az_search_idx_name, az_creds)
|
20 |
+
|
21 |
+
@tool
|
22 |
+
def get_company_analysis(query: str) -> dict:
|
23 |
+
"""
|
24 |
+
Searches through your database of company and crypto analysis, retrieves top 2
|
25 |
+
pieces of analysis relevant to your query.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
query (str): The search query
|
29 |
+
Returns:
|
30 |
+
list: A list of dictionaries containing the pieces of analysis.
|
31 |
+
"""
|
32 |
+
results = az_search_client.search(
|
33 |
+
query_type="simple",
|
34 |
+
search_text=query,
|
35 |
+
select="title,content",
|
36 |
+
include_total_count=True,
|
37 |
+
top=config.az_search_top_k
|
38 |
+
)
|
39 |
+
|
40 |
+
output = []
|
41 |
+
for x in results:
|
42 |
+
if x["@search.score"] >= config.az_search_min_score:
|
43 |
+
output.append({"title": x["title"], "content": x["content"]})
|
44 |
+
return output
|
45 |
+
|
46 |
@tool
|
47 |
def google_search_and_scrape(query: str) -> dict:
|
48 |
"""
|
|
|
278 |
|
279 |
def get_openai_tools() -> List[dict]:
|
280 |
functions = [
|
281 |
+
get_company_analysis,
|
282 |
google_search_and_scrape,
|
283 |
get_current_stock_price,
|
284 |
get_company_news,
|