from langchain.chains.llm import LLMChain from langchain.prompts.chat import ( ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate) from prompts.multi_queries import system_template, human_template from config import OPENAI_API_TYPE, OPENAI_API_VERSION, OPENAI_API_KEY, OPENAI_API_BASE, DEPLOYMENT_ID from chains.azure_openai import CustomAzureOpenAI class MultiQueries(LLMChain): llm = CustomAzureOpenAI(deployment_name=DEPLOYMENT_ID, openai_api_type=OPENAI_API_TYPE, openai_api_base=OPENAI_API_BASE, openai_api_version=OPENAI_API_VERSION, openai_api_key=OPENAI_API_KEY, temperature=0.0) prompt = ChatPromptTemplate.from_messages( [ SystemMessagePromptTemplate.from_template(system_template), HumanMessagePromptTemplate.from_template(human_template) ]) if __name__ == "__main__": queries_chain = MultiQueries() out = queries_chain.predict(question="Where can I request for my event's permit in Penang?") print(out.strip().split('\n\n')[1]) print(list(map(lambda x: x.split(': ')[-1], out.split('\n\n'))))