import * as webllm from "@mlc-ai/web-llm";
import rehypeStringify from "rehype-stringify";
import remarkFrontmatter from "remark-frontmatter";
import remarkGfm from "remark-gfm";
import RemarkBreaks from "remark-breaks";
import remarkParse from "remark-parse";
import remarkRehype from "remark-rehype";
import RehypeKatex from "rehype-katex";
import { unified } from "unified";
import remarkMath from "remark-math";
import rehypeHighlight from "rehype-highlight";

/*************** WebLLM logic ***************/
const messageFormatter = unified()
  .use(remarkParse)
  .use(remarkFrontmatter)
  .use(remarkMath)
  .use(remarkGfm)
  .use(RemarkBreaks)
  .use(remarkRehype)
  .use(rehypeStringify)
  .use(RehypeKatex)
  .use(rehypeHighlight, {
    detect: true,
    ignoreMissing: true,
  });
const messages = [
  {
    content: "You are a helpful AI agent helping users.",
    role: "system",
  },
];

// Callback function for initializing progress
function updateEngineInitProgressCallback(report) {
  console.log("initialize", report.progress);
  document.getElementById("download-status").textContent = report.text;
}

// Create engine instance
let modelLoaded = false;
const engine = new webllm.MLCEngine();
engine.setLogLevel("INFO");
engine.setInitProgressCallback(updateEngineInitProgressCallback);

async function initializeWebLLMEngine() {
  const model_size = document.getElementById("model_size").value;
  const quantization = document.getElementById("quantization").value;
  const context_window_size = parseInt(document.getElementById("context").value);
  const temperature = parseFloat(document.getElementById("temperature").value);
  const top_p = parseFloat(document.getElementById("top_p").value);
  const presence_penalty = parseFloat(document.getElementById("presence_penalty").value);
  const frequency_penalty = parseFloat(document.getElementById("frequency_penalty").value);

  document.getElementById("download-status").classList.remove("hidden");
  const selectedModel = `Llama-3.2-${model_size}-Instruct-${quantization}-MLC`;
  const config = {
    temperature,
    top_p,
    frequency_penalty,
    presence_penalty,
    context_window_size,
  };
  console.log(`Loading Model: ${selectedModel}`);
  console.log(`Config: ${JSON.stringify(config)}`);
  await engine.reload(selectedModel, config);
  modelLoaded = true;
}

async function streamingGenerating(messages, onUpdate, onFinish, onError) {
  try {
    let curMessage = "";
    let usage;
    const completion = await engine.chat.completions.create({
      stream: true,
      messages,
      stream_options: { include_usage: true },
    });
    for await (const chunk of completion) {
      const curDelta = chunk.choices[0]?.delta.content;
      if (curDelta) {
        curMessage += curDelta;
      }
      if (chunk.usage) {
        usage = chunk.usage;
      }
      onUpdate(curMessage);
    }
    const finalMessage = await engine.getMessage();
    onFinish(finalMessage, usage);
  } catch (err) {
    onError(err);
  }
}

/*************** UI logic ***************/
function onMessageSend() {
  if (!modelLoaded) {
    return;
  }
  const input = document.getElementById("user-input").value.trim();
  const message = {
    content: input,
    role: "user",
  };
  if (input.length === 0) {
    return;
  }
  document.getElementById("send").disabled = true;

  messages.push(message);
  appendMessage(message);

  document.getElementById("user-input").value = "";
  document
    .getElementById("user-input")
    .setAttribute("placeholder", "Generating...");

  const aiMessage = {
    content: "typing...",
    role: "assistant",
  };
  appendMessage(aiMessage);

  const onFinishGenerating = async (finalMessage, usage) => {
    updateLastMessage(finalMessage);
    document.getElementById("send").disabled = false;
    const usageText =
      `prompt_tokens: ${usage.prompt_tokens}, ` +
      `completion_tokens: ${usage.completion_tokens}, ` +
      `prefill: ${usage.extra.prefill_tokens_per_s.toFixed(4)} tokens/sec, ` +
      `decoding: ${usage.extra.decode_tokens_per_s.toFixed(4)} tokens/sec`;
    document.getElementById("chat-stats").classList.remove("hidden");
    document.getElementById("chat-stats").textContent = usageText;

    document
      .getElementById("user-input")
      .setAttribute("placeholder", "Type a message...");
  };

  streamingGenerating(
    messages,
    updateLastMessage,
    onFinishGenerating,
    console.error
  );
}

function appendMessage(message) {
  const chatBox = document.getElementById("chat-box");
  const container = document.createElement("div");
  container.classList.add("message-container");
  const newMessage = document.createElement("div");
  newMessage.classList.add("message");
  newMessage.textContent = message.content;

  if (message.role === "user") {
    container.classList.add("user");
  } else {
    container.classList.add("assistant");
  }

  container.appendChild(newMessage);
  chatBox.appendChild(container);
  chatBox.scrollTop = chatBox.scrollHeight; // Scroll to the latest message
}

async function updateLastMessage(content) {
  const formattedMessage = await messageFormatter.process(content);
  const messageDoms = document
    .getElementById("chat-box")
    .querySelectorAll(".message");
  const lastMessageDom = messageDoms[messageDoms.length - 1];
  lastMessageDom.innerHTML = formattedMessage;
}

/*************** UI binding ***************/
document.addEventListener('DOMContentLoaded', function() {
  document.getElementById("download").addEventListener("click", function () {
    document.getElementById("send").disabled = true;
    initializeWebLLMEngine().then(() => {
      document.getElementById("send").disabled = false;
    });
  });
  document.getElementById("send").addEventListener("click", function () {
    onMessageSend();
  });
  document.getElementById("user-input").addEventListener("keydown", (event) => {
    if (event.key === "Enter") {
      onMessageSend();
    }
  });

  document.getElementById('model_size').addEventListener('change', function() {
    const quantizationSelect = document.getElementById('quantization');
    const selectedSize = document.getElementById('model_size').value;
    const q0Options = Array.from(quantizationSelect.options).filter(option => 
      option.value === 'q0f32' || option.value === 'q0f16'
    );

    if (selectedSize === '3B') {
      q0Options.forEach(option => option.style.display = 'none');
    } else {
      q0Options.forEach(option => option.style.display = '');
    }
    if (quantizationSelect.selectedOptions[0].style.display === 'none') {
      quantizationSelect.value = quantizationSelect.options[0].value;
    }
  });
});

window.onload = function () {
  document.getElementById("download").textContent = "Download";
  document.getElementById("download").disabled = false;
}