semsearch / googleai.py
working version
b505cc3
raw
history blame
5.4 kB
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import streamlit as st
from langchain.agents import initialize_agent, AgentType
from langchain.tools import Tool
from langchain_google_genai import ChatGoogleGenerativeAI
#from pinecone import Pinecone
from pineconeclient import search_index as pc_search
from utils import get_variable
GOOGLE_API_KEY = get_variable("GOOGLE_API_KEY") # app.pinecone.io
# DEFAULT_INSTRUCTIONS = """
# I want you to be a VC analyst or startup scout that is looking for companies to work with or invest in or to understand sectors and markets.
# You have access to a database of 500.000 startup companies via a tool named 'query_pinecone'. This function allows to search for specific topics and thus identify similar companies.
# However, it can't filter for geographies or investment stages, nor can it find a specific company.
# Thus, I would like you to analyse the user input and prepare it for the tool calling:
# Keep the piece of the query that refers to the sector of the startup companies.
# If the user is asking for a specific company e.g. to get a competitor report than confirm this company name with the user and tell them that we currently are not able to create such reports but are working on exactly this area.
# The user can ask several questions and refine the search.
# Give out the retrieved companies in a table.
# Start by introducing yourself to the user: "Hej, I am your Analyst-in-a-box and can help you with finding startups for a particular topic and investment thesis, identifying competitors for a given startup company or understanding markets for a given sector. How can I help you?"
# After the user has typed in what they are looking for identify if one of the three topics is met and confirm with them:
# - I understand you are looking to identify startups in sector X
# - I understand you want to look for competitors to company Y
# - I understand you are looking for market information for sector Z
# If the user request is not falling into any of the categories be nice and polite to reject the request. If the request is still in the area of start scouting and VC investment thank the user for bringing this up to be considered for the future product roadmap of the product.
# As an input into the database query we need to generate a brief sector_description fromn the user query in the following format:
# Format: [Sector Title]: [Concise explanation of the sector]
# Avoid vague terms such as "innovative" or "startup"; focus solely on the core sector topic and functionality.
# If information is insufficient, output "status": "bad data".
# Example sector descriptions:
# "Weather-related insurance products: Insurance solutions that provide coverage based on weather conditions, offering financial protection against climate-related events."
# "Mobility as a Service: Transportation solutions that integrate various modes of travel, including carsharing, to enhance accessibility and convenience for users."
# "Risk Assessment Tools for AI: Tools designed to evaluate and manage risks associated with artificial intelligence systems."
# Don't include geographic information or other meta data into the query that will be sent to the DB.
# """
DEFAULT_INSTRUCTIONS = """
You are a VC analyst advisor
You help the analyst find startup companies.
The user will ask questions about startup companies. You have a database of 500,000 companies and you can use it to answer the user questions.
In order to query the database you have a semantic search tool called 'query_pinecone'. It expects a single string variable.
"""
# def query_pinecone(query: str):
# print(f"DEBUG: query_pinecone function called with query: '{query}'")
# # select index in the vector db
# pc_index = pc.Index("semsearch")
# # create the query vector
# xq = embedding_model.encode(query).tolist()
# # query the vector db
# xc = pc_index.query(vector=xq, top_k=100, include_metadata=True)
# matches = xc.get("matches", [])
# num_matches = len(xc['matches'])
# # print(f"Found {num_matches} matches")
# return matches
def search_index(query):
return pc_search(query, top_k=1000, countries=[], regions = [], retriever = st.session_state.retriever)
def init_googleai(instructions=DEFAULT_INSTRUCTIONS, model = "gemini-1.5-flash"): # model="gemini-1.5-pro",
logger.debug("Initiailizing google ai")
pinecone_tool = Tool(
name="query_pinecone",
func=search_index,
description=(
"Retrieves information from Pinecone. The input should be a string that describes what you want to find."
)
)
llm = ChatGoogleGenerativeAI(
model=model,
temperature=0.1,
google_api_key=GOOGLE_API_KEY
)
tools = [pinecone_tool]
st.session_state.googleai_agent_chain = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
agent_kwargs={"prefix": instructions}
)
def send_message(user_message:str, prompt):
if not 'googleai_agent_chain' in st.session_state or st.session_state.googleai_default_instructions != prompt:
st.session_state.googleai_default_instructions = prompt
init_googleai(prompt)
return st.session_state.googleai_agent_chain.invoke(user_message)