WilliamGazeley commited on
Commit
7c1f337
·
1 Parent(s): 49c70c9

Implemene Azure search tool

Browse files
Files changed (3) hide show
  1. requirements.txt +3 -0
  2. src/config.py +7 -1
  3. src/functions.py +33 -0
requirements.txt CHANGED
@@ -23,3 +23,6 @@ yfinance==0.2.36
23
  transformers==4.40.2
24
  langchain==0.1.9
25
  accelerate==0.27.2
 
 
 
 
23
  transformers==4.40.2
24
  langchain==0.1.9
25
  accelerate==0.27.2
26
+ azure-search-documents==11.6.0b1
27
+ azure-identity==1.16.0
28
+
src/config.py CHANGED
@@ -5,7 +5,13 @@ class Config(BaseSettings):
5
  hf_token: str = Field(...)
6
  hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
7
  headless: bool = Field(False, description="Run in headless mode.")
8
-
 
 
 
 
 
 
9
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
10
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
11
  load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
 
5
  hf_token: str = Field(...)
6
  hf_model: str = Field("InvestmentResearchAI/LLM-ADE-dev")
7
  headless: bool = Field(False, description="Run in headless mode.")
8
+
9
+ az_search_endpoint: str = Field("https://analysis-bank.search.windows.net")
10
+ az_search_api_key: str = Field(...)
11
+ az_search_idx_name: str = Field("analysis-index")
12
+ az_search_top_k: int = Field(2, description="Max number of results to retrun")
13
+ az_search_min_score: float = Field(12.0, description="Only results above this confidence score is used")
14
+
15
  chat_template: str = Field("chatml", description="Chat template for prompt formatting")
16
  num_fewshot: int | None = Field(None, description="Option to use json mode examples")
17
  load_in_4bit: str = Field("False", description="Option to load in 4bit with bitsandbytes")
src/functions.py CHANGED
@@ -11,6 +11,38 @@ from utils import inference_logger
11
  from langchain.tools import tool
12
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @tool
15
  def google_search_and_scrape(query: str) -> dict:
16
  """
@@ -246,6 +278,7 @@ def get_company_profile(symbol: str) -> dict:
246
 
247
  def get_openai_tools() -> List[dict]:
248
  functions = [
 
249
  google_search_and_scrape,
250
  get_current_stock_price,
251
  get_company_news,
 
11
  from langchain.tools import tool
12
  from langchain_core.utils.function_calling import convert_to_openai_tool
13
 
14
+ from azure.core.credentials import AzureKeyCredential
15
+ from azure.search.documents import SearchClient
16
+
17
+
18
+ az_creds = AzureKeyCredential(config.az_search_api_key)
19
+ az_search_client = SearchClient(config.az_search_endpoint, config.az_search_idx_name, az_creds)
20
+
21
+ @tool
22
+ def get_company_analysis(query: str) -> dict:
23
+ """
24
+ Searches through your database of company and crypto analysis, retrieves top 2
25
+ pieces of analysis relevant to your query.
26
+
27
+ Args:
28
+ query (str): The search query
29
+ Returns:
30
+ list: A list of dictionaries containing the pieces of analysis.
31
+ """
32
+ results = az_search_client.search(
33
+ query_type="simple",
34
+ search_text=query,
35
+ select="title,content",
36
+ include_total_count=True,
37
+ top=config.az_search_top_k
38
+ )
39
+
40
+ output = []
41
+ for x in results:
42
+ if x["@search.score"] >= config.az_search_min_score:
43
+ output.append({"title": x["title"], "content": x["content"]})
44
+ return output
45
+
46
  @tool
47
  def google_search_and_scrape(query: str) -> dict:
48
  """
 
278
 
279
  def get_openai_tools() -> List[dict]:
280
  functions = [
281
+ get_company_analysis,
282
  google_search_and_scrape,
283
  get_current_stock_price,
284
  get_company_news,