File size: 4,087 Bytes
f034856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1c4171
 
f034856
e1c4171
f034856
e1c4171
f034856
e1c4171
f034856
3605c07
f034856
 
e1c4171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f034856
3605c07
f034856
e1c4171
f034856
e1c4171
f034856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
WebSocket-compatible streaming pipeline for Matrix-Game V2
This wraps the streaming pipeline to use WebSocket inputs instead of stdin
"""

import torch
from pipeline import CausalInferenceStreamingPipeline
from pipeline.causal_inference import cond_current
import logging

logger = logging.getLogger(__name__)

class WebSocketStreamingPipeline(CausalInferenceStreamingPipeline):
    """
    A streaming pipeline that accepts actions via parameters instead of stdin
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.current_keyboard = None
        self.current_mouse = None
        
    def set_current_action(self, keyboard_state, mouse_state):
        """Set the current action from WebSocket input"""
        self.current_keyboard = keyboard_state
        self.current_mouse = mouse_state
        
    def get_websocket_action(self, mode="universal"):
        """
        Get current action from stored WebSocket data instead of stdin
        Returns the same format as get_current_action()
        
        Note: The format conversion is now handled at the engine level,
        so we just need to convert multi-action to single dominant action.
        """
        # Get current states (already in pipeline format from engine conversion)
        if self.current_keyboard is None:
            keyboard = [0, 0, 0, 0] if mode == 'universal' else [0, 0]
        else:
            keyboard = self.current_keyboard
            
        if self.current_mouse is None:
            mouse = [0, 0]
        else:
            mouse = self.current_mouse
        
        # Convert multi-action to single dominant action for the streaming pipeline
        # This handles cases where multiple keys might be pressed simultaneously
        dominant_keyboard = []
        for i, val in enumerate(keyboard):
            if val > 0:
                # Create one-hot vector with this action
                dominant = [0] * len(keyboard)
                dominant[i] = 1
                dominant_keyboard = dominant
                break
        
        if not dominant_keyboard:
            # No action pressed
            dominant_keyboard = [0] * len(keyboard)
            
        # Convert to tensors
        mouse_tensor = torch.tensor(mouse, dtype=torch.float32).cuda()
        keyboard_tensor = torch.tensor(dominant_keyboard, dtype=torch.float32).cuda()
        
        logger.debug(f"WebSocket action for mode {mode}: kb={keyboard} -> dominant_kb={dominant_keyboard}, mouse={mouse}")
        
        return {
            'mouse': mouse_tensor,
            'keyboard': keyboard_tensor
        }
    
    def inference(
        self,
        noise: torch.Tensor,
        conditional_dict,
        initial_latent=None,
        return_latents=False,
        output_folder=None,
        name=None,
        mode='universal',
        keyboard_condition=None,
        mouse_condition=None
    ) -> torch.Tensor:
        """
        Override inference to use WebSocket actions instead of stdin
        """
        # Store the provided conditions for use during inference
        if keyboard_condition is not None:
            self.current_keyboard = keyboard_condition
        if mouse_condition is not None:
            self.current_mouse = mouse_condition
            
        # Monkey-patch get_current_action during this inference
        import pipeline.causal_inference as ci
        original_get_current_action = ci.get_current_action
        ci.get_current_action = self.get_websocket_action
        
        try:
            # Call parent inference method
            result = super().inference(
                noise=noise,
                conditional_dict=conditional_dict,
                initial_latent=initial_latent,
                return_latents=return_latents,
                output_folder=output_folder,
                name=name,
                mode=mode
            )
            return result
        finally:
            # Restore original function
            ci.get_current_action = original_get_current_action