File size: 4,401 Bytes
f034856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
WebSocket-compatible streaming pipeline for Matrix-Game V2
This wraps the streaming pipeline to use WebSocket inputs instead of stdin
"""

import torch
from pipeline import CausalInferenceStreamingPipeline
from pipeline.causal_inference import cond_current
import logging

logger = logging.getLogger(__name__)

class WebSocketStreamingPipeline(CausalInferenceStreamingPipeline):
    """
    A streaming pipeline that accepts actions via parameters instead of stdin
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.current_keyboard = None
        self.current_mouse = None
        
    def set_current_action(self, keyboard_state, mouse_state):
        """Set the current action from WebSocket input"""
        self.current_keyboard = keyboard_state
        self.current_mouse = mouse_state
        
    def get_websocket_action(self, mode="universal"):
        """
        Get current action from stored WebSocket data instead of stdin
        Returns the same format as get_current_action()
        
        WebSocket keyboard format: [forward, back, left, right, jump, attack]
        Pipeline expects different formats per mode:
        - universal: keyboard [forward, back, left, right], mouse [x, y]
        - gta_drive: keyboard [forward, back], mouse [x, y] 
        - templerun: keyboard [left, right], no mouse
        """
        # Convert WebSocket keyboard format to mode-specific format
        if self.current_keyboard is None:
            ws_keyboard = [0, 0, 0, 0, 0, 0]  # Default WebSocket format
        else:
            ws_keyboard = self.current_keyboard
            
        # Map to mode-specific format
        if mode == 'universal':
            # Use first 4 values: forward, back, left, right
            keyboard = ws_keyboard[:4]
        elif mode == 'gta_drive':
            # Use forward and back only
            keyboard = [ws_keyboard[0], ws_keyboard[1]]
        elif mode == 'templerun':
            # Use left and right only
            keyboard = [ws_keyboard[2], ws_keyboard[3]]
        else:
            # Default to universal format
            keyboard = ws_keyboard[:4]
            
        # Handle mouse (not used in templerun)
        if mode == 'templerun':
            # Temple Run doesn't use mouse, but return zeros for compatibility
            mouse = [0, 0]
        else:
            if self.current_mouse is None:
                mouse = [0, 0]
            else:
                mouse = self.current_mouse
            
        # Convert to tensors in the format expected by the pipeline
        mouse_tensor = torch.tensor(mouse, dtype=torch.float32).cuda()
        keyboard_tensor = torch.tensor(keyboard, dtype=torch.float32).cuda()
        
        logger.debug(f"WebSocket action for mode {mode}: keyboard={keyboard}, mouse={mouse}")
        
        return {
            'mouse': mouse_tensor,
            'keyboard': keyboard_tensor
        }
    
    def inference(
        self,
        noise: torch.Tensor,
        conditional_dict,
        initial_latent=None,
        return_latents=False,
        output_folder=None,
        name=None,
        mode='universal',
        keyboard_condition=None,
        mouse_condition=None
    ) -> torch.Tensor:
        """
        Override inference to use WebSocket actions instead of stdin
        """
        # Store the provided conditions for use during inference
        if keyboard_condition is not None:
            self.current_keyboard = keyboard_condition
        if mouse_condition is not None:
            self.current_mouse = mouse_condition
            
        # Monkey-patch get_current_action during this inference
        import pipeline.causal_inference as ci
        original_get_current_action = ci.get_current_action
        ci.get_current_action = self.get_websocket_action
        
        try:
            # Call parent inference method
            result = super().inference(
                noise=noise,
                conditional_dict=conditional_dict,
                initial_latent=initial_latent,
                return_latents=return_latents,
                output_folder=output_folder,
                name=name,
                mode=mode
            )
            return result
        finally:
            # Restore original function
            ci.get_current_action = original_get_current_action