import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import requests
from datetime import datetime,timedelta
import re

attn_maps = {}
def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook

def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split('.')[-1].startswith('attn2'):
            module.register_forward_hook(hook_fn(name))

    return unet

def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1,0)
    temp_size = None

    for i in range(0,5):
        scale = 2 ** i
        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32),
        size=target_size,
        mode='bilinear',
        align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map
def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):

    idx = 0 if instance_or_negative else 1
    net_attn_maps = []

    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
        attn_map = upscale(attn_map, image_size) 
        net_attn_maps.append(attn_map) 

    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)

    return net_attn_maps

def attnmaps2images(net_attn_maps):

    #total_attn_scores = 0
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        #total_attn_scores += attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        #print("norm: ", normalized_attn_map.shape)
        image = Image.fromarray(normalized_attn_map)

        #image = fix_save_attn_map(attn_map)
        images.append(image)

    #print(total_attn_scores)
    return images
def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")


class RemoteJson:
    def __init__(self, url, refresh_gap_seconds=3600, processor=None):
        """
        Initialize the RemoteJsonManager.
        :param url: The URL of the remote JSON file.
        :param refresh_gap_seconds: Time in seconds after which the JSON should be refreshed.
        :param processor: Optional callback function to process the JSON after it's loaded successfully.
        """
        self.url = url
        self.refresh_gap_seconds = refresh_gap_seconds
        self.processor = processor
        self.json_data = None
        self.last_updated = None

    def _load_json(self):
        """
        Load JSON from the remote URL. If loading fails, return None.
        """
        try:
            response = requests.get(self.url)
            response.raise_for_status()
            return response.json()
        except requests.RequestException as e:
            print(f"Failed to fetch JSON: {e}")
            return None

    def _should_refresh(self):
        """
        Check whether the JSON should be refreshed based on the time gap.
        """
        if not self.last_updated:
            return True  # If no last update, always refresh
        return datetime.now() - self.last_updated > timedelta(seconds=self.refresh_gap_seconds)

    def _update_json(self):
        """
        Fetch and load the JSON from the remote URL. If it fails, keep the previous data.
        """
        new_json = self._load_json()
        if new_json:
            self.json_data = new_json
            self.last_updated = datetime.now()
            print("JSON updated successfully.")
            if self.processor:
                self.json_data = self.processor(self.json_data)
        else:
            print("Failed to update JSON. Keeping the previous version.")

    def get(self):
        """
        Get the JSON, checking whether it needs to be refreshed.
        If refresh is required, it fetches the new data and applies the processor.
        """
        if self._should_refresh():
            print("Refreshing JSON...")
            self._update_json()
        else:
            print("Using cached JSON.")

        return self.json_data

def extract_key_value_pairs(input_string):
    # Define the regular expression to match [xxx:yyy] where yyy can have special characters
    pattern = r"\[([^\]]+):([^\]]+)\]"
    
    # Find all matches in the input string with the original matching string
    matches = re.finditer(pattern, input_string)
    
    # Convert matches to a list of dictionaries including the raw matching string
    result = [{"key": match.group(1), "value": match.group(2), "raw": match.group(0)} for match in matches]
    
    return result

def extract_characters(prefix, input_string):
    # Define the regular expression to match placeholders starting with "@" and ending with space or comma
    pattern = rf"{prefix}([^\s,$]+)(?=\s|,|$)"
    
    # Find all matches in the input string
    matches = re.findall(pattern, input_string)
    
    # Return a list of dictionaries with the extracted placeholders
    result = [{"raw": f"{prefix}{match}", "key": match} for match in matches]
    
    return result