import streamlit as st
import os
from io import StringIO
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.embeddings import HuggingFaceInferenceAPIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import Document
import uuid
from llama_index.vector_stores.types import MetadataFilters, ExactMatchFilter
from typing import List
from pydantic import BaseModel
import json

inference_api_key = st.secrets["INFRERENCE_API_TOKEN"]

# embed_model_name = st.text_input(
#     'Embed Model name', "Gooly/gte-small-en-fine-tuned-e-commerce")

# llm_model_name = st.text_input(
#     'Embed Model name', "mistralai/Mistral-7B-Instruct-v0.2")


class PriceModel(BaseModel):
    """Data model for price"""
    price: str


embed_model_name = "jinaai/jina-embedding-s-en-v1"
llm_model_name = "mistralai/Mistral-7B-Instruct-v0.2"

llm = HuggingFaceInferenceAPI(
    model_name=llm_model_name, token=inference_api_key)

embed_model = HuggingFaceInferenceAPIEmbedding(
    model_name=embed_model_name,
    token=inference_api_key,
    model_kwargs={"device": ""},
    encode_kwargs={"normalize_embeddings": True},
)

service_context = ServiceContext.from_defaults(
    embed_model=embed_model, llm=llm)

query = st.text_input(
    'Query', "What is the price of the product?"
)

html_file = st.file_uploader("Upload a html file", type=["html"])

if html_file is not None:
    stringio = StringIO(html_file.getvalue().decode("utf-8"))
    string_data = stringio.read()
    with st.expander("Uploaded HTML"):
        st.code(string_data, language='html')

    document_id = str(uuid.uuid4())

    document = Document(text=string_data)
    document.metadata["id"] = document_id
    documents = [document]

    filters = MetadataFilters(
        filters=[ExactMatchFilter(key="id", value=document_id)])

    index = VectorStoreIndex.from_documents(
        documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

    query_engine = index.as_query_engine(
        filters=filters, service_context=service_context, response_mode="tree_summarize", output_cls=PriceModel)

    response = query_engine.query(query)

    st.write(f'Price: {response.price}')

# if st.button('Start Pipeline'):
#     if html_file is not None and embed_model_name is not None and llm_model_name is not None and query is not None:
#         st.write('Running Pipeline')
#         llm = HuggingFaceInferenceAPI(
#             model_name=llm_model_name, token=inference_api_key)

#         embed_model = HuggingFaceInferenceAPIEmbedding(
#             model_name=embed_model_name,
#             token=inference_api_key,
#             model_kwargs={"device": ""},
#             encode_kwargs={"normalize_embeddings": True},
#         )

#         service_context = ServiceContext.from_defaults(
#             embed_model=embed_model, llm=llm)

#         stringio = StringIO(html_file.getvalue().decode("utf-8"))
#         string_data = stringio.read()
#         with st.expander("Uploaded HTML"):
#             st.write(string_data)

#         document_id = str(uuid.uuid4())

#         document = Document(text=string_data)
#         document.metadata["id"] = document_id
#         documents = [document]

#         filters = MetadataFilters(
#             filters=[ExactMatchFilter(key="id", value=document_id)])

#         index = VectorStoreIndex.from_documents(
#             documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

#         retriever = index.as_retriever()

#         ranked_nodes = retriever.retrieve(
#             query)

#         with st.expander("Ranked Nodes"):
#             for node in ranked_nodes:
#                 st.write(node.node.get_content(), "-> Score:", node.score)

#         query_engine = index.as_query_engine(
#             filters=filters, service_context=service_context)

#         response = query_engine.query(query)

#         st.write(response.response)

#         st.write(response.source_nodes)

#     else:
#         st.error('Please fill in all the fields')
# else:
#     st.write('Press start to begin')

# # if html_file is not None:
# #     stringio = StringIO(html_file.getvalue().decode("utf-8"))
# #     string_data = stringio.read()
# #     with st.expander("Uploaded HTML"):
# #         st.write(string_data)

# #     document_id = str(uuid.uuid4())

# #     document = Document(text=string_data)
# #     document.metadata["id"] = document_id
# #     documents = [document]

# #     filters = MetadataFilters(
# #         filters=[ExactMatchFilter(key="id", value=document_id)])

# #     index = VectorStoreIndex.from_documents(
# #         documents, show_progress=True, metadata={"source": "HTML"}, service_context=service_context)

# #     retriever = index.as_retriever()

# #     ranked_nodes = retriever.retrieve(
# #         "Get me all the information about the product")

# #     with st.expander("Ranked Nodes"):
# #         for node in ranked_nodes:
# #             st.write(node.node.get_content(), "-> Score:", node.score)

# #     query_engine = index.as_query_engine(
# #         filters=filters, service_context=service_context)

# #     response = query_engine.query(
# #         "Get me all the information about the product")

# #     st.write(response)