import re
import logging
import json
from langchain.schema import (
    HumanMessage,
    SystemMessage,
)
import requests
from datetime import datetime
from uuid import uuid4
import random

def save_logs(scheduler, JSON_DATASET_PATH, logs, feedback=None) -> None:
    """ Every interaction with app saves the log of question and answer, 
        this is to get the usage statistics of app and evaluate model performances 
    """
    try:
        # We get the timestamp here now because we are simply recording time of logging
        current_time = datetime.now().timestamp()
        logs["time"] = str(current_time)
        
        # Save feedback (if any)
        if feedback:
            logs["feedback"] = feedback
            logs["record_id"] = str(uuid4())

        # Do some reordering to keep things clean (time up front)
        field_order = [
            "record_id",
            "session_id",
            "time",  # current log time
            "session_duration_seconds",
            "client_location",
            "platform",
            "system_prompt",
            "sources",
            "reports",
            "subtype",
            "year",
            "question",
            "retriever",
            "endpoint_type",
            "reader",
            "docs",
            "answer",
            "feedback"
        ]
        ordered_logs = {k: logs.get(k) for k in field_order if k in logs}

        with scheduler.lock:
            with open(JSON_DATASET_PATH, 'a') as f:
                json.dump(ordered_logs, f)
                f.write("\n")
                logging.info("logging done")
    except Exception as e:
        raise
        

def get_message_template(type, SYSTEM_PROMPT, USER_PROMPT):
    if type == 'NVIDIA':
        messages =  [{"role": "system", "content": SYSTEM_PROMPT},
                {"role":"user","content":USER_PROMPT}]
    elif type == 'DEDICATED':
        messages = [
                 SystemMessage(content=SYSTEM_PROMPT),
                 HumanMessage(content=USER_PROMPT),]
    elif type == 'SERVERLESS':
        messages = [
                 SystemMessage(content=SYSTEM_PROMPT),
                 HumanMessage(content=USER_PROMPT),]
    else:
        logging.error("No message template found")
        raise
    
    return messages


def make_html_source(source,i):
    """
    takes the text and converts it into html format for display in "source" side tab
    """
    meta = source.metadata
    content = source.page_content.strip()

    name = meta['filename']
    card = f"""
        <div class="card" id="doc{i}">
            <div class="card-content">
                <h2>Doc {i} - {meta['filename']} - Page {int(meta['page'])}</h2>
                <p>{content}</p>
            </div>
            <div class="card-footer">
                <span>{name}</span>
                <a href="{meta['filename']}#page={int(meta['page'])}" target="_blank" class="pdf-link">
                    <span role="img" aria-label="Open PDF">🔗</span>
                </a>
            </div>
        </div>
        """

    return card


def parse_output_llm_with_sources(output):
    # Split the content into a list of text and "[Doc X]" references
    content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output)
    parts = []
    for part in content_parts:
        if part.startswith("Doc"):
            subparts = part.split(",")
            subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts]
            subparts = [f"""<a href="#doc{subpart}" class="a-doc-ref" target="_self"><span class='doc-ref'><sup>{subpart}</sup></span></a>""" for subpart in subparts]
            parts.append("".join(subparts))
        else:
            parts.append(part)
    content_parts = "".join(parts)
    return content_parts


def get_client_ip(request=None):
    """Get the client IP address from the request context"""
    try:
        if request:
            # Try different headers that might contain the real IP
            ip = request.client.host
            # Check for proxy headers
            forwarded_for = request.headers.get('X-Forwarded-For')
            if forwarded_for:
                # X-Forwarded-For can contain multiple IPs - first one is the client
                ip = forwarded_for.split(',')[0].strip()
            
            logging.debug(f"Client IP detected: {ip}")
            return ip
    except Exception as e:
        logging.error(f"Error getting client IP: {e}")
    return "127.0.0.1"


def get_client_location(ip_address) -> dict | None:
    """Get geolocation info using ipapi.co"""
    # Add headers so we don't get blocked...
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
    }
    try:
        response = requests.get(
            f'https://ipapi.co/{ip_address}/json/',
            headers=headers,
            timeout=5
        )
        if response.status_code == 200:
            data = response.json()
            # Add random noise between -0.01 and 0.01 degrees (roughly ±1km)
            lat = data.get('latitude')
            lon = data.get('longitude')
            if lat is not None and lon is not None:
                lat += random.uniform(-0.01, 0.01)
                lon += random.uniform(-0.01, 0.01)
            
            return {
                'city': data.get('city'),
                'region': data.get('region'),
                'country': data.get('country_name'),
                'latitude': lat,
                'longitude': lon
            }
        elif response.status_code == 429:
            logging.warning(f"Rate limit exceeded. Response: {response.text}")
            return None
        else:
            logging.error(f"Error: Status code {response.status_code}. Response: {response.text}")
            return None
            
    except requests.exceptions.RequestException as e:
        logging.error(f"Request failed: {str(e)}")
        return None


def get_platform_info(user_agent: str) -> str:
    """Get platform info"""
    # Make a best guess at the device type
    if any(mobile_keyword in user_agent.lower() for mobile_keyword in ['mobile', 'android', 'iphone', 'ipad', 'ipod']):
        platform_info = 'mobile'
    else:
        platform_info = 'desktop'
            
    return platform_info