# app.py — UI-TARS demo (OSS disabled) import base64 import json import ast import os import re import io import math from datetime import datetime import gradio as gr from PIL import ImageDraw # ========================= # OpenAI client (optional) # ========================= # If OPENAI_API_KEY is set we will use OpenAI for parsing the model output text. # If ENDPOINT_URL is set, we'll point the OpenAI client at that base URL (advanced use). OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") ENDPOINT_URL = os.getenv("ENDPOINT_URL") # optional MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") # safe default instead of "tgi" client = None if OPENAI_API_KEY: try: from openai import OpenAI if ENDPOINT_URL: client = OpenAI(api_key=OPENAI_API_KEY, base_url=ENDPOINT_URL) else: client = OpenAI(api_key=OPENAI_API_KEY) print("✅ OpenAI client initialized.") except Exception as e: print(f"⚠️ OpenAI client not available: {e}") else: print("ℹ️ OPENAI_API_KEY not set. Running without OpenAI parsing.") # ========================= # UI-TARS prompt # ========================= DESCRIPTION = "[UI-TARS](https://github.com/bytedance/UI-TARS)" prompt = ( "Output only the coordinate of one box in your response. " "Return a tuple like (x,y) with values in 0..1000 for x and y. " "Do not include any extra text. " ) # ========================= # OSS (Aliyun) — DISABLED # ========================= # The original demo used Aliyun OSS (oss2) to upload images/metadata. # We disable it fully so no ENV like BUCKET / ENDPOINT is required. bucket = None print("⚠️ OSS integration disabled: skipping Aliyun storage.") def draw_point_area(image, point): """Draw a red point+circle at a (0..1000, 0..1000) coordinate on the given PIL image.""" if not point: return image radius = min(image.width, image.height) // 15 x = round(point[0] / 1000 * image.width) y = round(point[1] / 1000 * image.height) drawer = ImageDraw.Draw(image) drawer.ellipse((x - radius, y - radius, x + radius, y + radius), outline="red", width=2) drawer.ellipse((x - 2, y - 2, x + 2, y + 2), fill="red") return image def resize_image(image): """Resize extremely large screenshots to keep compute stable.""" max_pixels = 6000 * 28 * 28 if image.width * image.height > max_pixels: max_pixels = 2700 * 28 * 28 else: max_pixels = 1340 * 28 * 28 resize_factor = math.sqrt(max_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) return image.resize((width, height)) def upload_images(session_id, image, result_image, query): """No-op when OSS is disabled. Keeps API stable.""" if bucket is None: print("↪️ Skipped OSS upload (no bucket configured).") return img_path = f"{session_id}.png" result_img_path = f"{session_id}-draw.png" metadata = dict( query=query, resize_image=img_path, result_image=result_img_path, session_id=session_id, ) img_bytes = io.BytesIO() image.save(img_bytes, format="png") bucket.put_object(img_path, img_bytes.getvalue()) rst_img_bytes = io.BytesIO() result_image.save(rst_img_bytes, format="png") bucket.put_object(result_img_path, rst_img_bytes.getvalue()) bucket.put_object(f"{session_id}.json", json.dumps(metadata).encode("utf-8")) print("✅ (would) upload images — skipped unless bucket configured") def run_ui(image, query, session_id, is_example_image): """Main inference path: builds the message, asks the model for (x,y), draws, returns results.""" click_xy = None images_during_iterations = [] width, height = image.width, image.height # Resize for throughput + encode image = resize_image(image) buf = io.BytesIO() image.save(buf, format="png") base64_image = base64.standard_b64encode(buf.getvalue()).decode("utf-8") # Prepare prompt for an LLM that returns '(x,y)' messages = [ { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}, {"type": "text", "text": prompt + query}, ], } ] # If OpenAI client is present, ask it to parse coordinates. Otherwise we return a safe default. output_text = "" if client is not None: try: resp = client.chat.completions.create( model=MODEL_NAME, messages=messages, temperature=1.0, top_p=0.7, max_tokens=128, frequency_penalty=1, stream=False, ) output_text = resp.choices[0].message.content or "" except Exception as e: output_text = "" print(f"⚠️ OpenAI call failed: {e}") # Extract "(x,y)" from the text using regex pattern = r"\((\d+,\s*\d+)\)" match = re.search(pattern, output_text) if match: coordinates = match.group(1) try: click_xy = ast.literal_eval(coordinates) # (x, y) with 0..1000 scale except Exception: click_xy = None # If we still don't have coordinates, fall back to center if click_xy is None: click_xy = (500, 500) # Draw result + convert to absolute pixel coords for display result_image = draw_point_area(image.copy(), click_xy) images_during_iterations.append(result_image) abs_xy = (round(click_xy[0] / 1000 * width), round(click_xy[1] / 1000 * height)) # Upload artifacts only for real (non-example) inputs if str(is_example_image) == "False": upload_images(session_id, image, result_image, query) return images_during_iterations, str(abs_xy) def update_vote(vote_type, image, click_image, prompt_text, is_example): """Simple feedback hook (no external upload when OSS disabled).""" if vote_type == "upvote": return "Everything good" if is_example == "True": return "Do nothing for example" # Example gallery returns file paths; we do nothing here return "Thank you for your feedback!" # Demo examples examples = [ ["./examples/solitaire.png", "Play the solitaire collection", True], ["./examples/weather_ui.png", "Open map", True], ["./examples/football_live.png", "click team 1 win", True], ["./examples/windows_panel.png", "switch to documents", True], ["./examples/paint_3d.png", "rotate left", True], ["./examples/finder.png", "view files from airdrop", True], ["./examples/amazon.jpg", "Search bar at the top of the page", True], ["./examples/semantic.jpg", "Home", True], ["./examples/accweather.jpg", "Select May", True], ["./examples/arxiv.jpg", "Home", True], ["./examples/health.jpg", "text labeled by 2023/11/26", True], ["./examples/ios_setting.png", "Turn off Do not disturb.", True], ] title_markdown = """ # UI-TARS Pioneering Automated GUI Interaction with Native Agents [[🤗Model](https://huggingface.co/bytedance-research/UI-TARS-7B-SFT)] [[⌨️Code](https://github.com/bytedance/UI-TARS)] [[📑Paper](https://github.com/bytedance/UI-TARS/blob/main/UI_TARS_paper.pdf)] [🏄[Midscene (Browser Automation)](https://github.com/web-infra-dev/Midscene)] [🫨[Discord](https://discord.gg/txAE43ps)] """ tos_markdown = """ ### Terms of use This demo is governed by the original license of UI-TARS. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. (注:本演示受UI-TARS的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。) """ learn_more_markdown = """ ### License Apache License 2.0 """ code_adapt_markdown = """ ### Acknowledgments The app code is modified from [ShowUI](https://huggingface.co/spaces/showlab/ShowUI) """ block_css = """ #buttons button { min-width: min(120px,100%); } #chatbot img { max-width: 80%; max-height: 80vh; width: auto; height: auto; object-fit: contain; } """ def build_demo(): with gr.Blocks(title="UI-TARS Demo", theme=gr.themes.Default(), css=block_css) as demo: state_session_id = gr.State(value=None) gr.Markdown(title_markdown) with gr.Row(): with gr.Column(scale=3): imagebox = gr.Image(type="pil", label="Input Screenshot") textbox = gr.Textbox( show_label=True, placeholder="Enter an instruction and press Submit", label="Instruction", ) submit_btn = gr.Button(value="Submit", variant="primary") with gr.Column(scale=6): output_gallery = gr.Gallery(label="Output with click", object_fit="contain", preview=True) gr.HTML( """
Notice: The red point with a circle on the output image represents the predicted coordinates for a click.
""" ) with gr.Row(): output_coords = gr.Textbox(label="Final Coordinates") image_size = gr.Textbox(label="Image Size") gr.HTML("Expected result or not? help us improve! ⬇️
") with gr.Row(elem_id="action-buttons", equal_height=True): upvote_btn = gr.Button(value="👍 Looks good!", variant="secondary") downvote_btn = gr.Button(value="👎 Wrong coordinates!", variant="secondary") clear_btn = gr.Button(value="🗑️ Clear", interactive=True) with gr.Column(scale=3): gr.Examples( examples=[[e[0], e[1]] for e in examples], inputs=[imagebox, textbox], outputs=[textbox], examples_per_page=3, ) is_example_dropdown = gr.Dropdown( choices=["True", "False"], value="False", visible=False, label="Is Example Image", ) def set_is_example(query): for _, example_query, is_example in examples: if query.strip() == example_query.strip(): return str(is_example) return "False" textbox.change(set_is_example, inputs=[textbox], outputs=[is_example_dropdown]) def on_submit(image, query, is_example_image): if image is None: raise ValueError("No image provided. Please upload an image before submitting.") session_id = datetime.now().strftime("%Y%m%d_%H%M%S") images_during_iterations, click_coords = run_ui(image, query, session_id, is_example_image) return images_during_iterations, click_coords, session_id, f"{image.width}x{image.height}" submit_btn.click( on_submit, [imagebox, textbox, is_example_dropdown], [output_gallery, output_coords, state_session_id, image_size], ) clear_btn.click( lambda: (None, None, None, None, None, None), inputs=None, outputs=[imagebox, textbox, output_gallery, output_coords, state_session_id, image_size], queue=False, ) upvote_btn.click( lambda image, click_image, prompt_text, is_example: update_vote("upvote", image, click_image, prompt_text, is_example), inputs=[imagebox, output_gallery, textbox, is_example_dropdown], outputs=[], queue=False, ) downvote_btn.click( lambda image, click_image, prompt_text, is_example: update_vote("downvote", image, click_image, prompt_text, is_example), inputs=[imagebox, output_gallery, textbox, is_example_dropdown], outputs=[], queue=False, ) gr.Markdown(tos_markdown) gr.Markdown(learn_more_markdown) gr.Markdown(code_adapt_markdown) return demo if __name__ == "__main__": demo = build_demo() demo.queue(api_open=False).launch( server_name="0.0.0.0", server_port=7860, debug=True, )