NON_WORKING_matrix_game_2 / websocket_pipeline.py
jbilcke-hf's picture
jbilcke-hf HF Staff
let's break the loop
f034856
raw
history blame
4.4 kB
"""
WebSocket-compatible streaming pipeline for Matrix-Game V2
This wraps the streaming pipeline to use WebSocket inputs instead of stdin
"""
import torch
from pipeline import CausalInferenceStreamingPipeline
from pipeline.causal_inference import cond_current
import logging
logger = logging.getLogger(__name__)
class WebSocketStreamingPipeline(CausalInferenceStreamingPipeline):
"""
A streaming pipeline that accepts actions via parameters instead of stdin
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.current_keyboard = None
self.current_mouse = None
def set_current_action(self, keyboard_state, mouse_state):
"""Set the current action from WebSocket input"""
self.current_keyboard = keyboard_state
self.current_mouse = mouse_state
def get_websocket_action(self, mode="universal"):
"""
Get current action from stored WebSocket data instead of stdin
Returns the same format as get_current_action()
WebSocket keyboard format: [forward, back, left, right, jump, attack]
Pipeline expects different formats per mode:
- universal: keyboard [forward, back, left, right], mouse [x, y]
- gta_drive: keyboard [forward, back], mouse [x, y]
- templerun: keyboard [left, right], no mouse
"""
# Convert WebSocket keyboard format to mode-specific format
if self.current_keyboard is None:
ws_keyboard = [0, 0, 0, 0, 0, 0] # Default WebSocket format
else:
ws_keyboard = self.current_keyboard
# Map to mode-specific format
if mode == 'universal':
# Use first 4 values: forward, back, left, right
keyboard = ws_keyboard[:4]
elif mode == 'gta_drive':
# Use forward and back only
keyboard = [ws_keyboard[0], ws_keyboard[1]]
elif mode == 'templerun':
# Use left and right only
keyboard = [ws_keyboard[2], ws_keyboard[3]]
else:
# Default to universal format
keyboard = ws_keyboard[:4]
# Handle mouse (not used in templerun)
if mode == 'templerun':
# Temple Run doesn't use mouse, but return zeros for compatibility
mouse = [0, 0]
else:
if self.current_mouse is None:
mouse = [0, 0]
else:
mouse = self.current_mouse
# Convert to tensors in the format expected by the pipeline
mouse_tensor = torch.tensor(mouse, dtype=torch.float32).cuda()
keyboard_tensor = torch.tensor(keyboard, dtype=torch.float32).cuda()
logger.debug(f"WebSocket action for mode {mode}: keyboard={keyboard}, mouse={mouse}")
return {
'mouse': mouse_tensor,
'keyboard': keyboard_tensor
}
def inference(
self,
noise: torch.Tensor,
conditional_dict,
initial_latent=None,
return_latents=False,
output_folder=None,
name=None,
mode='universal',
keyboard_condition=None,
mouse_condition=None
) -> torch.Tensor:
"""
Override inference to use WebSocket actions instead of stdin
"""
# Store the provided conditions for use during inference
if keyboard_condition is not None:
self.current_keyboard = keyboard_condition
if mouse_condition is not None:
self.current_mouse = mouse_condition
# Monkey-patch get_current_action during this inference
import pipeline.causal_inference as ci
original_get_current_action = ci.get_current_action
ci.get_current_action = self.get_websocket_action
try:
# Call parent inference method
result = super().inference(
noise=noise,
conditional_dict=conditional_dict,
initial_latent=initial_latent,
return_latents=return_latents,
output_folder=output_folder,
name=name,
mode=mode
)
return result
finally:
# Restore original function
ci.get_current_action = original_get_current_action