import json
import uuid
import re
from typing import List
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

try:
    import pythonmonkey
except ImportError:
    install('pythonmonkey')
    import pythonmonkey

# Your code using pythonmonkey

# Assuming jsonrepair is accessible
jsonrepair = pythonmonkey.require('jsonrepair').jsonrepair

def clean_command_string(command_str):
    cleaned_command = re.sub(r'\\(?!["\\/bfnrt]|u[a-fA-F0-9]{4})', '', command_str)
    cleaned_command = cleaned_command.replace('\\"', '"')
    if cleaned_command.startswith('"') and cleaned_command.endswith('"'):
        cleaned_command = cleaned_command[1:-1]
    return cleaned_command

def parse_json_safely(json_str):
    try:
        return json.loads(json_str)
    except json.JSONDecodeError:
        try:
            repaired = jsonrepair(json_str)
            return json.loads(repaired)
        except Exception:
            return json_str

def clean_json_object(obj):
    if isinstance(obj, dict):
        return {k: clean_json_object(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [clean_json_object(item) for item in obj]
    elif isinstance(obj, str):
        cleaned = clean_command_string(obj)
        return parse_json_safely(cleaned) if cleaned.startswith('{') or cleaned.startswith('[') else cleaned
    else:
        return obj

def extract_tool_calls(output_str):
    # Pattern to capture everything after 'starttoolcall' until 'endtoolcall' or end of string if 'endtoolcall' isn't present
    pattern = r'starttoolcall(.*?)(?:endtoolcall|$)'
    matches = [match for match in re.findall(pattern, output_str, re.DOTALL)]
    return matches

def extract_tool_calls_and_text(output_str):
    # Initialize an empty list to collect all segments
    segments = []

    # Last index processed in the string
    last_end = 0

    # Pattern to capture everything after 'starttoolcall' until 'endtoolcall' or end of string if 'endtoolcall' isn't present
    pattern = r'(starttoolcall(.*?)(?:endtoolcall|$))'
    for match in re.finditer(pattern, output_str, re.DOTALL):
        start, end = match.span(1)
        
        # Capture any text between the end of the last tool call and the start of the current one
        if start > last_end:
            text_between = output_str[last_end:start].strip()
            if text_between:
                segments.append({"text": text_between, "type": "text"})
        
        # Append the current tool call to the list
        tool_call_content = match.group(2).strip()
        segments.append({"tool_call": tool_call_content, "type": "function"})

        # Update the last processed index
        last_end = end

    # Check if there is any remaining text after the last tool call
    if last_end < len(output_str):
        remaining_text = output_str[last_end:].strip()
        if remaining_text:
            segments.append({"text": remaining_text, "type": "text"})

    return segments

def postprocess_output(output_str: str):
    segments = extract_tool_calls_and_text(output_str)
    results = []

    for segment in segments:
        print("processing segment")
        print(segment)
        if segment['type'] == 'function':
            call = segment['tool_call']
            try:
                parsed_call = parse_json_safely(call)
                cleaned_call = clean_json_object(parsed_call)

                if isinstance(cleaned_call, dict) and 'name' in cleaned_call and 'arguments' in cleaned_call:
                    if isinstance(cleaned_call.get('arguments'), dict):
                        cleaned_call['arguments'] = json.dumps(cleaned_call['arguments'])
                    results.append({
                        "id": uuid.uuid4().hex[:8],
                        "function": cleaned_call,
                        "type": "function",
                    })
                else:
                    results.append({
                        "id": uuid.uuid4().hex[:8],
                        "text": call,
                        "type": "text",
                    })
            except Exception as e:
                results.append({
                    "id": uuid.uuid4().hex[:8],
                    "text": call,
                    "type": "text",
                })
        else:
            results.append({
                "id": uuid.uuid4().hex[:8],
                "text": segment['text'],
                "type": "text",
            })

    return results

def json_to_markdown(json_obj):
    """Convert a JSON object to a formatted markdown string."""
    markdown = ""
    for item in json_obj:
        if item.get("type") == "text":
            # For text items, just add the text content
            markdown += item.get("text", "") + "\n\n"
        elif item.get("type") == "function":
            # For function calls, format as JSON
            markdown += "```json\n"
            markdown += json.dumps(item.get("function", {}), indent=2)
            markdown += "\n```\n\n"
    return markdown.strip()

if __name__ == "__main__":
    # Test the function with a sample input
    # output_str = '''Some text before starttoolcall{"name": "funcA", "arguments": {"param1": 1}endtoolcall
    # More text starttoolcall{"name": "funcB", "arguments": {"param2": "test"}}endtoolcall'''
    
    # output_str = '''starttoolcall{"name": "get_current_weather", "arguments": {"location": "San Francisco", "unit": "celsius"}}endtoolcall starttoolcall{"name": "get_current_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}}endtoolcall okay great '''
    output_str = '''starttoolcall{"name": "get_current_weather", "arguments": {"location": "San Francisco", "unit": "celsius"}}endtoolcall starttoolcall{"name": "get_current_weather", "arguments": {"location": "Tokyo", "unit": "celsius"}}endtoolcall starttoolcall{"name": "get_current_weather", "arguments": {"location": "Paris", "unit": '''
    parsed_json = postprocess_output(output_str)
    print(json.dumps(parsed_json, indent=2))

    print("-----")
    print(json_to_markdown(parsed_json))