import argparse
import copy
import os
import re
import subprocess
import tempfile
import base64
from pathlib import Path
import fitz
import gradio as gr
import time
import html
from openai import OpenAI
stop_generation = False
def stream_from_vllm(messages):
    global stop_generation
    client = OpenAI(
        base_url="https://open.bigmodel.cn/api/paas/v4"
    )
    response = client.chat.completions.create(
        model="GLM-4.1V-Thinking-Flash",
        messages=messages,
        temperature=0.01,
        stream=True,
        max_tokens=8192
    )
    for chunk in response:
        if stop_generation:
            break
        if chunk.choices and chunk.choices[0].delta:
            delta = chunk.choices[0].delta
            yield delta
class GLM4VModel:
    def _strip_html(self, text: str) -> str:
        return re.sub(r"<[^>]+>", "", text).strip()
    def _wrap_text(self, text: str):
        return [{"type": "text", "text": text}]
    def _image_to_base64(self, image_path):
        with open(image_path, "rb") as image_file:
            encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
            ext = Path(image_path).suffix.lower()
            if ext in ['.jpg', '.jpeg']:
                mime_type = 'image/jpeg'
            elif ext == '.png':
                mime_type = 'image/png'
            elif ext == '.gif':
                mime_type = 'image/gif'
            elif ext == '.bmp':
                mime_type = 'image/bmp'
            elif ext in ['.tiff', '.tif']:
                mime_type = 'image/tiff'
            elif ext == '.webp':
                mime_type = 'image/webp'
            else:
                mime_type = 'image/jpeg'
            return f"data:{mime_type};base64,{encoded_string}"
    def _pdf_to_imgs(self, pdf_path):
        doc = fitz.open(pdf_path)
        imgs = []
        for i in range(doc.page_count):
            pix = doc.load_page(i).get_pixmap(dpi=180)
            img_p = os.path.join(tempfile.gettempdir(), f"{Path(pdf_path).stem}_{i}.png")
            pix.save(img_p)
            imgs.append(img_p)
        doc.close()
        return imgs
    def _ppt_to_imgs(self, ppt_path):
        tmp = tempfile.mkdtemp()
        subprocess.run(
            ["libreoffice", "--headless", "--convert-to", "pdf", "--outdir", tmp, ppt_path],
            check=True,
        )
        pdf_path = os.path.join(tmp, Path(ppt_path).stem + ".pdf")
        return self._pdf_to_imgs(pdf_path)
    def _files_to_content(self, media):
        out = []
        for f in media or []:
            ext = Path(f.name).suffix.lower()
            if ext in [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".m4v"]:
                out.append({"type": "video_url", "video_url": {"url": f"file://{f.name}"}})
            elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
                base64_url = self._image_to_base64(f.name)
                out.append({"type": "image_url", "image_url": {"url": base64_url}})
            elif ext in [".ppt", ".pptx"]:
                for p in self._ppt_to_imgs(f.name):
                    base64_url = self._image_to_base64(p)
                    out.append({"type": "image_url", "image_url": {"url": base64_url}})
            elif ext == ".pdf":
                for p in self._pdf_to_imgs(f.name):
                    base64_url = self._image_to_base64(p)
                    out.append({"type": "image_url", "image_url": {"url": base64_url}})
        return out
    def _stream_fragment(self, reasoning_content: str = "", content: str = "", skip_think: bool = False):
        think_html = ""
        if reasoning_content and not skip_think:
            # Properly escape and format thinking content
            think_content = html.escape(reasoning_content).replace("\n", "
")
            think_html = (
                    "💠Thinking
"
                    ""
                    + think_content
                    + "
")
            answer_html = f"
{content_formatted}
"
        return think_html + answer_html
    def _build_messages(self, raw_hist, sys_prompt):
        msgs = []
        if sys_prompt.strip():
            msgs.append({"role": "system", "content": [{"type": "text", "text": sys_prompt.strip()}]})
        for h in raw_hist:
            if h["role"] == "user":
                msgs.append({"role": "user", "content": h["content"]})
            else:
                # Clean HTML from previous responses
                raw = re.sub(r"", "", h["content"], flags=re.DOTALL)
                clean_content = self._strip_html(raw).strip()
                if clean_content:
                    msgs.append({"role": "assistant", "content": self._wrap_text(clean_content)})
        return msgs
    def stream_generate(self, raw_hist, sys_prompt: str, *, skip_special_tokens: bool = False):
        global stop_generation
        stop_generation = False
        msgs = self._build_messages(raw_hist, sys_prompt)
        reasoning_buffer = ""
        content_buffer = ""
        try:
            for delta in stream_from_vllm(msgs):
                if stop_generation:
                    break
                # Handle different possible response formats
                if hasattr(delta, 'reasoning_content') and delta.reasoning_content:
                    reasoning_buffer += delta.reasoning_content
                elif hasattr(delta, 'content') and delta.content:
                    content_buffer += delta.content
                else:
                    # Fallback: check if delta itself contains the content
                    if isinstance(delta, dict):
                        if 'reasoning_content' in delta and delta['reasoning_content']:
                            reasoning_buffer += delta['reasoning_content']
                        if 'content' in delta and delta['content']:
                            content_buffer += delta['content']
                    # Additional fallback for standard OpenAI format
                    elif hasattr(delta, 'content') and delta.content:
                        content_buffer += delta.content
                yield self._stream_fragment(reasoning_buffer, content_buffer)
        except Exception as e:
            error_msg = f"Error during streaming: {str(e)}"
            yield self._stream_fragment("", error_msg)
def format_display_content(content):
    if isinstance(content, list):
        text_parts = []
        file_count = 0
        for item in content:
            if item["type"] == "text":
                text_parts.append(item["text"])
            else:
                file_count += 1
        display_text = " ".join(text_parts)
        if file_count > 0:
            return f"[{file_count} file(s) uploaded]\n{display_text}"
        return display_text
    return content
def create_display_history(raw_hist):
    display_hist = []
    for h in raw_hist:
        if h["role"] == "user":
            display_content = format_display_content(h["content"])
            display_hist.append({"role": "user", "content": display_content})
        else:
            display_hist.append({"role": "assistant", "content": h["content"]})
    return display_hist
glm4v = GLM4VModel()
def check_files(files):
    vids = imgs = ppts = pdfs = 0
    for f in files or []:
        ext = Path(f.name).suffix.lower()
        if ext in [".mp4", ".avi", ".mkv", ".mov", ".wmv", ".flv", ".webm", ".mpeg", ".m4v"]:
            vids += 1
        elif ext in [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"]:
            imgs += 1
        elif ext in [".ppt", ".pptx"]:
            ppts += 1
        elif ext == ".pdf":
            pdfs += 1
    if vids > 1 or ppts > 1 or pdfs > 1:
        return False, "Only one video or one PPT or one PDF allowed"
    if imgs > 10:
        return False, "Maximum 10 images allowed"
    if (ppts or pdfs) and (vids or imgs) or (vids and imgs):
        return False, "Cannot mix documents, videos, and images"
    return True, ""
def chat(files, msg, raw_hist, sys_prompt):
    global stop_generation
    stop_generation = False
    ok, err = check_files(files)
    if not ok:
        raw_hist.append({"role": "assistant", "content": err})
        display_hist = create_display_history(raw_hist)
        yield display_hist, copy.deepcopy(raw_hist), None, ""
        return
    payload = glm4v._files_to_content(files) if files else None
    if msg.strip():
        if payload is None:
            payload = glm4v._wrap_text(msg.strip())
        else:
            payload.append({"type": "text", "text": msg.strip()})
    user_rec = {"role": "user", "content": payload if payload else msg.strip()}
    if raw_hist is None:
        raw_hist = []
    raw_hist.append(user_rec)
    place = {"role": "assistant", "content": ""}
    raw_hist.append(place)
    display_hist = create_display_history(raw_hist)
    yield display_hist, copy.deepcopy(raw_hist), None, ""
    try:
        for chunk in glm4v.stream_generate(raw_hist[:-1], sys_prompt):
            if stop_generation:
                break
            place["content"] = chunk
            display_hist = create_display_history(raw_hist)
            yield display_hist, copy.deepcopy(raw_hist), None, ""
    except Exception as e:
        error_content = f"Error: {html.escape(str(e))}
"
        place["content"] = error_content
        display_hist = create_display_history(raw_hist)
        yield display_hist, copy.deepcopy(raw_hist), None, ""
    display_hist = create_display_history(raw_hist)
    yield display_hist, copy.deepcopy(raw_hist), None, ""
def reset():
    global stop_generation
    stop_generation = True
    time.sleep(0.1)
    return [], [], None, ""
demo = gr.Blocks(title="GLM-4.1V-9B-Thinking", theme=gr.themes.Soft())
with demo:
    gr.Markdown(
        "GLM-4.1V-9B-Thinking
"
        "This demo uses the API version of the service for faster response.
"
        ""
    )
    raw_history = gr.State([])
    with gr.Row():
        with gr.Column(scale=7):
            chatbox = gr.Chatbot(
                label="Chat",
                type="messages",
                height=600,
                elem_classes="chatbot-container",
                sanitize_html=False,
                line_breaks=True
            )
            textbox = gr.Textbox(label="Message", lines=3)
            with gr.Row():
                send = gr.Button("Send", variant="primary")
                clear = gr.Button("Clear")
        with gr.Column(scale=3):
            up = gr.File(label="Upload Files", file_count="multiple", file_types=["file"], type="filepath")
            gr.Markdown("Supports images / videos / PPT / PDF")
            gr.Markdown(
                "The maximum supported input is 10 images or 1 video/PPT/PDF(less than 10 pages) in this demo. "
                "You may upload only one file type at a time (such as an image, video, PDF, or PPT"
            )
            sys = gr.Textbox(label="System Prompt", lines=6)
    send.click(
        chat,
        inputs=[up, textbox, raw_history, sys],
        outputs=[chatbox, raw_history, up, textbox]
    )
    textbox.submit(
        chat,
        inputs=[up, textbox, raw_history, sys],
        outputs=[chatbox, raw_history, up, textbox]
    )
    clear.click(
        reset,
        outputs=[chatbox, raw_history, up, textbox]
    )
if __name__ == "__main__":
    demo.launch()