from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from google.oauth2 import id_token
from google.auth.transport import requests as google_requests
from google_auth_oauthlib.flow import Flow
from chainlit.utils import mount_chainlit
import secrets
import json
import base64
from config.constants import (
    OAUTH_GOOGLE_CLIENT_ID,
    OAUTH_GOOGLE_CLIENT_SECRET,
    CHAINLIT_URL,
    EMAIL_ENCRYPTION_KEY,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from helpers import (
    get_time,
    reset_tokens_for_user,
    check_user_cooldown,
)
from modules.chat_processor.helpers import get_user_details, update_user_info
from config.config_manager import config_manager
import hashlib

# set config
config = config_manager.get_config().dict()

# set constants
GITHUB_REPO = config["misc"]["github_repo"]
DOCS_WEBSITE = config["misc"]["docs_website"]
ALL_TIME_TOKENS_ALLOCATED = config["token_config"]["all_time_tokens_allocated"]
TOKENS_LEFT = config["token_config"]["tokens_left"]
COOLDOWN_TIME = config["token_config"]["cooldown_time"]
REGEN_TIME = config["token_config"]["regen_time"]

GOOGLE_CLIENT_ID = OAUTH_GOOGLE_CLIENT_ID
GOOGLE_CLIENT_SECRET = OAUTH_GOOGLE_CLIENT_SECRET
GOOGLE_REDIRECT_URI = f"{CHAINLIT_URL}/auth/oauth/google/callback"

app = FastAPI()
app.mount("/public", StaticFiles(directory="public"), name="public")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Update with appropriate origins
    allow_methods=["*"],
    allow_headers=["*"],  # or specify the headers you want to allow
    expose_headers=["X-User-Info"],  # Expose the custom header
)

templates = Jinja2Templates(directory="templates")
session_store = {}
CHAINLIT_PATH = "/chainlit_tutor"

# only admin is given any additional permissions for now -- no limits on tokens
with open("public/files/students_encrypted.json", "r") as file:
    USER_ROLES = json.load(file)

# Create a Google OAuth flow
flow = Flow.from_client_config(
    {
        "web": {
            "client_id": GOOGLE_CLIENT_ID,
            "client_secret": GOOGLE_CLIENT_SECRET,
            "auth_uri": "https://accounts.google.com/o/oauth2/auth",
            "token_uri": "https://oauth2.googleapis.com/token",
            "redirect_uris": [GOOGLE_REDIRECT_URI],
            "scopes": [
                "openid",
                # "https://www.googleapis.com/auth/userinfo.email",
                # "https://www.googleapis.com/auth/userinfo.profile",
            ],
        }
    },
    scopes=[
        "openid",
        "https://www.googleapis.com/auth/userinfo.email",
        "https://www.googleapis.com/auth/userinfo.profile",
    ],
    redirect_uri=GOOGLE_REDIRECT_URI,
)


def get_user_role(username: str):

    # Function to deterministically hash emails
    def deterministic_hash(email, salt):
        return hashlib.pbkdf2_hmac("sha256", email.encode(), salt, 100000).hex()

    # encrypt email (#FIXME: this is not the best way to do this, not really encryption, more like a hash)
    encryption_salt = EMAIL_ENCRYPTION_KEY.encode()
    encrypted_email = deterministic_hash(username, encryption_salt)
    role = USER_ROLES.get(encrypted_email, ["guest"])

    if "guest" in role:
        return "unauthorized"

    return role


async def get_user_info_from_cookie(request: Request):
    user_info_encoded = request.cookies.get("X-User-Info")
    if user_info_encoded:
        try:
            user_info_json = base64.b64decode(user_info_encoded).decode()
            return json.loads(user_info_json)
        except Exception as e:
            print(f"Error decoding user info: {e}")
            return None
    return None


async def del_user_info_from_cookie(request: Request, response: Response):
    # Delete cookies from the response
    response.delete_cookie("X-User-Info")
    response.delete_cookie("session_token")
    # Get the session token from the request cookies
    session_token = request.cookies.get("session_token")
    # Check if the session token exists in the session_store before deleting
    if session_token and session_token in session_store:
        del session_store[session_token]


def get_user_info(request: Request):
    session_token = request.cookies.get("session_token")
    if session_token and session_token in session_store:
        return session_store[session_token]
    return None


@app.get("/", response_class=HTMLResponse)
async def login_page(request: Request):
    user_info = await get_user_info_from_cookie(request)
    if user_info and user_info.get("google_signed_in"):
        return RedirectResponse("/post-signin")
    return templates.TemplateResponse(
        "login.html",
        {"request": request, "GITHUB_REPO": GITHUB_REPO, "DOCS_WEBSITE": DOCS_WEBSITE},
    )


# @app.get("/login/guest")
# async def login_guest():
#     username = "guest"
#     session_token = secrets.token_hex(16)
#     unique_session_id = secrets.token_hex(8)
#     username = f"{username}_{unique_session_id}"
#     session_store[session_token] = {
#         "email": username,
#         "name": "Guest",
#         "profile_image": "",
#         "google_signed_in": False,  # Ensure guest users do not have this flag
#     }
#     user_info_json = json.dumps(session_store[session_token])
#     user_info_encoded = base64.b64encode(user_info_json.encode()).decode()

#     # Set cookies
#     response = RedirectResponse(url="/post-signin", status_code=303)
#     response.set_cookie(key="session_token", value=session_token)
#     response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True)
#     return response


@app.get("/unauthorized", response_class=HTMLResponse)
async def unauthorized(request: Request):
    return templates.TemplateResponse("unauthorized.html", {"request": request})


@app.get("/login/google")
async def login_google(request: Request):
    # Clear any existing session cookies to avoid conflicts with guest sessions
    response = RedirectResponse(url="/post-signin")
    response.delete_cookie(key="session_token")
    response.delete_cookie(key="X-User-Info")

    user_info = await get_user_info_from_cookie(request)
    # Check if user is already signed in using Google
    if user_info and user_info.get("google_signed_in"):
        return RedirectResponse("/post-signin")
    else:
        authorization_url, _ = flow.authorization_url(prompt="consent")
        return RedirectResponse(authorization_url, headers=response.headers)


@app.get("/auth/oauth/google/callback")
async def auth_google(request: Request):
    try:
        flow.fetch_token(code=request.query_params.get("code"))
        credentials = flow.credentials
        user_info = id_token.verify_oauth2_token(
            credentials.id_token, google_requests.Request(), GOOGLE_CLIENT_ID
        )

        email = user_info["email"]
        name = user_info.get("name", "")
        profile_image = user_info.get("picture", "")
        role = get_user_role(email)

        if role == "unauthorized":
            return RedirectResponse("/unauthorized")

        session_token = secrets.token_hex(16)
        session_store[session_token] = {
            "email": email,
            "name": name,
            "profile_image": profile_image,
            "google_signed_in": True,  # Set this flag to True for Google-signed users
        }

        # add literalai user info to session store to be sent to chainlit
        literalai_user = await get_user_details(email)
        session_store[session_token]["literalai_info"] = literalai_user.to_dict()
        session_store[session_token]["literalai_info"]["metadata"]["role"] = role

        user_info_json = json.dumps(session_store[session_token])
        user_info_encoded = base64.b64encode(user_info_json.encode()).decode()

        # Set cookies
        response = RedirectResponse(url="/post-signin", status_code=303)
        response.set_cookie(key="session_token", value=session_token)
        response.set_cookie(
            key="X-User-Info", value=user_info_encoded, httponly=True
        )  # TODO: is the flag httponly=True necessary?
        return response
    except Exception as e:
        print(f"Error during Google OAuth callback: {e}")
        return RedirectResponse(url="/", status_code=302)


@app.get("/cooldown")
async def cooldown(request: Request):
    user_info = await get_user_info_from_cookie(request)
    user_details = await get_user_details(user_info["email"])
    current_datetime = get_time()
    cooldown, cooldown_end_time = await check_user_cooldown(
        user_details, current_datetime, COOLDOWN_TIME, TOKENS_LEFT, REGEN_TIME
    )
    print(f"User in cooldown: {cooldown}")
    print(f"Cooldown end time: {cooldown_end_time}")
    if cooldown and "admin" not in get_user_role(user_info["email"]):
        return templates.TemplateResponse(
            "cooldown.html",
            {
                "request": request,
                "username": user_info["email"],
                "role": get_user_role(user_info["email"]),
                "cooldown_end_time": cooldown_end_time,
                "tokens_left": user_details.metadata["tokens_left"],
            },
        )
    else:
        user_details.metadata["in_cooldown"] = False
        await update_user_info(user_details)
        await reset_tokens_for_user(
            user_details,
            config["token_config"]["tokens_left"],
            config["token_config"]["regen_time"],
        )
        return RedirectResponse("/post-signin")


@app.get("/post-signin", response_class=HTMLResponse)
async def post_signin(request: Request):
    user_info = await get_user_info_from_cookie(request)
    if not user_info:
        user_info = get_user_info(request)
    user_details = await get_user_details(user_info["email"])
    current_datetime = get_time()
    user_details.metadata["last_login"] = current_datetime
    # if new user, set the number of tries
    if "tokens_left" not in user_details.metadata:
        user_details.metadata["tokens_left"] = (
            TOKENS_LEFT  # set the number of tokens left for the new user
        )
    if "last_message_time" not in user_details.metadata:
        user_details.metadata["last_message_time"] = current_datetime
    if "all_time_tokens_allocated" not in user_details.metadata:
        user_details.metadata["all_time_tokens_allocated"] = ALL_TIME_TOKENS_ALLOCATED
    if "in_cooldown" not in user_details.metadata:
        user_details.metadata["in_cooldown"] = False
    await update_user_info(user_details)

    if "last_message_time" in user_details.metadata and "admin" not in get_user_role(
        user_info["email"]
    ):
        cooldown, _ = await check_user_cooldown(
            user_details, current_datetime, COOLDOWN_TIME, TOKENS_LEFT, REGEN_TIME
        )
        if cooldown:
            user_details.metadata["in_cooldown"] = True
            return RedirectResponse("/cooldown")
        else:
            user_details.metadata["in_cooldown"] = False
            await reset_tokens_for_user(
                user_details,
                config["token_config"]["tokens_left"],
                config["token_config"]["regen_time"],
            )

    if user_info:
        username = user_info["email"]
        role = get_user_role(username)
        jwt_token = request.cookies.get("X-User-Info")
        return templates.TemplateResponse(
            "dashboard.html",
            {
                "request": request,
                "username": username,
                "role": role,
                "jwt_token": jwt_token,
                "tokens_left": user_details.metadata["tokens_left"],
                "all_time_tokens_allocated": user_details.metadata[
                    "all_time_tokens_allocated"
                ],
                "total_tokens_allocated": ALL_TIME_TOKENS_ALLOCATED,
            },
        )
    return RedirectResponse("/")


@app.get("/start-tutor")
@app.post("/start-tutor")
async def start_tutor(request: Request):
    user_info = await get_user_info_from_cookie(request)
    if user_info:
        user_info_json = json.dumps(user_info)
        user_info_encoded = base64.b64encode(user_info_json.encode()).decode()

        response = RedirectResponse(CHAINLIT_PATH, status_code=303)
        response.set_cookie(key="X-User-Info", value=user_info_encoded, httponly=True)
        return response

    return RedirectResponse(url="/")


@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
    if exc.status_code == 404:
        return templates.TemplateResponse(
            "error_404.html", {"request": request}, status_code=404
        )
    return templates.TemplateResponse(
        "error.html",
        {"request": request, "error": str(exc)},
        status_code=exc.status_code,
    )


@app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception):
    return templates.TemplateResponse(
        "error.html", {"request": request, "error": str(exc)}, status_code=500
    )


@app.get("/logout", response_class=HTMLResponse)
async def logout(request: Request, response: Response):
    await del_user_info_from_cookie(request=request, response=response)
    response = RedirectResponse(url="/", status_code=302)
    # Set cookies to empty values and expire them immediately
    response.set_cookie(key="session_token", value="", expires=0)
    response.set_cookie(key="X-User-Info", value="", expires=0)
    return response


@app.get("/get-tokens-left")
async def get_tokens_left(request: Request):
    try:
        user_info = await get_user_info_from_cookie(request)
        user_details = await get_user_details(user_info["email"])
        await reset_tokens_for_user(
            user_details,
            config["token_config"]["tokens_left"],
            config["token_config"]["regen_time"],
        )
        tokens_left = user_details.metadata["tokens_left"]
        return {"tokens_left": tokens_left}
    except Exception as e:
        print(f"Error getting tokens left: {e}")
        return {"tokens_left": 0}


mount_chainlit(app=app, target="chainlit_app.py", path=CHAINLIT_PATH)

if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="127.0.0.1", port=7860)