import datetime
import json
import os
from opik import Opik
import parameters
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed


class DateTimeEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, datetime.datetime):
            return obj.isoformat()
        return super().default(obj)

def get_trace_content(opik, trace_id):
    try:
        trace_content = opik.get_trace_content(trace_id)
        return trace_content.dict()
    except Exception as e:
        print(f"Error getting trace content {trace_id}: {e}")
        return None

def get_span_content(opik, trace_id, span):
    try:
        content = opik.get_span_content(span.id)
        return {"trace_id": trace_id, "span_id": span.id, "content": content.dict()}
    except Exception as e:
        print(f"Error getting span content {span.id}: {e}")
        return None

def get_traces_on_date(start_date_str, end_date_str, project_name, api_key,max_workers=10):
    try:
        print("Step 1: Converting date strings")
        date = datetime.date.fromisoformat(start_date_str)
        start_date_str = date.isoformat() + "T00:00:00Z"
        
        if not end_date_str:
            end_date = date + datetime.timedelta(days=1)
            end_date_str = end_date.isoformat() + "T00:00:00Z"
        else:
            end_date = datetime.date.fromisoformat(end_date_str)
            end_date_str = end_date.isoformat() + "T00:00:00Z"
        
        print(f"Start: {start_date_str} and end: {end_date_str}")
        filter_string = f'start_time >= "{start_date_str}" and end_time <= "{end_date_str}"'
        print("Filter string: ", filter_string)

        print("Step 2: Initializing Opik client")
        try:
            opik = Opik(api_key=api_key, project_name=project_name, workspace='verba-tech-ninja')
            print("Opik client initialized successfully")
        except Exception as e:
            print(f"Error initializing Opik client: {e}")
            return [], []

        print("Step 3: Searching traces")
        try:
            traces = opik.search_traces(filter_string=filter_string, project_name=project_name)
            print("Total searches: ", len(traces))
        except Exception as e:
            print(f"Error searching traces: {e}")
            return [], []

        print("Step 4: Processing traces in parallel")
        all_traces_content = []
        try:
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                future_to_trace = {executor.submit(get_trace_content, opik, trace.id): trace for trace in traces}
                for future in as_completed(future_to_trace):
                    result = future.result()
                    if result:
                        all_traces_content.append(result)
            print(f"Completed processing {len(all_traces_content)} traces")
        except Exception as e:
            print(f"Error processing traces in parallel: {e}")

        print("Step 5: Processing spans in parallel")
        all_spans_content = []
        try:
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                future_to_span = {}
                for i, trace in enumerate(traces):
                    try:
                        print(f"Searching spans for trace_id: {trace.id}:{i+1}/{len(traces)}")
                        spans = opik.search_spans(project_name=parameters.project, trace_id=trace.id)
                        print(f"Found {len(spans)} spans for trace_id: {trace.id}")
                        for span in spans:
                            future_to_span[executor.submit(get_span_content, opik, trace.id, span)] = span
                    except Exception as e:
                        print(f"Error searching spans for trace {trace.id}: {e}")
                
                for future in as_completed(future_to_span):
                    result = future.result()
                    if result:
                        all_spans_content.append(result)
            print(f"Completed processing {len(all_spans_content)} spans")
        except Exception as e:
            print(f"Error processing spans in parallel: {e}")

        print("Step 6: Saving to JSON files")
        traces_file = 'all_traces_content.json'
        spans_file = 'all_spans_content.json'
        try:
            if os.path.exists(traces_file):
                os.remove(traces_file)
                print(f"Removed existing {traces_file}")
            if os.path.exists(spans_file):
                os.remove(spans_file)
                print(f"Removed existing {spans_file}")

            print(f"Writing {len(all_traces_content)} traces to {traces_file}")
            with open(traces_file, 'w') as f:
                json.dump(all_traces_content, f, indent=2, cls=DateTimeEncoder)
            print(f"Saved traces to {traces_file}")

            print(f"Writing {len(all_spans_content)} spans to {spans_file}")
            with open(spans_file, 'w') as f:
                json.dump(all_spans_content, f, indent=2, cls=DateTimeEncoder)
            print(f"Saved spans to {spans_file}")
        except Exception as e:
            print(f"Error saving to JSON files: {e}")
            with open('partial_traces_content.json', 'w') as f:
                json.dump(all_traces_content, f, indent=2, cls=DateTimeEncoder)
            with open('partial_spans_content.json', 'w') as f:
                json.dump(all_spans_content, f, indent=2, cls=DateTimeEncoder)
            print("Saved partial data to partial_traces_content.json and partial_spans_content.json")

        print("Step 7: Returning results")
        return all_traces_content, all_spans_content

    except Exception as e:
        print(f"Main function error: {e}")
        return [], []

def find_errors_and_metrics(traces, spans):
    try:
        print("Step 8: Analyzing outputs for errors")
        error_spans = []
        error_metrics = defaultdict(list)

        for span in spans:
            content = span['content']
            output = content.get("output")
            error_info = content.get("error_info", {})  

            if isinstance(output, dict) and 'output' in output:
                output_value = output.get("output")
            else:
                output_value = output

            if ((output_value is None or (isinstance(output, list) and len(output) == 0)) and len(error_info) > 0):
                error_type = error_info.get("exception_type", "unknown_error")
                error_spans.append({
                    "trace_id": span["trace_id"],
                    "span_id": span["span_id"],
                    "error_content": output, 
                    "exception_type": error_type
                })
                
                error_metrics[error_type].append({"trace_id": span["trace_id"], "span_id": span["span_id"]})

        print(f"Found {len(error_spans)} outputs with errors (empty/null)")

        print("Step 9: Saving error spans")
        error_file = 'error_spans.json'
        try:
            if os.path.exists(error_file):
                os.remove(error_file)
                print(f"Removed existing {error_file}")
            print(f"Writing {len(error_spans)} error outputs to {error_file}")
            with open(error_file, 'w') as f:
                json.dump(error_spans, f, indent=2, cls=DateTimeEncoder)
            print(f"Saved error outputs to {error_file}")
        except Exception as e:
            print(f"Error saving error spans: {e}")

        print("Step 10: Calculating metrics")
        metrics = {
            "total_errors": len(error_spans),
            "error_types": {
                error_type: {
                    "count": len(entries),
                    "instances": [
                        {"trace_id": entry["trace_id"], "span_id": entry["span_id"]}
                        for entry in entries
                    ]
                }
                for error_type, entries in error_metrics.items()
            }
        }
        print(f"Metrics calculated: {len(metrics['error_types'])} error types")
        for error_type, data in metrics["error_types"].items():
            print(f"Error Type: {error_type}, Count: {data['count']}")
            for instance in data["instances"]:
                print(f"  Trace ID: {instance['trace_id']}, Span ID: {instance['span_id']}")
        return metrics

    except Exception as e:
        print(f"Error in find_errors_and_metrics: {e}")
        return {}

def process_dates(start_date, end_date, project):
    try:
        print("Pipeline Start: Processing dates")
        traces, spans = get_traces_on_date(start_date, end_date, project, parameters.api_key)
        metrics = find_errors_and_metrics(traces, spans)
        
        html_output = """
        <div style='font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; border-radius: 10px;'>
            <h2 style='color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 5px;'>Metrics Report</h2>
            
            <div style='margin: 15px 0;'>
                <span style='color: #e74c3c; font-weight: bold;'>Total Empty/Null Outputs Found: </span>
                <span style='color: #2980b9'>{empty_count}</span>
            </div>
            
            <div style='margin: 15px 0;'>
                <span style='color: #27ae60; font-weight: bold;'>Total Traces Found: </span>
                <span style='color: #2980b9'>{traces_count}</span>
                <br>
                <span style='color: #27ae60; font-weight: bold;'>Total Spans Processed: </span>
                <span style='color: #2980b9'>{spans_count}</span>
            </div>
            
            <h3 style='color: #8e44ad; margin-top: 20px;'>Error Metrics</h3>
            {error_section}
        </div>
        """

        if metrics.get('error_types', {}):
            error_html = "<div style='background-color: #fff; padding: 15px; border-radius: 5px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);'>"
            for error_type, data in metrics.get('error_types', {}).items():
                error_html += f"""
                <div style='margin: 10px 0;'>
                    <span style='color: #d35400; font-weight: bold;'>{error_type.replace('_', ' ').title()}: </span>
                    <span style='color: #2980b9'>{data['count']} occurrences</span>
                    <div style='margin-left: 20px;'>
                        <span style='color: #7f8c8d;'>Instances:</span>
                        <ul style='list-style-type: none; padding-left: 10px;'>
                """
                for instance in data['instances'][:5]:
                    error_html += f"""
                    <li style='color: #34495e;'>
                        Trace ID: <span style='color: #16a085'>{instance['trace_id']}</span>, 
                        Span ID: <span style='color: #16a085'>{instance['span_id']}</span>
                    </li>
                    """
                if len(data['instances']) > 5:
                    error_html += "<li style='color: #7f8c8d;'>... (and more)</li>"
                error_html += "</ul></div></div>"
            error_html += "</div>"
        else:
            error_html = """
            <div style='color: #27ae60; background-color: #ecf0f1; padding: 10px; border-radius: 5px;'>
                No empty or null outputs with exceptions detected.
            </div>
            """

        formatted_output = html_output.format(
            empty_count=metrics.get('total_errors', 0),
            traces_count=len(traces),
            spans_count=len(spans),
            error_section=error_html
        )
        
        print("Pipeline End: Results formatted")
        return formatted_output
        
    except Exception as e:
        print(f"Error processing dates: {e}")
        error_html = f"""
        <div style='font-family: Arial, sans-serif; padding: 20px; background-color: #f5f5f5; border-radius: 10px;'>
            <h2 style='color: #e74c3c;'>Error Occurred</h2>
            <p style='color: #c0392b;'>Error processing dates: {str(e)}</p>
        </div>
        """
        return error_html