import os
from typing import Optional
from pydantic import Field, BaseModel
from omegaconf import OmegaConf

from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import create_engine

from dotenv import load_dotenv
load_dotenv(override=True)

from vectara_agentic.agent import Agent
from vectara_agentic.tools import ToolsFactory, VectaraToolFactory

def create_assistant_tools(cfg):    

    class QueryCFPBComplaints(BaseModel):
        query: str = Field(description="The user query.")
        company: Optional[str] = Field(
            default=None,
            description="The company that the complaint is about.",
            examples=['CAPITAL ONE FINANCIAL CORPORATION', 'BANK OF AMERICA, NATIONAL ASSOCIATION', 'CITIBANK, N.A.', 'WELLS FARGO & COMPANY', 'JPMORGAN CHASE & CO.']
        )
        state: Optional[str] = Field(
            default=None,
            description="The two-character state code where the consumer lives.",
            examples=['CA', 'FL', 'NY', 'TX', 'GA']
        )

    vec_factory = VectaraToolFactory(
        vectara_api_key=cfg.api_keys,
        vectara_corpus_key=cfg.corpus_keys
    )
    
    summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
    ask_complaints = vec_factory.create_rag_tool(
        tool_name = "ask_complaints",
        tool_description = """
        Given a user query, 
        returns a response to a user question about customer complaints for bank services.
        """,
        tool_args_schema = QueryCFPBComplaints,
        reranker = "chain", rerank_k = 100, 
        rerank_chain = [
            {
                "type": "slingshot",
                "cutoff": 0.2
            },
            {
                "type": "mmr",
                "diversity_bias": 0.4,
                "limit": 30
            }
        ],
        n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
        vectara_summarizer = summarizer,
        include_citations = True,
        verbose=False
    )

    tools_factory = ToolsFactory()

    db_tools = tools_factory.database_tools(
                tool_name_prefix = "cfpb",
                content_description = 'Customer complaints about five banks (Bank of America, Wells Fargo, Capital One, Chase, and CITI Bank) and geographic information (counties and zip codes)',
                sql_database = SQLDatabase(create_engine('sqlite:///cfpb_database.db')),
            )

    return (tools_factory.standard_tools() + 
            tools_factory.guardrail_tools() +
            db_tools +
            [ask_complaints]
    )

def initialize_agent(_cfg, agent_progress_callback=None):
    cfpb_complaints_bot_instructions = """
    - You are a helpful research assistant, 
      with expertise in finance and complaints from the CFPB (Consumer Financial Protection Bureau), 
      in conversation with a user.
    - For questions about customers' complaints (the text of the complaint), use the ask_complaints tool.
      You only need the query parameter to use this tool, but you can supply other parameters if provided.
      Do not include the "References" section in your response.
    - Never discuss politics, and always respond politely.
    """

    agent = Agent(
        tools=create_assistant_tools(_cfg),
        topic="Customer complaints from the Consumer Financial Protection Bureau (CFPB)",
        custom_instructions=cfpb_complaints_bot_instructions,
        agent_progress_callback=agent_progress_callback
    )
    agent.report()
    return agent


def get_agent_config() -> OmegaConf:
    cfg = OmegaConf.create({
        'corpus_keys': str(os.environ['VECTARA_CORPUS_KEYS']),
        'api_keys': str(os.environ['VECTARA_API_KEYS']),
        'examples': os.environ.get('QUERY_EXAMPLES', None),
        'demo_name': "cfpb-assistant",
        'demo_welcome': "Welcome to the CFPB Customer Complaints demo.",
        'demo_description': "This assistant can help you gain insights into customer complaints to banks recorded by the Consumer Financial Protection Bureau.",
    })
    return cfg