multimodalart's picture
Upload 537 files
62bb9d8 verified
raw
history blame
18.1 kB
import pytest
import time
import torch
import urllib.error
import numpy as np
import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient, run_warmup
@pytest.mark.execution
class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[
(False, 0),
(True, 0),
(True, 100),
])
def _server(self, args_pytest, request):
pargs = [
'python','main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
if use_lru:
pargs += ['--cache-lru', str(lru_size)]
# Running server with args: pargs
p = subprocess.Popen(pargs)
yield
p.kill()
torch.cuda.empty_cache()
@fixture(scope="class", autouse=True)
def shared_client(self, args_pytest, _server):
client = ComfyClient()
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen=args_pytest["listen"], port=args_pytest["port"])
except ConnectionRefusedError:
# Retrying...
pass
else:
break
yield client
del client
torch.cuda.empty_cache()
@fixture
def client(self, shared_client, request):
shared_client.set_test_name(f"async_nodes[{request.node.name}]")
yield shared_client
@fixture
def builder(self, request):
yield GraphBuilder(prefix=request.node.name)
# Happy Path Tests
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that a basic async node executes correctly."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1)
output = g.node("SaveImage", images=sleep_node.out(0))
result = client.run(g)
# Verify execution completed
assert result.did_run(sleep_node), "Async sleep node should have executed"
assert result.did_run(output), "Output node should have executed"
# Verify the image passed through correctly
result_images = result.get_images(output)
assert len(result_images) == 1, "Should have 1 image"
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create multiple async sleep nodes with different durations
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3)
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4)
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5)
# Add outputs for each
_output1 = g.node("PreviewImage", images=sleep1.out(0))
_output2 = g.node("PreviewImage", images=sleep2.out(0))
_output3 = g.node("PreviewImage", images=sleep3.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
# Verify all nodes executed
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with proper dependency handling."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Chain of async operations
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
# Average depends on both async results
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0))
output = g.node("SaveImage", images=average.out(0))
result = client.run(g)
# Verify execution order
assert result.did_run(sleep1) and result.did_run(sleep2)
assert result.did_run(average) and result.did_run(output)
# Verify averaged result
result_images = result.get_images(output)
avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5"
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder):
"""Test async VALIDATE_INPUTS function."""
g = builder
# Create a test node with async validation
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0)
g.node("SaveImage", images=validation_node.out(0))
# Should pass validation
result = client.run(g)
assert result.did_run(validation_node)
# Test validation failure
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold
with pytest.raises(urllib.error.HTTPError):
client.run(g)
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1)
# Create async nodes that will be evaluated lazily
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3)
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3)
# Use lazy mix that only needs sleep1 (mask=0.0)
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0))
g.node("SaveImage", images=lazy_mix.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should only execute sleep1, not sleep2
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
assert result.did_run(sleep1), "Sleep1 should have executed"
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder):
"""Test async check_lazy_status function."""
g = builder
# Create a node with async check_lazy_status
lazy_node = g.node("TestAsyncLazyCheck",
input1="value1",
input2="value2",
condition=True)
g.node("SaveImage", images=lazy_node.out(0))
result = client.run(g)
assert result.did_run(lazy_node)
# Error Handling Tests
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async execution errors are properly handled."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create an async node that will error
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
g.node("SaveImage", images=error_node.out(0))
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node"
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder):
"""Test async validation error handling."""
g = builder
# Node with async validation that will fail
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0)
g.node("SaveImage", images=validation_node.out(0))
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
# Verify it's a validation error
assert exc_info.value.code == 400
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling of async operations that timeout."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Very long sleep that would timeout
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0)
g.node("SaveImage", images=timeout_node.out(0))
try:
client.run(g)
assert False, "Should have raised a timeout error"
except Exception as e:
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}"
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder):
"""Test that workflow can recover after async errors."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# First run with error
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1)
g.node("SaveImage", images=error_node.out(0))
try:
client.run(g)
except Exception:
pass # Expected
# Second run should succeed
g2 = GraphBuilder(prefix="recovery_test")
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1)
g2.node("SaveImage", images=sleep_node.out(0))
result = client.run(g2)
assert result.did_run(sleep_node), "Should be able to run after error"
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test handling when sync node errors while async node is executing."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Async node that takes time
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5)
# Sync node that will error immediately
error_node = g.node("TestSyncError", value=image.out(0))
# Both feed into output
g.node("PreviewImage", images=sleep_node.out(0))
g.node("PreviewImage", images=error_node.out(0))
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
# Verify the sync error was caught even though async was running
assert 'prompt_id' in e.args[0]
# Edge Cases
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes with execution blockers."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Async sleep nodes
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2)
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2)
# Create list of images
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0))
# Create list of blocking conditions - [False, True] to block only the second item
int1 = g.node("StubInt", value=1)
int2 = g.node("StubInt", value=2)
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0))
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2)
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==")
# Block based on the comparison results
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False)
output = g.node("PreviewImage", images=blocker.out(0))
result = client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have blocked second image"
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
g.node("SaveImage", images=sleep_node.out(0))
# First run
result1 = client.run(g)
assert result1.did_run(sleep_node), "Should run first time"
# Second run - should be cached
start_time = time.time()
result2 = client.run(g)
elapsed_time = time.time() - start_time
assert not result2.did_run(sleep_node), "Should be cached"
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
"""Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
# Node that generates async nodes dynamically
dynamic_async = g.node("TestDynamicAsyncGeneration",
image1=image1.out(0),
image2=image2.out(0),
num_async_nodes=3,
sleep_duration=0.2)
g.node("SaveImage", images=dynamic_async.out(0))
start_time = time.time()
result = client.run(g)
elapsed_time = time.time() - start_time
# Should execute async nodes in parallel within dynamic prompt
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
assert result.did_run(dynamic_async)
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):
"""Test that async resources are properly cleaned up."""
g = builder
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
# Create multiple async nodes that use resources
resource_nodes = []
for i in range(5):
node = g.node("TestAsyncResourceUser",
value=image.out(0),
resource_id=f"resource_{i}",
duration=0.1)
resource_nodes.append(node)
g.node("PreviewImage", images=node.out(0))
result = client.run(g)
# Verify all nodes executed
for node in resource_nodes:
assert result.did_run(node)
# Run again to ensure resources were cleaned up
result2 = client.run(g)
# Should be cached but not error due to resource conflicts
for node in resource_nodes:
assert not result2.did_run(node), "Should be cached"
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder):
"""Test cancellation of async operations."""
# This would require implementing cancellation in the client
# For now, we'll test that long-running async operations can be interrupted
pass # TODO: Implement when cancellation API is available
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder):
"""Test workflows with both sync and async nodes."""
g = builder
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
# Mix of sync and async operations
# Sync: lazy mix images
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0))
# Async: sleep
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2)
# Sync: custom validation
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5)
# Async: sleep again
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2)
output = g.node("SaveImage", images=async_op2.out(0))
result = client.run(g)
# Verify all nodes executed in correct order
assert result.did_run(sync_op1)
assert result.did_run(async_op1)
assert result.did_run(sync_op2)
assert result.did_run(async_op2)
# Image should be a mix of black and white (gray)
result_images = result.get_images(output)
avg_value = np.array(result_images[0]).mean()
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75"