File size: 18,094 Bytes
62bb9d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import pytest
import time
import torch
import urllib.error
import numpy as np
import subprocess

from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient, run_warmup


@pytest.mark.execution
class TestAsyncNodes:
    @fixture(scope="class", autouse=True, params=[
        (False, 0),
        (True, 0),
        (True, 100),
    ])
    def _server(self, args_pytest, request):
        pargs = [
            'python','main.py',
            '--output-directory', args_pytest["output_dir"],
            '--listen', args_pytest["listen"],
            '--port', str(args_pytest["port"]),
            '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
            '--cpu',
        ]
        use_lru, lru_size = request.param
        if use_lru:
            pargs += ['--cache-lru', str(lru_size)]
        # Running server with args: pargs
        p = subprocess.Popen(pargs)
        yield
        p.kill()
        torch.cuda.empty_cache()

    @fixture(scope="class", autouse=True)
    def shared_client(self, args_pytest, _server):
        client = ComfyClient()
        n_tries = 5
        for i in range(n_tries):
            time.sleep(4)
            try:
                client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
            except ConnectionRefusedError:
                # Retrying...
                pass
            else:
                break
        yield client
        del client
        torch.cuda.empty_cache()

    @fixture
    def client(self, shared_client, request):
        shared_client.set_test_name(f"async_nodes[{request.node.name}]")
        yield shared_client

    @fixture
    def builder(self, request):
        yield GraphBuilder(prefix=request.node.name)

    # Happy Path Tests

    def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
        """Test that a basic async node executes correctly."""
        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
        output = g.node("SaveImage", images=sleep_node.out(0))

        result = client.run(g)

        # Verify execution completed
        assert result.did_run(sleep_node), "Async sleep node should have executed"
        assert result.did_run(output), "Output node should have executed"

        # Verify the image passed through correctly
        result_images = result.get_images(output)
        assert len(result_images) == 1, "Should have 1 image"
        assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"

    def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
        """Test that multiple async nodes execute in parallel."""
        # Warmup execution to ensure server is fully initialized
        run_warmup(client)

        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

        # Create multiple async sleep nodes with different durations
        sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
        sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
        sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)

        # Add outputs for each
        _output1 = g.node("PreviewImage", images=sleep1.out(0))
        _output2 = g.node("PreviewImage", images=sleep2.out(0))
        _output3 = g.node("PreviewImage", images=sleep3.out(0))

        start_time = time.time()
        result = client.run(g)
        elapsed_time = time.time() - start_time

        # Should take ~0.5s (max duration) not 1.2s (sum of durations)
        assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"

        # Verify all nodes executed
        assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)

    def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
        """Test async nodes with proper dependency handling."""
        g = builder
        image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

        # Chain of async operations
        sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
        sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)

        # Average depends on both async results
        average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
        output = g.node("SaveImage", images=average.out(0))

        result = client.run(g)

        # Verify execution order
        assert result.did_run(sleep1) and result.did_run(sleep2)
        assert result.did_run(average) and result.did_run(output)

        # Verify averaged result
        result_images = result.get_images(output)
        avg_value = np.array(result_images[0]).mean()
        assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"

    def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
        """Test async VALIDATE_INPUTS function."""
        g = builder
        # Create a test node with async validation
        validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
        g.node("SaveImage", images=validation_node.out(0))

        # Should pass validation
        result = client.run(g)
        assert result.did_run(validation_node)

        # Test validation failure
        validation_node.inputs['threshold'] = 3.0  # Will fail since value > threshold
        with pytest.raises(urllib.error.HTTPError):
            client.run(g)

    def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
        """Test async nodes with lazy evaluation."""
        # Warmup execution to ensure server is fully initialized
        run_warmup(client, prefix="warmup_lazy")

        g = builder
        input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
        mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)

        # Create async nodes that will be evaluated lazily
        sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
        sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)

        # Use lazy mix that only needs sleep1 (mask=0.0)
        lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
        g.node("SaveImage", images=lazy_mix.out(0))

        start_time = time.time()
        result = client.run(g)
        elapsed_time = time.time() - start_time

        # Should only execute sleep1, not sleep2
        assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
        assert result.did_run(sleep1), "Sleep1 should have executed"
        assert not result.did_run(sleep2), "Sleep2 should have been skipped"

    def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
        """Test async check_lazy_status function."""
        g = builder
        # Create a node with async check_lazy_status
        lazy_node = g.node("TestAsyncLazyCheck",
                          input1="value1",
                          input2="value2",
                          condition=True)
        g.node("SaveImage", images=lazy_node.out(0))

        result = client.run(g)
        assert result.did_run(lazy_node)

    # Error Handling Tests

    def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
        """Test that async execution errors are properly handled."""
        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        # Create an async node that will error
        error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
        g.node("SaveImage", images=error_node.out(0))

        try:
            client.run(g)
            assert False, "Should have raised an error"
        except Exception as e:
            assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
            assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"

    def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
        """Test async validation error handling."""
        g = builder
        # Node with async validation that will fail
        validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
        g.node("SaveImage", images=validation_node.out(0))

        with pytest.raises(urllib.error.HTTPError) as exc_info:
            client.run(g)
        # Verify it's a validation error
        assert exc_info.value.code == 400

    def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
        """Test handling of async operations that timeout."""
        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        # Very long sleep that would timeout
        timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
        g.node("SaveImage", images=timeout_node.out(0))

        try:
            client.run(g)
            assert False, "Should have raised a timeout error"
        except Exception as e:
            assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"

    def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
        """Test that workflow can recover after async errors."""
        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

        # First run with error
        error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
        g.node("SaveImage", images=error_node.out(0))

        try:
            client.run(g)
        except Exception:
            pass  # Expected

        # Second run should succeed
        g2 = GraphBuilder(prefix="recovery_test")
        image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
        sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
        g2.node("SaveImage", images=sleep_node.out(0))

        result = client.run(g2)
        assert result.did_run(sleep_node), "Should be able to run after error"

    def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
        """Test handling when sync node errors while async node is executing."""
        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

        # Async node that takes time
        sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)

        # Sync node that will error immediately
        error_node = g.node("TestSyncError", value=image.out(0))

        # Both feed into output
        g.node("PreviewImage", images=sleep_node.out(0))
        g.node("PreviewImage", images=error_node.out(0))

        try:
            client.run(g)
            assert False, "Should have raised an error"
        except Exception as e:
            # Verify the sync error was caught even though async was running
            assert 'prompt_id' in e.args[0]

    # Edge Cases

    def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
        """Test async nodes with execution blockers."""
        g = builder
        image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

        # Async sleep nodes
        sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
        sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)

        # Create list of images
        image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))

        # Create list of blocking conditions - [False, True] to block only the second item
        int1 = g.node("StubInt", value=1)
        int2 = g.node("StubInt", value=2)
        block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))

        # Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
        compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")

        # Block based on the comparison results
        blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)

        output = g.node("PreviewImage", images=blocker.out(0))

        result = client.run(g)
        images = result.get_images(output)
        assert len(images) == 1, "Should have blocked second image"

    def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
        """Test that async nodes are properly cached."""
        # Warmup execution to ensure server is fully initialized
        run_warmup(client, prefix="warmup_cache")

        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
        g.node("SaveImage", images=sleep_node.out(0))

        # First run
        result1 = client.run(g)
        assert result1.did_run(sleep_node), "Should run first time"

        # Second run - should be cached
        start_time = time.time()
        result2 = client.run(g)
        elapsed_time = time.time() - start_time

        assert not result2.did_run(sleep_node), "Should be cached"
        assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"

    def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
        """Test async nodes within dynamically generated prompts."""
        # Warmup execution to ensure server is fully initialized
        run_warmup(client, prefix="warmup_dynamic")

        g = builder
        image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

        # Node that generates async nodes dynamically
        dynamic_async = g.node("TestDynamicAsyncGeneration",
                              image1=image1.out(0),
                              image2=image2.out(0),
                              num_async_nodes=3,
                              sleep_duration=0.2)
        g.node("SaveImage", images=dynamic_async.out(0))

        start_time = time.time()
        result = client.run(g)
        elapsed_time = time.time() - start_time

        # Should execute async nodes in parallel within dynamic prompt
        assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
        assert result.did_run(dynamic_async)

    def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
        """Test that async resources are properly cleaned up."""
        g = builder
        image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)

        # Create multiple async nodes that use resources
        resource_nodes = []
        for i in range(5):
            node = g.node("TestAsyncResourceUser",
                         value=image.out(0),
                         resource_id=f"resource_{i}",
                         duration=0.1)
            resource_nodes.append(node)
            g.node("PreviewImage", images=node.out(0))

        result = client.run(g)

        # Verify all nodes executed
        for node in resource_nodes:
            assert result.did_run(node)

        # Run again to ensure resources were cleaned up
        result2 = client.run(g)
        # Should be cached but not error due to resource conflicts
        for node in resource_nodes:
            assert not result2.did_run(node), "Should be cached"

    def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
        """Test cancellation of async operations."""
        # This would require implementing cancellation in the client
        # For now, we'll test that long-running async operations can be interrupted
        pass  # TODO: Implement when cancellation API is available

    def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
        """Test workflows with both sync and async nodes."""
        g = builder
        image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
        image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
        mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)

        # Mix of sync and async operations
        # Sync: lazy mix images
        sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
        # Async: sleep
        async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
        # Sync: custom validation
        sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
        # Async: sleep again
        async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)

        output = g.node("SaveImage", images=async_op2.out(0))

        result = client.run(g)

        # Verify all nodes executed in correct order
        assert result.did_run(sync_op1)
        assert result.did_run(async_op1)
        assert result.did_run(sync_op2)
        assert result.did_run(async_op2)

        # Image should be a mix of black and white (gray)
        result_images = result.get_images(output)
        avg_value = np.array(result_images[0]).mean()
        assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"