import json
import os
import torch
import argparse
from PIL import Image
from chameleon.inference.chameleon import ChameleonInferenceModel, Options
from constants import (
    MODEL_7B_PATH,
    TOKENIZER_TEXT_PATH,
    TOKENIZER_IMAGE_CFG_PATH,
    TOKENIZER_IMAGE_PATH,
)
from typing import List, Tuple
import logging

# Set up the logging configuration
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def split_token_sequence(
    tokens: torch.LongTensor,
    boi: int,
    eoi: int
) -> List[Tuple[str, torch.LongTensor]]:
    """
    Split a sequence of tokens into text and image segments.
    
    Args:
        tokens (torch.LongTensor): The token sequence.
        boi (int): Begin of image token.
        eoi (int): End of image token.
    
    Returns:
        List[Tuple[str, torch.LongTensor]]: List of tuples indicating segment type and tokens.
    """
    batch_size, _ = tokens.shape
    assert batch_size == 1, "Batch size must be 1"
    
    device = tokens.device
    tokens = tokens[0]  # remove batch dimension
    tokens = tokens.to(device)
    segments = []
    current_segment = []
    in_image_seg = False

    for token in tokens:
        if token == boi:
            # if entering an image segment, save the current text segment (if any)
            if current_segment:
                segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
                current_segment = []
            in_image_seg = True
        elif token == eoi and in_image_seg:
            # if exiting an image segment, save the current image segment
            segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
            current_segment = []
            in_image_seg = False
        else:
            current_segment.append(token)
    # save any remaining tokens
    if current_segment:
        if in_image_seg:
            segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
        else:
            segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1)))
    return segments

def main(args: argparse.Namespace):
    """Main function to generate and process model output."""
    # Load Chameleon model
    model = ChameleonInferenceModel(
        MODEL_7B_PATH.as_posix(),
        TOKENIZER_TEXT_PATH.as_posix(),
        TOKENIZER_IMAGE_CFG_PATH.as_posix(),
        TOKENIZER_IMAGE_PATH.as_posix(),
    )
    # Print model configuration
    logging.info(f"Model path: {MODEL_7B_PATH}")
    logging.info(f"Text tokenizer path: {TOKENIZER_TEXT_PATH}")
    logging.info(f"Image tokenizer config path: {TOKENIZER_IMAGE_CFG_PATH}")
    logging.info(f"Image tokenizer path: {TOKENIZER_IMAGE_PATH}")
    # Generate options
    options = Options()
    # Prepare prompt
    instructions = [args.instruction]
    batch_prompt_ui = []
    for instruction in instructions:
        if isinstance(instruction, Tuple):
            inst, image_path = instruction
            batch_prompt_ui += [
                [
                    {"type": "image", "value": f"file:{image_path}"},
                    {"type": "text", "value": inst}
                ],
            ]
        else:
            batch_prompt_ui += [
                [
                    {"type": "text", "value": instruction}
                ],
            ]
    # generate
    tokens: torch.LongTensor = model.generate(
        batch_prompt_ui=batch_prompt_ui,
        options=options
    )
    # split
    boi, eoi = model.vocab.begin_image, model.vocab.end_image   # 8197(boi), 8196(eoi)
    segments = split_token_sequence(tokens, boi, eoi)
    # decode
    os.makedirs(args.save_dir, exist_ok=True)
    segments_data = []
    for seg_id, (seg_type, seg_tokens) in enumerate(segments):
        if seg_type == "image_seg":
            assert seg_tokens.shape[1] == 1024
            img = model.decode_image(seg_tokens)[0]
            image_path = os.path.join(args.save_dir, f"{seg_id}.png")
            img.save(image_path)
            segments_data.append({"type": "image", "content": image_path})
        else:
            assert seg_type == "text_seg"
            decoded_text = model.decode_text(seg_tokens)[0]
            segments_data.append({"type": "text", "content": decoded_text})

    jsonl_path = os.path.join("./segments.jsonl")
    with open(jsonl_path, 'w') as jsonl_file:
        for segment in segments_data:
            jsonl_file.write(json.dumps(segment) + '\n')

def parse_arguments() -> argparse.Namespace:
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Generate interleaved image-text content based on text instructions.")
    parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for interleaved image-text generation.")
    parser.add_argument("-s", "--save_dir", type=str, default="./outputs/interleaved/", help="The directory to save the generated images.")
    args: argparse.Namespace = parser.parse_args()
    return args

if __name__ == "__main__":
    args: argparse.Namespace = parse_arguments()
    main(args)