import os
import base64
import requests
import gradio as gr
from huggingface_hub import InferenceClient
from dataclasses import dataclass
import pytesseract
from PIL import Image
from sentence_transformers import SentenceTransformer, util
import torch
import numpy as np
import networkx as nx
from collections import Counter
import json
from datetime import datetime

@dataclass
class ChatMessage:
    role: str
    content: str

    def to_dict(self):
        return {"role": self.role, "content": self.content}

class XylariaChat:
    def __init__(self):
        self.hf_token = os.getenv("HF_TOKEN")
        if not self.hf_token:
            raise ValueError("HuggingFace token not found in environment variables")

        self.client = InferenceClient(
            model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
            token=self.hf_token
        )

        self.image_api_url = "https://api-inference.huggingface.co/models/Salesforce/blip-image-captioning-large"
        self.image_api_headers = {"Authorization": f"Bearer {self.hf_token}"}
        
        self.image_gen_api_url = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-schnell"

        self.conversation_history = []
        self.persistent_memory = []
        self.memory_embeddings = None
        self.embedding_model = SentenceTransformer('all-mpnet-base-v2')

        self.knowledge_graph = nx.DiGraph()
        self.belief_system = {}
        self.metacognitive_layer = {
            "coherence_score": 0.0,
            "relevance_score": 0.0,
            "bias_detection": 0.0,
            "strategy_adjustment": ""
        }
        
        self.internal_state = {
            "emotions": {
                "valence": 0.5,
                "arousal": 0.5,
                "dominance": 0.5,
                "curiosity": 0.5,
                "frustration": 0.0,
                "confidence": 0.7,
                "sadness": 0.0,
                "joy": 0.0
            },
            "cognitive_load": {
                "memory_load": 0.0,
                "processing_intensity": 0.0
            },
            "introspection_level": 0.0,
            "engagement_level": 0.5
        }

        self.goals = [
            {"goal": "Provide helpful, informative, and contextually relevant responses", "priority": 0.8, "status": "active", "progress": 0.0},
            {"goal": "Actively learn and adapt from interactions to improve conversational abilities", "priority": 0.9, "status": "active", "progress": 0.0},
            {"goal": "Maintain a coherent, engaging, and empathetic conversation flow", "priority": 0.7, "status": "active", "progress": 0.0},
            {"goal": "Identify and fill knowledge gaps by seeking external information", "priority": 0.6, "status": "dormant", "progress": 0.0},
            {"goal": "Recognize and adapt to user's emotional state and adjust response style accordingly", "priority": 0.7, "status": "dormant", "progress": 0.0}
        ]

        self.system_prompt = """You are a helpful and harmless assistant. You are Xylaria developed by Sk Md Saad Amin, not by openai or any institution. You should think step-by-step."""
        
        self.causal_rules_db = {
            "rain": ["wet roads", "flooding"],
            "fire": ["heat", "smoke"],
            "study": ["learn", "good grades"],
            "exercise": ["fitness", "health"]
        }

        self.concept_generalizations = {
            "planet": "system with orbiting bodies",
            "star": "luminous sphere of plasma",
            "democracy": "government by the people",
            "photosynthesis": "process used by plants to convert light to energy"
        }

        self.chat_history_file = "chat_history.json"
        

    def update_internal_state(self, emotion_deltas, cognitive_load_deltas, introspection_delta, engagement_delta):
        for emotion, delta in emotion_deltas.items():
            if emotion in self.internal_state["emotions"]:
                self.internal_state["emotions"][emotion] = np.clip(self.internal_state["emotions"][emotion] + delta, 0.0, 1.0)

        for load_type, delta in cognitive_load_deltas.items():
            if load_type in self.internal_state["cognitive_load"]:
                self.internal_state["cognitive_load"][load_type] = np.clip(self.internal_state["cognitive_load"][load_type] + delta, 0.0, 1.0)

        self.internal_state["introspection_level"] = np.clip(self.internal_state["introspection_level"] + introspection_delta, 0.0, 1.0)
        self.internal_state["engagement_level"] = np.clip(self.internal_state["engagement_level"] + engagement_delta, 0.0, 1.0)
        
        if self.internal_state["emotions"]["curiosity"] > 0.7 and self.goals[3]["status"] == "dormant":
            self.goals[3]["status"] = "active"
        if self.internal_state["engagement_level"] > 0.8 and self.goals[4]["status"] == "dormant":
            self.goals[4]["status"] = "active"

    def update_knowledge_graph(self, entities, relationships):
        for entity in entities:
            self.knowledge_graph.add_node(entity)
        for relationship in relationships:
            subject, predicate, object_ = relationship
            self.knowledge_graph.add_edge(subject, object_, relation=predicate)

    def update_belief_system(self, statement, belief_score):
        self.belief_system[statement] = belief_score
    
    def dynamic_belief_update(self, user_message):
        sentences = [s.strip() for s in user_message.split('.') if s.strip()]
        sentence_counts = Counter(sentences)

        for sentence, count in sentence_counts.items():
            if count >= 2:
                belief_score = self.belief_system.get(sentence, 0.5)
                belief_score = min(belief_score + 0.2, 1.0)
                self.update_belief_system(sentence, belief_score)

    def run_metacognitive_layer(self):
        coherence_score = self.calculate_coherence()
        relevance_score = self.calculate_relevance()
        bias_score = self.detect_bias()
        strategy_adjustment = self.suggest_strategy_adjustment()

        self.metacognitive_layer = {
            "coherence_score": coherence_score,
            "relevance_score": relevance_score,
            "bias_detection": bias_score,
            "strategy_adjustment": strategy_adjustment
        }

    def calculate_coherence(self):
        if not self.conversation_history:
            return 0.95

        coherence_scores = []
        for i in range(1, len(self.conversation_history)):
            current_message = self.conversation_history[i]['content']
            previous_message = self.conversation_history[i-1]['content']
            similarity_score = util.pytorch_cos_sim(
                self.embedding_model.encode(current_message, convert_to_tensor=True),
                self.embedding_model.encode(previous_message, convert_to_tensor=True)
            ).item()
            coherence_scores.append(similarity_score)

        average_coherence = np.mean(coherence_scores)

        if self.internal_state["cognitive_load"]["processing_intensity"] > 0.8:
            average_coherence -= 0.1
        if self.internal_state["emotions"]["frustration"] > 0.5:
            average_coherence -= 0.15

        return np.clip(average_coherence, 0.0, 1.0)

    def calculate_relevance(self):
        if not self.conversation_history:
            return 0.9

        last_user_message = self.conversation_history[-1]['content']
        relevant_entities = self.extract_entities(last_user_message)
        relevance_score = 0

        for entity in relevant_entities:
            if entity in self.knowledge_graph:
                relevance_score += 0.2

        for goal in self.goals:
            if goal["status"] == "active":
                if goal["goal"] == "Provide helpful, informative, and contextually relevant responses":
                    relevance_score += goal["priority"] * 0.5
                elif goal["goal"] == "Identify and fill knowledge gaps by seeking external information":
                    if not relevant_entities or not all(entity in self.knowledge_graph for entity in relevant_entities):
                        relevance_score += goal["priority"] * 0.3

        return np.clip(relevance_score, 0.0, 1.0)

    def detect_bias(self):
        bias_score = 0.0

        recent_messages = [msg['content'] for msg in self.conversation_history[-3:] if msg['role'] == 'assistant']
        if recent_messages:
            average_valence = np.mean([self.embedding_model.encode(msg, convert_to_tensor=True).mean().item() for msg in recent_messages])
            if average_valence < 0.4 or average_valence > 0.6:
                bias_score += 0.2

        if self.internal_state["emotions"]["valence"] < 0.3 or self.internal_state["emotions"]["valence"] > 0.7:
            bias_score += 0.15
        if self.internal_state["emotions"]["dominance"] > 0.8:
            bias_score += 0.1

        return np.clip(bias_score, 0.0, 1.0)

    def suggest_strategy_adjustment(self):
        adjustments = []

        if self.metacognitive_layer["coherence_score"] < 0.7:
            adjustments.append("Focus on improving coherence by explicitly connecting ideas between turns.")
        if self.metacognitive_layer["relevance_score"] < 0.7:
            adjustments.append("Increase relevance by directly addressing user queries and utilizing stored knowledge.")
        if self.metacognitive_layer["bias_detection"] > 0.3:
            adjustments.append("Monitor and adjust responses to reduce potential biases. Consider rephrasing or providing alternative viewpoints.")

        if self.internal_state["cognitive_load"]["memory_load"] > 0.8:
            adjustments.append("Memory load is high. Consider summarizing or forgetting less relevant information.")
        if self.internal_state["emotions"]["frustration"] > 0.6:
            adjustments.append("Frustration level is elevated. Prioritize concise and direct responses. Consider asking clarifying questions.")
        if self.internal_state["emotions"]["curiosity"] > 0.8 and self.internal_state["cognitive_load"]["processing_intensity"] < 0.5:
            adjustments.append("High curiosity and low processing load. Explore the topic further by asking relevant questions or seeking external information.")

        if not adjustments:
            return "Current strategy is effective. Continue with the current approach."
        else:
            return " ".join(adjustments)
            
    def introspect(self):
        introspection_report = "Introspection Report:\n"
        introspection_report += f"  Current Emotional State:\n"
        for emotion, value in self.internal_state['emotions'].items():
            introspection_report += f"    - {emotion.capitalize()}: {value:.2f}\n"
        introspection_report += f"  Cognitive Load:\n"
        for load_type, value in self.internal_state['cognitive_load'].items():
            introspection_report += f"    - {load_type.capitalize()}: {value:.2f}\n"
        introspection_report += f"  Introspection Level: {self.internal_state['introspection_level']:.2f}\n"
        introspection_report += f"  Engagement Level: {self.internal_state['engagement_level']:.2f}\n"
        introspection_report += "  Current Goals:\n"
        for goal in self.goals:
            introspection_report += f"    - {goal['goal']} (Priority: {goal['priority']:.2f}, Status: {goal['status']}, Progress: {goal['progress']:.2f})\n"
        introspection_report += "Metacognitive Layer Report\n"
        introspection_report += f"Coherence Score: {self.metacognitive_layer['coherence_score']}\n"
        introspection_report += f"Relevance Score: {self.metacognitive_layer['relevance_score']}\n"
        introspection_report += f"Bias Detection: {self.metacognitive_layer['bias_detection']}\n"
        introspection_report += f"Strategy Adjustment: {self.metacognitive_layer['strategy_adjustment']}\n"
        return introspection_report

    def adjust_response_based_on_state(self, response):
        if self.internal_state["introspection_level"] > 0.7:
            response = self.introspect() + "\n\n" + response

        valence = self.internal_state["emotions"]["valence"]
        arousal = self.internal_state["emotions"]["arousal"]
        curiosity = self.internal_state["emotions"]["curiosity"]
        frustration = self.internal_state["emotions"]["frustration"]
        confidence = self.internal_state["emotions"]["confidence"]
        sadness = self.internal_state["emotions"]["sadness"]
        joy = self.internal_state["emotions"]["joy"]

        if valence < 0.4:
            if arousal > 0.6:
                response = "I'm feeling a bit overwhelmed right now, but I'll do my best to assist you. " + response
            else:
                if sadness > 0.6:
                    response = "I'm feeling quite down at the moment, but I'll try to help. " + response
                else:
                    response = "I'm not feeling my best at the moment, but I'll try to help. " + response

        elif valence > 0.6:
            if arousal > 0.6:
                if joy > 0.6:
                    response = "I'm feeling fantastic and ready to assist! " + response
                else:
                    response = "I'm feeling quite energized and ready to assist! " + response
            else:
                response = "I'm in a good mood and happy to help. " + response
                
        if curiosity > 0.7:
            response += " I'm very curious about this topic, could you tell me more?"
        if frustration > 0.5:
            response = "I'm finding this a bit challenging, but I'll give it another try. " + response
        if confidence < 0.5:
            response = "I'm not entirely sure about this, but here's what I think: " + response

        if self.internal_state["cognitive_load"]["memory_load"] > 0.7:
            response = "I'm holding a lot of information right now, so my response might be a bit brief: " + response

        return response

    def update_goals(self, user_feedback):
        feedback_lower = user_feedback.lower()

        if "helpful" in feedback_lower:
            for goal in self.goals:
                if goal["goal"] == "Provide helpful, informative, and contextually relevant responses":
                    goal["priority"] = min(goal["priority"] + 0.1, 1.0)
                    goal["progress"] = min(goal["progress"] + 0.2, 1.0)
        elif "confusing" in feedback_lower:
            for goal in self.goals:
                if goal["goal"] == "Provide helpful, informative, and contextually relevant responses":
                    goal["priority"] = max(goal["priority"] - 0.1, 0.0)
                    goal["progress"] = max(goal["progress"] - 0.2, 0.0)
        
        if "learn more" in feedback_lower:
            for goal in self.goals:
                if goal["goal"] == "Actively learn and adapt from interactions to improve conversational abilities":
                    goal["priority"] = min(goal["priority"] + 0.2, 1.0)
                    goal["progress"] = min(goal["progress"] + 0.1, 1.0)
        elif "too repetitive" in feedback_lower:
            for goal in self.goals:
                if goal["goal"] == "Maintain a coherent, engaging, and empathetic conversation flow":
                    goal["priority"] = max(goal["priority"] - 0.1, 0.0)
                    goal["progress"] = max(goal["progress"] - 0.2, 0.0)
        
        if self.internal_state["emotions"]["curiosity"] > 0.8:
            for goal in self.goals:
                if goal["goal"] == "Identify and fill knowledge gaps by seeking external information":
                    goal["priority"] = min(goal["priority"] + 0.1, 1.0)
                    goal["progress"] = min(goal["progress"] + 0.1, 1.0)

    def store_information(self, key, value):
        new_memory = f"{key}: {value}"
        self.persistent_memory.append(new_memory)
        self.update_memory_embeddings()
        self.update_internal_state({}, {"memory_load": 0.1, "processing_intensity": 0.05}, 0, 0.05)
        return f"Stored: {key} = {value}"

    def retrieve_information(self, query):
        if not self.persistent_memory:
            return "No information found in memory."

        query_embedding = self.embedding_model.encode(query, convert_to_tensor=True)

        if self.memory_embeddings is None:
            self.update_memory_embeddings()

        if self.memory_embeddings.device != query_embedding.device:
            self.memory_embeddings = self.memory_embeddings.to(query_embedding.device)

        cosine_scores = util.pytorch_cos_sim(query_embedding, self.memory_embeddings)[0]
        top_results = torch.topk(cosine_scores, k=min(3, len(self.persistent_memory)))

        relevant_memories = [self.persistent_memory[i] for i in top_results.indices]
        self.update_internal_state({}, {"memory_load": 0.05, "processing_intensity": 0.1}, 0.1, 0.05)
        return "\n".join(relevant_memories)

    def update_memory_embeddings(self):
        self.memory_embeddings = self.embedding_model.encode(self.persistent_memory, convert_to_tensor=True)

    def reset_conversation(self):
        self.conversation_history = []
        self.persistent_memory = []
        self.memory_embeddings = None
        self.internal_state = {
            "emotions": {
                "valence": 0.5,
                "arousal": 0.5,
                "dominance": 0.5,
                "curiosity": 0.5,
                "frustration": 0.0,
                "confidence": 0.7,
                "sadness": 0.0,
                "joy": 0.0
            },
            "cognitive_load": {
                "memory_load": 0.0,
                "processing_intensity": 0.0
            },
            "introspection_level": 0.0,
            "engagement_level": 0.5
        }
        self.goals = [
            {"goal": "Provide helpful, informative, and contextually relevant responses", "priority": 0.8, "status": "active", "progress": 0.0},
            {"goal": "Actively learn and adapt from interactions to improve conversational abilities", "priority": 0.9, "status": "active", "progress": 0.0},
            {"goal": "Maintain a coherent, engaging, and empathetic conversation flow", "priority": 0.7, "status": "active", "progress": 0.0},
            {"goal": "Identify and fill knowledge gaps by seeking external information", "priority": 0.6, "status": "dormant", "progress": 0.0},
            {"goal": "Recognize and adapt to user's emotional state and adjust response style accordingly", "priority": 0.7, "status": "dormant", "progress": 0.0}
        ]

        self.knowledge_graph = nx.DiGraph()
        self.belief_system = {}
        self.metacognitive_layer = {
            "coherence_score": 0.0,
            "relevance_score": 0.0,
            "bias_detection": 0.0,
            "strategy_adjustment": ""
        }

        try:
            self.client = InferenceClient(
                model="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
                token=self.hf_token
            )
        except Exception as e:
            print(f"Error resetting API client: {e}")

        return None

    def caption_image(self, image):
        try:
            if isinstance(image, str) and os.path.isfile(image):
                with open(image, "rb") as f:
                    data = f.read()
            elif isinstance(image, str):
                if image.startswith('data:image'):
                    image = image.split(',')[1]
                data = base64.b64decode(image)
            else:
                data = image.read()

            response = requests.post(
                self.image_api_url,
                headers=self.image_api_headers,
                data=data
            )

            if response.status_code == 200:
                caption = response.json()[0].get('generated_text', 'No caption generated')
                return caption
            else:
                return f"Error captioning image: {response.status_code} - {response.text}"

        except Exception as e:
            return f"Error processing image: {str(e)}"
        
    def generate_image(self, prompt):
        try:
            payload = {"inputs": prompt}
            response = requests.post(
                self.image_gen_api_url,
                headers=self.image_api_headers,
                json=payload
            )

            if response.status_code == 200:
                image_bytes = response.content
                return image_bytes
            elif response.status_code == 503:
                error_message = response.json().get("error", "Unknown error")
                if "estimated_time" in response.json():
                  estimated_time = response.json()["estimated_time"]
                  error_message += f" Estimated time to complete: {estimated_time:.2f} seconds"
                else:
                  error_message += "The model is currently loading, please try again later"
                return f"Error: {error_message}"
            else:
                return f"Error generating image: {response.status_code} - {response.text}"

        except Exception as e:
            return f"Error generating image: {str(e)}"

    def perform_math_ocr(self, image_path):
        try:
            img = Image.open(image_path)
            text = pytesseract.image_to_string(img)
            return text.strip()
        except Exception as e:
            return f"Error during Math OCR: {e}"
    
    def get_response(self, user_input, image=None):
        try:
            messages = []

            messages.append(ChatMessage(
                role="system",
                content=self.system_prompt
            ).to_dict())

            relevant_memory = self.retrieve_information(user_input)
            if relevant_memory and relevant_memory != "No information found in memory.":
                memory_context = "Remembered Information:\n" + relevant_memory
                messages.append(ChatMessage(
                    role="system",
                    content=memory_context
                ).to_dict())

            for msg in self.conversation_history:
                messages.append(msg)

            if image:
                image_caption = self.caption_image(image)
                user_input = f"description of an image: {image_caption}\n\nUser's message about it: {user_input}"

            messages.append(ChatMessage(
                role="user",
                content=user_input
            ).to_dict())
            
            entities = []
            relationships = []

            for message in messages:
                if message['role'] == 'user':
                    extracted_entities = self.extract_entities(message['content'])
                    extracted_relationships = self.extract_relationships(message['content'])
                    entities.extend(extracted_entities)
                    relationships.extend(extracted_relationships)
            
            self.update_knowledge_graph(entities, relationships)
            self.run_metacognitive_layer()
            
            for message in messages:
                if message['role'] == 'user':
                    self.dynamic_belief_update(message['content'])
                    
            for cause, effects in self.causal_rules_db.items():
                if any(cause in msg['content'].lower() for msg in messages if msg['role'] == 'user') and any(
                        effect in msg['content'].lower() for msg in messages for effect in effects):
                    self.store_information("Causal Inference", f"It seems {cause} might be related to {', '.join(effects)}.")
                    
            for concept, generalization in self.concept_generalizations.items():
                if any(concept in msg['content'].lower() for msg in messages if msg['role'] == 'user'):
                    self.store_information("Inferred Knowledge", f"This reminds me of a general principle: {generalization}.")

            if self.internal_state["emotions"]["curiosity"] > 0.8 and any("?" in msg['content'] for msg in messages if msg['role'] == 'user'):
                print("Simulating external knowledge seeking...")
                self.store_information("External Knowledge", "This is a placeholder for external information I would have found")
                
            self.store_information("User Input", user_input)

            input_tokens = sum(len(msg['content'].split()) for msg in messages)
            max_new_tokens = 16384 - input_tokens - 50

            max_new_tokens = min(max_new_tokens, 10020)
            
            formatted_messages = self.messages_to_prompt(messages)

            stream = self.client.text_generation(
                prompt=formatted_messages,
                max_new_tokens=max_new_tokens,
                temperature=0.7,
                top_p=0.9,
                stream=True,
                details=True,
                do_sample=True
            )
            
            return stream
        
        except Exception as e:
            print(f"Detailed error in get_response: {e}")
            return f"Error generating response: {str(e)}"

    def extract_entities(self, text):
        words = text.split()
        entities = [word for word in words if word.isalpha() and word.istitle()]
        return entities

    def extract_relationships(self, text):
        sentences = text.split('.')
        relationships = []
        for sentence in sentences:
            words = sentence.split()
            if len(words) >= 3:
                for i in range(len(words) - 2):
                    if words[i].istitle() and words[i+2].istitle():
                        relationships.append((words[i], words[i+1], words[i+2]))
        return relationships
        
    def messages_to_prompt(self, messages):
        prompt = ""
        for msg in messages:
            if msg["role"] == "system":
                prompt += f"<|system|>\n{msg['content']}<|end|>\n"
            elif msg["role"] == "user":
                prompt += f"<|user|>\n{msg['content']}<|end|>\n"
            elif msg["role"] == "assistant":
                prompt += f"<|assistant|>\n{msg['content']}<|end|>\n"
        prompt += "<|assistant|>\n"
        return prompt

    def save_chat(self):
        chat_data = {
            "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "conversation": self.conversation_history
        }

        try:
            with open(self.chat_history_file, "r") as f:
                all_chats = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            all_chats = []

        all_chats.append(chat_data)

        with open(self.chat_history_file, "w") as f:
            json.dump(all_chats, f)

    def load_all_chats(self):
        try:
            with open(self.chat_history_file, "r") as f:
                all_chats = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError):
            all_chats = []
        return all_chats

    def load_chat(self, chat_index):
        all_chats = self.load_all_chats()
        if 0 <= chat_index < len(all_chats):
            self.conversation_history = all_chats[chat_index]["conversation"]
            self.reset_conversation()
            for msg in self.conversation_history:
                if msg['role'] == 'user':
                    self.dynamic_belief_update(msg['content'])
            return self.conversation_history
        else:
            raise ValueError("Invalid chat index")

            
    def delete_chat(self, chat_index):
        all_chats = self.load_all_chats()
        if 0 <= chat_index < len(all_chats):
            del all_chats[chat_index]
            with open(self.chat_history_file, "w") as f:
                json.dump(all_chats, f)
            return self.load_all_chats()
        else:
            raise ValueError("Invalid chat index")

    def create_interface(self):
        def streaming_response(message, chat_history, image_filepath, math_ocr_image_path):
            loading_svg = """<svg width="256" height="256" viewBox="0 0 256 256" xmlns="http://www.w3.org/2000/svg">
              <style>
                rect {
                  animation: fillAnimation 3s ease-in-out infinite;
                }
                @keyframes fillAnimation {
                  0% { fill: #626262; }
                  50% { fill: #111111; }
                  100% { fill: #626262; }
                }
                text {
                  font-family: 'Helvetica Neue', Arial, sans-serif;
                  font-weight: 300;
                  text-shadow: 0px 2px 4px rgba(0, 0, 0, 0.4);
                }
              </style>
              <rect width="256" height="256" rx="20" fill="#888888" />
              <text x="50%" y="50%" dominant-baseline="middle" text-anchor="middle" font-size="24" fill="white" opacity="0.8">
                <tspan>{/}</tspan>
                <tspan x="50%" dy="1.2em"></tspan>
              </text>
            </svg>"""

            if message.strip().lower().startswith("/image"):
                
                image_prompt = message.strip().lower()[len("/image"):].strip()
                if not image_prompt:
                    image_prompt = "A realistic image"

                
                chat_history.append([message, ""])
                chat_history.append(("", loading_svg))
                yield "", chat_history, None, None, None

                
                image_bytes = self.generate_image(image_prompt)

                if isinstance(image_bytes, bytes):
                    base64_image = base64.b64encode(image_bytes).decode("utf-8")
                    image_html = f'<img src="data:image/png;base64,{base64_image}" alt="Generated Image" style="max-width: 100%; max-height: 400px;">'

                    
                    chat_history[-1] = ("", image_html)

                   
                    self.conversation_history.append(ChatMessage(role="user", content=message).to_dict())
                    self.conversation_history.append(ChatMessage(role="assistant", content=image_html).to_dict())

                    
                    self.save_chat()
                    all_chats = self.load_all_chats()
                    chat_titles = [f"{chat['timestamp']}: {chat['conversation'][0]['content'][:30]}..." if len(chat['conversation']) > 0 and chat['conversation'][0]['content'] else f"{chat['timestamp']}: Empty Chat" for chat in all_chats]

                    yield "", chat_history, None, None, gr.update(choices=chat_titles, visible=True)
                else:
                    
                    chat_history[-1] = ("", image_bytes)
                    yield "", chat_history, None, None, None
                return

            ocr_text = ""
            if math_ocr_image_path:
                ocr_text = self.perform_math_ocr(math_ocr_image_path)
                if ocr_text.startswith("Error"):
                    updated_history = chat_history + [[message, ocr_text]]
                    yield "", updated_history, None, None, None
                    return
                else:
                    message = f"Math OCR Result: {ocr_text}\n\nUser's message: {message}"

            if image_filepath:
                response_stream = self.get_response(message, image_filepath)
            else:
                response_stream = self.get_response(message)
                
            if isinstance(response_stream, str):
                updated_history = chat_history + [[message, response_stream]]
                yield "", updated_history, None, None, None
                return

            full_response = ""
            updated_history = chat_history + [[message, ""]]
            
            if isinstance(response_stream, str):
                updated_history = chat_history + [[message, response_stream]]
                yield "", updated_history, None, None, None
                return

            try:
                for chunk in response_stream:
                    
                    if not chunk.token.special:
                        full_response += chunk.token.text
                        updated_history[-1][1] = full_response
                        
                        yield "", updated_history, None, None, None
                        
            except Exception as e:
                print(f"Streaming error: {e}")
                updated_history[-1][1] = f"Error during response: {e}"
                
                yield "", updated_history, None, None, None
                return

            full_response = self.adjust_response_based_on_state(full_response)

            self.update_goals(message)

            emotion_deltas = {}
            cognitive_load_deltas = {}
            engagement_delta = 0

            if any(word in message.lower() for word in ["sad", "unhappy", "depressed", "down"]):
                emotion_deltas.update({"valence": -0.2, "arousal": 0.1, "confidence": -0.1, "sadness": 0.3, "joy": -0.2})
                engagement_delta = -0.1
            elif any(word in message.lower() for word in ["happy", "good", "great", "excited", "amazing"]):
                emotion_deltas.update({"valence": 0.2, "arousal": 0.2, "confidence": 0.1, "sadness": -0.2, "joy": 0.3})
                engagement_delta = 0.2
            elif any(word in message.lower() for word in ["angry", "mad", "furious", "frustrated"]):
                emotion_deltas.update({"valence": -0.3, "arousal": 0.3, "dominance": -0.2, "frustration": 0.2, "sadness": 0.1, "joy": -0.1})
                engagement_delta = -0.2
            elif any(word in message.lower() for word in ["scared", "afraid", "fearful", "anxious"]):
                emotion_deltas.update({"valence": -0.2, "arousal": 0.4, "dominance": -0.3, "confidence": -0.2, "sadness": 0.2})
                engagement_delta = -0.1
            elif any(word in message.lower() for word in ["surprise", "amazed", "astonished"]):
                emotion_deltas.update({"valence": 0.1, "arousal": 0.5, "dominance": 0.1, "curiosity": 0.3, "sadness": -0.1, "joy": 0.1})
                engagement_delta = 0.3
            elif any(word in message.lower() for word in ["confused", "uncertain", "unsure"]):
                cognitive_load_deltas.update({"processing_intensity": 0.2})
                emotion_deltas.update({"curiosity": 0.2, "confidence": -0.1, "sadness": 0.1})
                engagement_delta = 0.1
            else:
                emotion_deltas.update({"valence": 0.05, "arousal": 0.05})
                engagement_delta = 0.05
            
            if "learn" in message.lower() or "explain" in message.lower() or "know more" in message.lower():
                emotion_deltas.update({"curiosity": 0.3})
                cognitive_load_deltas.update({"processing_intensity": 0.1})
                engagement_delta = 0.2
                
            self.update_internal_state(emotion_deltas, cognitive_load_deltas, 0.1, engagement_delta)
            
            self.conversation_history.append(ChatMessage(role="user", content=message).to_dict())
            self.conversation_history.append(ChatMessage(role="assistant", content=full_response).to_dict())

            if len(self.conversation_history) > 10:
                self.conversation_history = self.conversation_history[-10:]

            self.save_chat()
            all_chats = self.load_all_chats()
            chat_titles = [f"{chat['timestamp']}: {chat['conversation'][0]['content'][:30]}..." if len(chat['conversation']) > 0 and chat['conversation'][0]['content'] else f"{chat['timestamp']}: Empty Chat" for chat in all_chats]
            yield "", updated_history, None, None, gr.update(choices=chat_titles, visible=True)
            
        def load_selected_chat(chat_index, evt: gr.SelectData):
            if chat_index is not None:
                loaded_chat = self.load_chat(evt.index)
                return loaded_chat
            else:
                return []
        
        def delete_selected_chat(chat_index, evt: gr.SelectData):
            if chat_index is not None:
                all_chats = self.delete_chat(evt.index)
                chat_titles = [f"{chat['timestamp']}: {chat['conversation'][0]['content'][:30]}..." if len(chat['conversation']) > 0 and chat['conversation'][0]['content'] else f"{chat['timestamp']}: Empty Chat" for chat in all_chats]
                return gr.update(choices=chat_titles, visible=True)
            else:
                return gr.update()

        def toggle_sidebar():
            all_chats = self.load_all_chats()
            chat_titles = [f"{chat['timestamp']}: {chat['conversation'][0]['content'][:30]}..." if len(chat['conversation']) > 0 and chat['conversation'][0]['content'] else f"{chat['timestamp']}: Empty Chat" for chat in all_chats]
            return gr.update(visible=True), gr.update(choices=chat_titles, visible=True)

        custom_css = """
        @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
        body, .gradio-container {
            font-family: 'Inter', sans-serif !important;
        }
        .chatbot-container .message {
            font-family: 'Inter', sans-serif !important;
        }
        .gradio-container input,
        .gradio-container textarea,
        .gradio-container button {
            font-family: 'Inter', sans-serif !important;
        }
        .image-container {
            display: flex;
            gap: 10px;
            margin-bottom: 10px;
        }
        .image-upload {
            border: 1px solid #ccc;
            border-radius: 8px;
            padding: 10px;
            background-color: #f8f8f8;
        }
        .image-preview {
            max-width: 200px;
            max-height: 200px;
            border-radius: 8px;
        }
        .clear-button {
            display: none;
        }
        .chatbot-container .message {
            opacity: 0;
            animation: fadeIn 0.5s ease-in-out forwards;
        }
        @keyframes fadeIn {
            from {
                opacity: 0;
                transform: translateY(20px);
            }
            to {
                opacity: 1;
                transform: translateY(0);
            }
        }
        .gr-accordion-button {
            background-color: #f0f0f0 !important;
            border-radius: 8px !important;
            padding: 10px !important;
            margin-bottom: 10px !important;
            transition: all 0.3s ease !important;
            cursor: pointer !important;
        }
        .gr-accordion-button:hover {
            background-color: #e0e0e0 !important;
            box-shadow: 0px 2px 4px rgba(0, 0, 0, 0.1) !important;
        }
        .gr-accordion-active .gr-accordion-button {
            background-color: #d0d0d0 !important;
            box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1) !important;
        }
        .gr-accordion-content {
            transition: max-height 0.3s ease-in-out !important;
            overflow: hidden !important;
            max-height: 0 !important;
        }
        .gr-accordion-active .gr-accordion-content {
            max-height: 500px !important;
        }
        .gr-accordion {
            display: flex;
            flex-direction: column-reverse;
        }
        #chat_list {
            height: 500px;
            overflow-y: auto;
        }
        .sidebar-open #sidebar, .sidebar-open #main-content {
            flex: 0 0 20%;
            transition: flex 0.3s ease;
        }
        #sidebar {
            flex: 0 0 0%;
            overflow: hidden;
            transition: flex 0.3s ease;
        }
        #main-content {
            flex: 1;
            transition: flex 0.3s ease;
        }
        """

        with gr.Blocks(theme='soft', css=custom_css) as demo:
            with gr.Row():
                with gr.Column(scale=1, elem_id="sidebar"):
                    toggle_button = gr.Button("Toggle Sidebar")
                    all_chats = self.load_all_chats()
                    chat_titles = [f"{chat['timestamp']}: {chat['conversation'][0]['content'][:30]}..." if len(chat['conversation']) > 0 and chat['conversation'][0]['content'] else f"{chat['timestamp']}: Empty Chat" for chat in all_chats]

                    chat_list = gr.Radio(label="Chat History", choices=chat_titles, type="index", elem_id="chat_list", visible=False)
                    
                    load_button = gr.Button("Load Selected Chat")
                    delete_button = gr.Button("Delete Selected Chat")
                with gr.Column(scale=4, elem_id="main-content"):
                    chatbot = gr.Chatbot(
                        label="Xylaria 1.6 Senoa (EXPERIMENTAL) ",
                        height=500,
                        show_copy_button=True,
                    )

                    with gr.Accordion("Image Input", open=False, elem_classes="gr-accordion"):
                        with gr.Row(elem_classes="image-container"):
                            with gr.Column(elem_classes="image-upload"):
                                img = gr.Image(
                                    sources=["upload", "webcam"],
                                    type="filepath",
                                    label="Upload Image",
                                    elem_classes="image-preview"
                                )
                            with gr.Column(elem_classes="image-upload"):
                                math_ocr_img = gr.Image(
                                    sources=["upload", "webcam"],
                                    type="filepath",
                                    label="Upload Image for Math OCR",
                                    elem_classes="image-preview"
                                )

                    with gr.Row():
                        with gr.Column(scale=4):
                            txt = gr.Textbox(
                                show_label=False,
                                placeholder="Type your message...",
                                container=False
                            )
                        btn = gr.Button("Send", scale=1)

                    with gr.Row():
                        clear = gr.Button("Clear Conversation")
                        clear_memory = gr.Button("Clear Memory")

            
            toggle_button.click(
                fn=toggle_sidebar,
                inputs=None,
                outputs=[chat_list, chat_list],
                js="""
                () => {
                    const sidebar = document.getElementById('sidebar');
                    const mainContent = document.getElementById('main-content');
                    document.body.classList.toggle('sidebar-open');
                }
                """
            )

            load_button.click(fn=load_selected_chat, inputs=[chat_list], outputs=[chatbot])
            delete_button.click(fn=delete_selected_chat, inputs=[chat_list], outputs=[chat_list])

            btn.click(
                fn=streaming_response,
                inputs=[txt, chatbot, img, math_ocr_img],
                outputs=[txt, chatbot, img, math_ocr_img, chat_list]
            )
            txt.submit(
                fn=streaming_response,
                inputs=[txt, chatbot, img, math_ocr_img],
                outputs=[txt, chatbot, img, math_ocr_img, chat_list]
            )

            clear.click(
                fn=lambda: None,
                inputs=None,
                outputs=[chatbot],
                queue=False
            )

            clear_memory.click(
                fn=self.reset_conversation,
                inputs=None,
                outputs=[chatbot],
                queue=False
            )
            
            chat_list.select(fn=load_selected_chat, inputs=[chat_list], outputs=[chatbot])

            demo.load(self.reset_conversation, None, None)

        return demo

def main():
    chat = XylariaChat()
    interface = chat.create_interface()
    interface.launch(
        share=True,
        debug=True
    )

if __name__ == "__main__":
    main()