|
|
|
import chainlit as cl |
|
|
|
|
|
import pandas as pd |
|
from sqlalchemy import create_engine |
|
from typing import List, Tuple, Any |
|
from pydantic import BaseModel, Field |
|
|
|
|
|
import chromadb |
|
from llama_index import ( |
|
ServiceContext, |
|
SQLDatabase, |
|
VectorStoreIndex, |
|
) |
|
from llama_index.agent import OpenAIAgent |
|
from llama_index.callbacks.base import CallbackManager |
|
from llama_index.embeddings.openai import OpenAIEmbedding |
|
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine |
|
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter |
|
from llama_index.llms import OpenAI |
|
from llama_index.node_parser.simple import SimpleNodeParser |
|
from llama_index.query_engine import RetrieverQueryEngine |
|
from llama_index.readers.wikipedia import WikipediaReader |
|
from llama_index.retrievers import VectorIndexRetriever |
|
from llama_index.storage.storage_context import StorageContext |
|
from llama_index.tools import FunctionTool |
|
from llama_index.tools.query_engine import QueryEngineTool |
|
from llama_index.vector_stores import ChromaVectorStore |
|
from llama_index.vector_stores.types import ( |
|
VectorStoreInfo, |
|
MetadataInfo, |
|
ExactMatchFilter, |
|
MetadataFilters, |
|
) |
|
import logging |
|
import os |
|
import openai |
|
import json |
|
import nest_asyncio |
|
|
|
nest_asyncio.apply() |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
openai.api_key = os.environ.get("OPENAI_API_KEY") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@cl.on_chat_start |
|
async def init(): |
|
|
|
|
|
|
|
embed_model = OpenAIEmbedding() |
|
chunk_size = 2048 |
|
llm = OpenAI( |
|
temperature=0, |
|
model="gpt-3.5-turbo", |
|
streaming=True |
|
) |
|
|
|
service_context = ServiceContext.from_defaults( |
|
llm=llm, |
|
chunk_size=chunk_size, |
|
embed_model=embed_model, |
|
callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]), |
|
) |
|
|
|
text_splitter = TokenTextSplitter( |
|
chunk_size=chunk_size |
|
) |
|
|
|
node_parser = SimpleNodeParser( |
|
text_splitter=text_splitter |
|
) |
|
|
|
|
|
|
|
|
|
chroma_client = chromadb.Client() |
|
chroma_collection = chroma_client.get_or_create_collection("wikipedia_barbie_opp") |
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
wiki_vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context) |
|
|
|
movie_list = ["Barbie (film)", "Oppenheimer (film)"] |
|
wiki_docs = WikipediaReader().load_data(pages=movie_list, auto_suggest=False) |
|
|
|
|
|
|
|
for movie, wiki_doc in zip(movie_list, wiki_docs): |
|
nodes = node_parser.get_nodes_from_documents([wiki_doc]) |
|
for node in nodes: |
|
node.metadata = {"title" : movie} |
|
wiki_vector_index.insert_nodes(nodes) |
|
|
|
|
|
|
|
|
|
top_k = 3 |
|
|
|
vector_store_info = VectorStoreInfo( |
|
content_info="semantic information about movies", |
|
metadata_info=[MetadataInfo( |
|
name="title", |
|
type="str", |
|
description="title of the movie, one of [Barbie (film), Oppenheimer (film)]", |
|
)] |
|
) |
|
|
|
|
|
class AutoRetrieveModel(BaseModel): |
|
query: str = Field(..., description="natural language query string") |
|
filter_key_list: List[str] = Field( |
|
..., description="List of metadata filter field names" |
|
) |
|
filter_value_list: List[str] = Field( |
|
..., |
|
description=( |
|
"List of metadata filter field values (corresponding to names specified in filter_key_list)" |
|
) |
|
) |
|
|
|
|
|
|
|
def auto_retrieve_fn( |
|
query: str, filter_key_list: List[str], filter_value_list: List[str] |
|
): |
|
"""Auto retrieval function. |
|
Performs auto-retrieval from a vector database, and then applies a set of filters. |
|
""" |
|
query = query or "Query" |
|
|
|
exact_match_filters = [ |
|
ExactMatchFilter(key=k, value=v) |
|
for k, v in zip(filter_key_list, filter_value_list) |
|
] |
|
retriever = VectorIndexRetriever( |
|
wiki_vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k |
|
) |
|
query_engine = RetrieverQueryEngine.from_args(retriever) |
|
|
|
response = query_engine.query(query) |
|
return str(response) |
|
|
|
|
|
description = f"""\ |
|
Use this tool to look up semantic information about films. |
|
The vector database schema is given below: |
|
{vector_store_info.json()} |
|
""" |
|
|
|
auto_retrieve_tool = FunctionTool.from_defaults( |
|
fn=auto_retrieve_fn, |
|
name="auto_retrieve_tool", |
|
description=description, |
|
fn_schema=AutoRetrieveModel, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
barbie_df = pd.read_csv("barbie_data/barbie.csv") |
|
oppenheimer_df = pd.read_csv("oppenheimer_data/oppenheimer.csv") |
|
|
|
|
|
|
|
engine = create_engine("sqlite+pysqlite:///:memory:") |
|
|
|
|
|
|
|
barbie_df.to_sql( |
|
"barbie", |
|
engine |
|
) |
|
|
|
oppenheimer_df.to_sql( |
|
"oppenheimer", |
|
engine |
|
) |
|
|
|
|
|
|
|
sql_database = SQLDatabase( |
|
engine, |
|
include_tables=["barbie", "oppenheimer"]) |
|
|
|
|
|
|
|
sql_query_engine = NLSQLTableQueryEngine( |
|
sql_database=sql_database, |
|
tables=["barbie", "oppenheimer"] |
|
) |
|
|
|
|
|
|
|
sql_tool = QueryEngineTool.from_defaults( |
|
query_engine=sql_query_engine, |
|
name="sql_tool", |
|
description=( |
|
"""Useful for translating a natural language query into a SQL query over a table containing: |
|
1. barbie, containing information related to reviews of the Barbie movie. |
|
2. oppenheimer, containing information related to reviews of the Oppenheimer movie.""" |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
barbenheimer_agent = OpenAIAgent.from_tools( |
|
[sql_tool, auto_retrieve_tool], llm=llm, verbose=True |
|
) |
|
|
|
cl.user_session.set("query_engine", barbenheimer_agent) |
|
|
|
|
|
@cl.on_message |
|
async def main(message): |
|
query_engine = cl.user_session.get("query_engine") |
|
logger.info(f"Received message: {message}") |
|
|
|
response = query_engine.query(message) |
|
logger.info("Response object created") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
await cl.Message(content=json.dumps(f"{response}")).send() |
|
|