Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 4,298 Bytes
			
			| 50eec37 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import pytest
from comfy_execution.validation import validate_node_input
def test_exact_match():
    """Test cases where types match exactly"""
    assert validate_node_input("STRING", "STRING")
    assert validate_node_input("STRING,INT", "STRING,INT")
    assert validate_node_input("INT,STRING", "STRING,INT")  # Order shouldn't matter
def test_strict_mode():
    """Test strict mode validation"""
    # Should pass - received type is subset of input type
    assert validate_node_input("STRING", "STRING,INT", strict=True)
    assert validate_node_input("INT", "STRING,INT", strict=True)
    assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
    # Should fail - received type is not subset of input type
    assert not validate_node_input("STRING,INT", "STRING", strict=True)
    assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
    assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
def test_non_strict_mode():
    """Test non-strict mode validation (default behavior)"""
    # Should pass - types have overlap
    assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
    assert validate_node_input("STRING,INT", "INT,BOOLEAN")
    assert validate_node_input("STRING", "STRING,INT")
    # Should fail - no overlap in types
    assert not validate_node_input("BOOLEAN", "STRING,INT")
    assert not validate_node_input("FLOAT", "STRING,INT")
    assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
def test_whitespace_handling():
    """Test that whitespace is handled correctly"""
    assert validate_node_input("STRING, INT", "STRING,INT")
    assert validate_node_input("STRING,INT", "STRING, INT")
    assert validate_node_input(" STRING , INT ", "STRING,INT")
    assert validate_node_input("STRING,INT", " STRING , INT ")
def test_empty_strings():
    """Test behavior with empty strings"""
    assert validate_node_input("", "")
    assert not validate_node_input("STRING", "")
    assert not validate_node_input("", "STRING")
def test_single_vs_multiple():
    """Test single type against multiple types"""
    assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
    assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
    assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
def test_non_string():
    """Test non-string types"""
    obj1 = object()
    obj2 = object()
    assert validate_node_input(obj1, obj1)
    assert not validate_node_input(obj1, obj2)
class NotEqualsOverrideTest(str):
    """Test class for ``__ne__`` override."""
    def __ne__(self, value: object) -> bool:
        if self == "*" or value == "*":
            return False
        if self == "LONGER_THAN_2":
            return not len(value) > 2
        raise TypeError("This is a class for unit tests only.")
def test_ne_override():
    """Test ``__ne__`` any override"""
    any = NotEqualsOverrideTest("*")
    invalid_type = "INVALID_TYPE"
    obj = object()
    assert validate_node_input(any, any)
    assert validate_node_input(any, invalid_type)
    assert validate_node_input(any, obj)
    assert validate_node_input(any, {})
    assert validate_node_input(any, [])
    assert validate_node_input(any, [1, 2, 3])
def test_ne_custom_override():
    """Test ``__ne__`` custom override"""
    special = NotEqualsOverrideTest("LONGER_THAN_2")
    assert validate_node_input(special, special)
    assert validate_node_input(special, "*")
    assert validate_node_input(special, "INVALID_TYPE")
    assert validate_node_input(special, [1, 2, 3])
    # Should fail
    assert not validate_node_input(special, [1, 2])
    assert not validate_node_input(special, "TY")
@pytest.mark.parametrize(
    "received,input_type,strict,expected",
    [
        ("STRING", "STRING", False, True),
        ("STRING,INT", "STRING,INT", False, True),
        ("STRING", "STRING,INT", True, True),
        ("STRING,INT", "STRING", True, False),
        ("BOOLEAN", "STRING,INT", False, False),
        ("STRING,BOOLEAN", "STRING,INT", False, True),
    ],
)
def test_parametrized_cases(received, input_type, strict, expected):
    """Parametrized test cases for various scenarios"""
    assert validate_node_input(received, input_type, strict) == expected
 | 
