Spaces:
Running
on
Zero
Running
on
Zero
import pytest | |
import time | |
import torch | |
import urllib.error | |
import numpy as np | |
import subprocess | |
from pytest import fixture | |
from comfy_execution.graph_utils import GraphBuilder | |
from tests.inference.test_execution import ComfyClient, run_warmup | |
class TestAsyncNodes: | |
def _server(self, args_pytest, request): | |
pargs = [ | |
'python','main.py', | |
'--output-directory', args_pytest["output_dir"], | |
'--listen', args_pytest["listen"], | |
'--port', str(args_pytest["port"]), | |
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', | |
'--cpu', | |
] | |
use_lru, lru_size = request.param | |
if use_lru: | |
pargs += ['--cache-lru', str(lru_size)] | |
# Running server with args: pargs | |
p = subprocess.Popen(pargs) | |
yield | |
p.kill() | |
torch.cuda.empty_cache() | |
def shared_client(self, args_pytest, _server): | |
client = ComfyClient() | |
n_tries = 5 | |
for i in range(n_tries): | |
time.sleep(4) | |
try: | |
client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) | |
except ConnectionRefusedError: | |
# Retrying... | |
pass | |
else: | |
break | |
yield client | |
del client | |
torch.cuda.empty_cache() | |
def client(self, shared_client, request): | |
shared_client.set_test_name(f"async_nodes[{request.node.name}]") | |
yield shared_client | |
def builder(self, request): | |
yield GraphBuilder(prefix=request.node.name) | |
# Happy Path Tests | |
def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test that a basic async node executes correctly.""" | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.1) | |
output = g.node("SaveImage", images=sleep_node.out(0)) | |
result = client.run(g) | |
# Verify execution completed | |
assert result.did_run(sleep_node), "Async sleep node should have executed" | |
assert result.did_run(output), "Output node should have executed" | |
# Verify the image passed through correctly | |
result_images = result.get_images(output) | |
assert len(result_images) == 1, "Should have 1 image" | |
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" | |
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test that multiple async nodes execute in parallel.""" | |
# Warmup execution to ensure server is fully initialized | |
run_warmup(client) | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
# Create multiple async sleep nodes with different durations | |
sleep1 = g.node("TestSleep", value=image.out(0), seconds=0.3) | |
sleep2 = g.node("TestSleep", value=image.out(0), seconds=0.4) | |
sleep3 = g.node("TestSleep", value=image.out(0), seconds=0.5) | |
# Add outputs for each | |
_output1 = g.node("PreviewImage", images=sleep1.out(0)) | |
_output2 = g.node("PreviewImage", images=sleep2.out(0)) | |
_output3 = g.node("PreviewImage", images=sleep3.out(0)) | |
start_time = time.time() | |
result = client.run(g) | |
elapsed_time = time.time() - start_time | |
# Should take ~0.5s (max duration) not 1.2s (sum of durations) | |
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" | |
# Verify all nodes executed | |
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) | |
def test_async_with_dependencies(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async nodes with proper dependency handling.""" | |
g = builder | |
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
# Chain of async operations | |
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) | |
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) | |
# Average depends on both async results | |
average = g.node("TestVariadicAverage", input1=sleep1.out(0), input2=sleep2.out(0)) | |
output = g.node("SaveImage", images=average.out(0)) | |
result = client.run(g) | |
# Verify execution order | |
assert result.did_run(sleep1) and result.did_run(sleep2) | |
assert result.did_run(average) and result.did_run(output) | |
# Verify averaged result | |
result_images = result.get_images(output) | |
avg_value = np.array(result_images[0]).mean() | |
assert abs(avg_value - 127.5) < 1, f"Average value {avg_value} should be ~127.5" | |
def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async VALIDATE_INPUTS function.""" | |
g = builder | |
# Create a test node with async validation | |
validation_node = g.node("TestAsyncValidation", value=5.0, threshold=10.0) | |
g.node("SaveImage", images=validation_node.out(0)) | |
# Should pass validation | |
result = client.run(g) | |
assert result.did_run(validation_node) | |
# Test validation failure | |
validation_node.inputs['threshold'] = 3.0 # Will fail since value > threshold | |
with pytest.raises(urllib.error.HTTPError): | |
client.run(g) | |
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async nodes with lazy evaluation.""" | |
# Warmup execution to ensure server is fully initialized | |
run_warmup(client, prefix="warmup_lazy") | |
g = builder | |
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
mask = g.node("StubMask", value=0.0, height=512, width=512, batch_size=1) | |
# Create async nodes that will be evaluated lazily | |
sleep1 = g.node("TestSleep", value=input1.out(0), seconds=0.3) | |
sleep2 = g.node("TestSleep", value=input2.out(0), seconds=0.3) | |
# Use lazy mix that only needs sleep1 (mask=0.0) | |
lazy_mix = g.node("TestLazyMixImages", image1=sleep1.out(0), image2=sleep2.out(0), mask=mask.out(0)) | |
g.node("SaveImage", images=lazy_mix.out(0)) | |
start_time = time.time() | |
result = client.run(g) | |
elapsed_time = time.time() - start_time | |
# Should only execute sleep1, not sleep2 | |
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" | |
assert result.did_run(sleep1), "Sleep1 should have executed" | |
assert not result.did_run(sleep2), "Sleep2 should have been skipped" | |
def test_async_check_lazy_status(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async check_lazy_status function.""" | |
g = builder | |
# Create a node with async check_lazy_status | |
lazy_node = g.node("TestAsyncLazyCheck", | |
input1="value1", | |
input2="value2", | |
condition=True) | |
g.node("SaveImage", images=lazy_node.out(0)) | |
result = client.run(g) | |
assert result.did_run(lazy_node) | |
# Error Handling Tests | |
def test_async_execution_error(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test that async execution errors are properly handled.""" | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
# Create an async node that will error | |
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) | |
g.node("SaveImage", images=error_node.out(0)) | |
try: | |
client.run(g) | |
assert False, "Should have raised an error" | |
except Exception as e: | |
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}" | |
assert e.args[0]['node_id'] == error_node.id, "Error should be from async error node" | |
def test_async_validation_error(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async validation error handling.""" | |
g = builder | |
# Node with async validation that will fail | |
validation_node = g.node("TestAsyncValidationError", value=15.0, max_value=10.0) | |
g.node("SaveImage", images=validation_node.out(0)) | |
with pytest.raises(urllib.error.HTTPError) as exc_info: | |
client.run(g) | |
# Verify it's a validation error | |
assert exc_info.value.code == 400 | |
def test_async_timeout_handling(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test handling of async operations that timeout.""" | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
# Very long sleep that would timeout | |
timeout_node = g.node("TestAsyncTimeout", value=image.out(0), timeout=0.5, operation_time=2.0) | |
g.node("SaveImage", images=timeout_node.out(0)) | |
try: | |
client.run(g) | |
assert False, "Should have raised a timeout error" | |
except Exception as e: | |
assert 'timeout' in str(e).lower(), f"Expected timeout error, got: {e}" | |
def test_concurrent_async_error_recovery(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test that workflow can recover after async errors.""" | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
# First run with error | |
error_node = g.node("TestAsyncError", value=image.out(0), error_after=0.1) | |
g.node("SaveImage", images=error_node.out(0)) | |
try: | |
client.run(g) | |
except Exception: | |
pass # Expected | |
# Second run should succeed | |
g2 = GraphBuilder(prefix="recovery_test") | |
image2 = g2.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
sleep_node = g2.node("TestSleep", value=image2.out(0), seconds=0.1) | |
g2.node("SaveImage", images=sleep_node.out(0)) | |
result = client.run(g2) | |
assert result.did_run(sleep_node), "Should be able to run after error" | |
def test_sync_error_during_async_execution(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test handling when sync node errors while async node is executing.""" | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
# Async node that takes time | |
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.5) | |
# Sync node that will error immediately | |
error_node = g.node("TestSyncError", value=image.out(0)) | |
# Both feed into output | |
g.node("PreviewImage", images=sleep_node.out(0)) | |
g.node("PreviewImage", images=error_node.out(0)) | |
try: | |
client.run(g) | |
assert False, "Should have raised an error" | |
except Exception as e: | |
# Verify the sync error was caught even though async was running | |
assert 'prompt_id' in e.args[0] | |
# Edge Cases | |
def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async nodes with execution blockers.""" | |
g = builder | |
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
# Async sleep nodes | |
sleep1 = g.node("TestSleep", value=image1.out(0), seconds=0.2) | |
sleep2 = g.node("TestSleep", value=image2.out(0), seconds=0.2) | |
# Create list of images | |
image_list = g.node("TestMakeListNode", value1=sleep1.out(0), value2=sleep2.out(0)) | |
# Create list of blocking conditions - [False, True] to block only the second item | |
int1 = g.node("StubInt", value=1) | |
int2 = g.node("StubInt", value=2) | |
block_list = g.node("TestMakeListNode", value1=int1.out(0), value2=int2.out(0)) | |
# Compare each value against 2, so first is False (1 != 2) and second is True (2 == 2) | |
compare = g.node("TestIntConditions", a=block_list.out(0), b=2, operation="==") | |
# Block based on the comparison results | |
blocker = g.node("TestExecutionBlocker", input=image_list.out(0), block=compare.out(0), verbose=False) | |
output = g.node("PreviewImage", images=blocker.out(0)) | |
result = client.run(g) | |
images = result.get_images(output) | |
assert len(images) == 1, "Should have blocked second image" | |
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test that async nodes are properly cached.""" | |
# Warmup execution to ensure server is fully initialized | |
run_warmup(client, prefix="warmup_cache") | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2) | |
g.node("SaveImage", images=sleep_node.out(0)) | |
# First run | |
result1 = client.run(g) | |
assert result1.did_run(sleep_node), "Should run first time" | |
# Second run - should be cached | |
start_time = time.time() | |
result2 = client.run(g) | |
elapsed_time = time.time() - start_time | |
assert not result2.did_run(sleep_node), "Should be cached" | |
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" | |
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test async nodes within dynamically generated prompts.""" | |
# Warmup execution to ensure server is fully initialized | |
run_warmup(client, prefix="warmup_dynamic") | |
g = builder | |
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
# Node that generates async nodes dynamically | |
dynamic_async = g.node("TestDynamicAsyncGeneration", | |
image1=image1.out(0), | |
image2=image2.out(0), | |
num_async_nodes=3, | |
sleep_duration=0.2) | |
g.node("SaveImage", images=dynamic_async.out(0)) | |
start_time = time.time() | |
result = client.run(g) | |
elapsed_time = time.time() - start_time | |
# Should execute async nodes in parallel within dynamic prompt | |
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" | |
assert result.did_run(dynamic_async) | |
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test that async resources are properly cleaned up.""" | |
g = builder | |
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
# Create multiple async nodes that use resources | |
resource_nodes = [] | |
for i in range(5): | |
node = g.node("TestAsyncResourceUser", | |
value=image.out(0), | |
resource_id=f"resource_{i}", | |
duration=0.1) | |
resource_nodes.append(node) | |
g.node("PreviewImage", images=node.out(0)) | |
result = client.run(g) | |
# Verify all nodes executed | |
for node in resource_nodes: | |
assert result.did_run(node) | |
# Run again to ensure resources were cleaned up | |
result2 = client.run(g) | |
# Should be cached but not error due to resource conflicts | |
for node in resource_nodes: | |
assert not result2.did_run(node), "Should be cached" | |
def test_async_cancellation(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test cancellation of async operations.""" | |
# This would require implementing cancellation in the client | |
# For now, we'll test that long-running async operations can be interrupted | |
pass # TODO: Implement when cancellation API is available | |
def test_mixed_sync_async_execution(self, client: ComfyClient, builder: GraphBuilder): | |
"""Test workflows with both sync and async nodes.""" | |
g = builder | |
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) | |
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1) | |
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1) | |
# Mix of sync and async operations | |
# Sync: lazy mix images | |
sync_op1 = g.node("TestLazyMixImages", image1=image1.out(0), image2=image2.out(0), mask=mask.out(0)) | |
# Async: sleep | |
async_op1 = g.node("TestSleep", value=sync_op1.out(0), seconds=0.2) | |
# Sync: custom validation | |
sync_op2 = g.node("TestCustomValidation1", input1=async_op1.out(0), input2=0.5) | |
# Async: sleep again | |
async_op2 = g.node("TestSleep", value=sync_op2.out(0), seconds=0.2) | |
output = g.node("SaveImage", images=async_op2.out(0)) | |
result = client.run(g) | |
# Verify all nodes executed in correct order | |
assert result.did_run(sync_op1) | |
assert result.did_run(async_op1) | |
assert result.did_run(sync_op2) | |
assert result.did_run(async_op2) | |
# Image should be a mix of black and white (gray) | |
result_images = result.get_images(output) | |
avg_value = np.array(result_images[0]).mean() | |
assert abs(avg_value - 63.75) < 5, f"Average value {avg_value} should be ~63.75" | |