Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from sim.simulator import GenieSimulator | |
| RES = 512 | |
| image = Image.open("sim/assets/langtable_prompt/frame_06.png") | |
| genie = GenieSimulator( | |
| image_encoder_type='temporalvae', | |
| image_encoder_ckpt='stabilityai/stable-video-diffusion-img2vid', | |
| quantize=False, | |
| backbone_type='stmar', | |
| backbone_ckpt='data/mar_ckpt/langtable', | |
| prompt_horizon=3, | |
| action_stride=1, | |
| domain='language_table', | |
| ) | |
| prompt_image = np.tile( | |
| np.array(image), (genie.prompt_horizon, 1, 1, 1) | |
| ).astype(np.uint8) | |
| prompt_action = np.zeros( | |
| (genie.prompt_horizon - 1, genie.action_stride, 2) | |
| ).astype(np.float32) | |
| genie.set_initial_state((prompt_image, prompt_action)) | |
| image = genie.reset() | |
| image = cv2.resize(image, (RES, RES)) | |
| image = Image.fromarray(image) | |
| # Example model: takes a direction and returns a random image | |
| def model(direction: str, genie=genie): | |
| if direction == 'right': | |
| action = np.array([0, 0.05]) | |
| elif direction == 'left': | |
| action = np.array([0, -0.05]) | |
| elif direction == 'down': | |
| action = np.array([0.05, 0]) | |
| elif direction == 'up': | |
| action = np.array([-0.05, 0]) | |
| else: | |
| raise ValueError(f"Invalid direction: {direction}") | |
| next_image = genie.step(action)['pred_next_frame'] | |
| next_image = cv2.resize(next_image, (RES, RES)) | |
| return Image.fromarray(next_image) | |
| # Gradio function to handle user input | |
| def handle_input(direction): | |
| print(f"User clicked: {direction}") | |
| new_image = model(direction) # Get a new image from the model | |
| return new_image | |
| if __name__ == '__main__': | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| image_display = gr.Image(value=image, type="pil", label="Generated Image") | |
| with gr.Row(): | |
| up = gr.Button("β Up") | |
| with gr.Row(): | |
| left = gr.Button("β Left") | |
| down = gr.Button("β Down") | |
| right = gr.Button("β Right") | |
| # Define button interactions | |
| up.click(fn=lambda: handle_input("up"), outputs=image_display, show_progress='hidden') | |
| down.click(fn=lambda: handle_input("down"), outputs=image_display, show_progress='hidden') | |
| left.click(fn=lambda: handle_input("left"), outputs=image_display, show_progress='hidden') | |
| right.click(fn=lambda: handle_input("right"), outputs=image_display, show_progress='hidden') | |
| demo.launch() | |