import base64
import io
from typing import Dict, Any, Optional
from PIL import Image
import numpy as np
import requests

class DynamicImageOutpainter:
    """
    A sophisticated image processing class for iterative outpainting and padding.
    
    ## Key Features:
    - Dynamic image cropping and centering
    - Iterative outpainting with configurable steps
    - Flexible padding mechanism
    - AI-driven edge generation
    
    ## Usage Strategy:
    1. Initialize with base image and generation parameters
    2. Apply iterative padding and outpainting
    3. Support multiple AI inference backends
    """
    
    def __init__(
        self, 
        endpoint_url: str, 
        api_token: str, 
        padding_size: int = 256,
        max_iterations: int = 3
    ):
        """
        Initialize the outpainting processor.
        
        Args:
            endpoint_url (str): AI inference endpoint URL
            api_token (str): Authentication token for API
            padding_size (int): Size of padding around cropped image
            max_iterations (int): Maximum number of outpainting iterations
        """
        self.endpoint_url = endpoint_url
        self.api_token = api_token
        self.padding_size = padding_size
        self.max_iterations = max_iterations
        
        self.headers = {
            "Authorization": f"Bearer {self.api_token}",
            "Content-Type": "application/json",
            "Accept": "image/png"
        }
    
    def encode_image(self, image: Image.Image) -> str:
        """
        Base64 encode a PIL Image for API transmission.
        
        Args:
            image (Image.Image): Source image to encode
        
        Returns:
            str: Base64 encoded image string
        """
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    def crop_to_center(self, image: Image.Image) -> Image.Image:
        """
        Crop image to its center, maintaining square aspect ratio.
        
        Args:
            image (Image.Image): Source image
        
        Returns:
            Image.Image: Center-cropped image
        """
        width, height = image.size
        size = min(width, height)
        left = (width - size) // 2
        top = (height - size) // 2
        right = left + size
        bottom = top + size
        
        return image.crop((left, top, right, bottom))
    
    def create_padding_mask(self, image: Image.Image) -> Image.Image:
        """
        Generate a mask for padding regions.
        
        Args:
            image (Image.Image): Source image
        
        Returns:
            Image.Image: Mask indicating padding regions
        """
        mask = Image.new('L', image.size, 0)
        mask_array = np.array(mask)
        
        # Set padding regions to white (255)
        mask_array[:self.padding_size, :] = 255  # Top
        mask_array[-self.padding_size:, :] = 255  # Bottom
        mask_array[:, :self.padding_size] = 255  # Left
        mask_array[:, -self.padding_size:] = 255  # Right
        
        return Image.fromarray(mask_array)
    
    def pad_image(self, image: Image.Image) -> Image.Image:
        """
        Add padding around the image.
        
        Args:
            image (Image.Image): Source image
        
        Returns:
            Image.Image: Padded image
        """
        padded_size = (
            image.width + 2 * self.padding_size, 
            image.height + 2 * self.padding_size
        )
        padded_image = Image.new('RGBA', padded_size, (0, 0, 0, 0))
        padded_image.paste(image, (self.padding_size, self.padding_size))
        return padded_image
    
    def predict_outpainting(
        self, 
        image: Image.Image, 
        mask_image: Image.Image, 
        prompt: str
    ) -> Image.Image:
        """
        Call AI inference endpoint for outpainting.
        
        Args:
            image (Image.Image): Base image
            mask_image (Image.Image): Padding mask
            prompt (str): Outpainting generation prompt
        
        Returns:
            Image.Image: Outpainted result
        """
        payload = {
            "inputs": prompt,
            "image": self.encode_image(image),
            "mask_image": self.encode_image(mask_image)
        }
        
        try:
            response = requests.post(
                self.endpoint_url, 
                headers=self.headers, 
                json=payload
            )
            response.raise_for_status()
            return Image.open(io.BytesIO(response.content))
        except requests.RequestException as e:
            print(f"Outpainting request failed: {e}")
            return image
    
    def process_iterative_outpainting(
        self, 
        initial_image: Image.Image, 
        prompt: str
    ) -> Image.Image:
        """
        Execute iterative outpainting process.
        
        Args:
            initial_image (Image.Image): Starting image
            prompt (str): Generation prompt
        
        Returns:
            Image.Image: Final outpainted image
        """
        current_image = self.crop_to_center(initial_image)
        
        for iteration in range(self.max_iterations):
            padded_image = self.pad_image(current_image)
            mask = self.create_padding_mask(padded_image)
            
            current_image = self.predict_outpainting(
                padded_image, mask, prompt
            )
        
        return current_image
    
    def run(
        self, 
        image_path: str, 
        prompt: str
    ) -> Dict[str, Any]:
        """
        Main processing method for dynamic outpainting.
        
        Args:
            image_path (str): Path to input image
            prompt (str): Outpainting generation prompt
        
        Returns:
            Dict containing processing results
        """
        try:
            initial_image = Image.open(image_path)
            result_image = self.process_iterative_outpainting(
                initial_image, prompt
            )
            
            # Optional: Save result
            result_path = f"outpainted_result_{id(self)}.png"
            result_image.save(result_path)
            
            return {
                "status": "success",
                "result_path": result_path,
                "iterations": self.max_iterations
            }
        
        except Exception as e:
            return {
                "status": "error",
                "message": str(e)
            }

# Usage Example
def main():
    outpainter = DynamicImageOutpainter(
        endpoint_url="https://your-ai-endpoint.com",
        api_token="your_huggingface_token",
        padding_size=256,
        max_iterations=3
    )
    
    result = outpainter.run(
        image_path="input_image.png",
        prompt="Expand the scene with natural, seamless background"
    )
    
    print(result)

if __name__ == "__main__":
    main()