Spaces:
Running
on
Zero
Running
on
Zero
import asyncio | |
import concurrent.futures | |
import contextvars | |
import functools | |
import inspect | |
import logging | |
import os | |
import textwrap | |
import threading | |
from enum import Enum | |
from typing import Optional, Type, get_origin, get_args | |
class TypeTracker: | |
"""Tracks types discovered during stub generation for automatic import generation.""" | |
def __init__(self): | |
self.discovered_types = {} # type_name -> (module, qualname) | |
self.builtin_types = { | |
"Any", | |
"Dict", | |
"List", | |
"Optional", | |
"Tuple", | |
"Union", | |
"Set", | |
"Sequence", | |
"cast", | |
"NamedTuple", | |
"str", | |
"int", | |
"float", | |
"bool", | |
"None", | |
"bytes", | |
"object", | |
"type", | |
"dict", | |
"list", | |
"tuple", | |
"set", | |
} | |
self.already_imported = ( | |
set() | |
) # Track types already imported to avoid duplicates | |
def track_type(self, annotation): | |
"""Track a type annotation and record its module/import info.""" | |
if annotation is None or annotation is type(None): | |
return | |
# Skip builtins and typing module types we already import | |
type_name = getattr(annotation, "__name__", None) | |
if type_name and ( | |
type_name in self.builtin_types or type_name in self.already_imported | |
): | |
return | |
# Get module and qualname | |
module = getattr(annotation, "__module__", None) | |
qualname = getattr(annotation, "__qualname__", type_name or "") | |
# Skip types from typing module (they're already imported) | |
if module == "typing": | |
return | |
# Skip UnionType and GenericAlias from types module as they're handled specially | |
if module == "types" and type_name in ("UnionType", "GenericAlias"): | |
return | |
if module and module not in ["builtins", "__main__"]: | |
# Store the type info | |
if type_name: | |
self.discovered_types[type_name] = (module, qualname) | |
def get_imports(self, main_module_name: str) -> list[str]: | |
"""Generate import statements for all discovered types.""" | |
imports = [] | |
imports_by_module = {} | |
for type_name, (module, qualname) in sorted(self.discovered_types.items()): | |
# Skip types from the main module (they're already imported) | |
if main_module_name and module == main_module_name: | |
continue | |
if module not in imports_by_module: | |
imports_by_module[module] = [] | |
if type_name not in imports_by_module[module]: # Avoid duplicates | |
imports_by_module[module].append(type_name) | |
# Generate import statements | |
for module, types in sorted(imports_by_module.items()): | |
if len(types) == 1: | |
imports.append(f"from {module} import {types[0]}") | |
else: | |
imports.append(f"from {module} import {', '.join(sorted(set(types)))}") | |
return imports | |
class AsyncToSyncConverter: | |
""" | |
Provides utilities to convert async classes to sync classes with proper type hints. | |
""" | |
_thread_pool: Optional[concurrent.futures.ThreadPoolExecutor] = None | |
_thread_pool_lock = threading.Lock() | |
_thread_pool_initialized = False | |
def get_thread_pool(cls, max_workers=None) -> concurrent.futures.ThreadPoolExecutor: | |
"""Get or create the shared thread pool with proper thread-safe initialization.""" | |
# Fast path - check if already initialized without acquiring lock | |
if cls._thread_pool_initialized: | |
assert cls._thread_pool is not None, "Thread pool should be initialized" | |
return cls._thread_pool | |
# Slow path - acquire lock and create pool if needed | |
with cls._thread_pool_lock: | |
if not cls._thread_pool_initialized: | |
cls._thread_pool = concurrent.futures.ThreadPoolExecutor( | |
max_workers=max_workers, thread_name_prefix="async_to_sync_" | |
) | |
cls._thread_pool_initialized = True | |
# This should never be None at this point, but add assertion for type checker | |
assert cls._thread_pool is not None | |
return cls._thread_pool | |
def run_async_in_thread(cls, coro_func, *args, **kwargs): | |
""" | |
Run an async function in a separate thread from the thread pool. | |
Blocks until the async function completes. | |
Properly propagates contextvars between threads and manages event loops. | |
""" | |
# Capture current context - this includes all context variables | |
context = contextvars.copy_context() | |
# Store the result and any exception that occurs | |
result_container: dict = {"result": None, "exception": None} | |
# Function that runs in the thread pool | |
def run_in_thread(): | |
# Create new event loop for this thread | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
try: | |
# Create the coroutine within the context | |
async def run_with_context(): | |
# The coroutine function might access context variables | |
return await coro_func(*args, **kwargs) | |
# Run the coroutine with the captured context | |
# This ensures all context variables are available in the async function | |
result = context.run(loop.run_until_complete, run_with_context()) | |
result_container["result"] = result | |
except Exception as e: | |
# Store the exception to re-raise in the calling thread | |
result_container["exception"] = e | |
finally: | |
# Ensure event loop is properly closed to prevent warnings | |
try: | |
# Cancel any remaining tasks | |
pending = asyncio.all_tasks(loop) | |
for task in pending: | |
task.cancel() | |
# Run the loop briefly to handle cancellations | |
if pending: | |
loop.run_until_complete( | |
asyncio.gather(*pending, return_exceptions=True) | |
) | |
except Exception: | |
pass # Ignore errors during cleanup | |
# Close the event loop | |
loop.close() | |
# Clear the event loop from the thread | |
asyncio.set_event_loop(None) | |
# Submit to thread pool and wait for result | |
thread_pool = cls.get_thread_pool() | |
future = thread_pool.submit(run_in_thread) | |
future.result() # Wait for completion | |
# Re-raise any exception that occurred in the thread | |
if result_container["exception"] is not None: | |
raise result_container["exception"] | |
return result_container["result"] | |
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: | |
""" | |
Creates a new class with synchronous versions of all async methods. | |
Args: | |
async_class: The async class to convert | |
thread_pool_size: Size of thread pool to use | |
Returns: | |
A new class with sync versions of all async methods | |
""" | |
sync_class_name = "ComfyAPISyncStub" | |
cls.get_thread_pool(thread_pool_size) | |
# Create a proper class with docstrings and proper base classes | |
sync_class_dict = { | |
"__doc__": async_class.__doc__, | |
"__module__": async_class.__module__, | |
"__qualname__": sync_class_name, | |
"__orig_class__": async_class, # Store original class for typing references | |
} | |
# Create __init__ method | |
def __init__(self, *args, **kwargs): | |
self._async_instance = async_class(*args, **kwargs) | |
# Handle annotated class attributes (like execution: Execution) | |
# Get all annotations from the class hierarchy | |
all_annotations = {} | |
for base_class in reversed(inspect.getmro(async_class)): | |
if hasattr(base_class, "__annotations__"): | |
all_annotations.update(base_class.__annotations__) | |
# For each annotated attribute, check if it needs to be created or wrapped | |
for attr_name, attr_type in all_annotations.items(): | |
if hasattr(self._async_instance, attr_name): | |
# Attribute exists on the instance | |
attr = getattr(self._async_instance, attr_name) | |
# Check if this attribute needs a sync wrapper | |
if hasattr(attr, "__class__"): | |
from comfy_api.internal.singleton import ProxiedSingleton | |
if isinstance(attr, ProxiedSingleton): | |
# Create a sync version of this attribute | |
try: | |
sync_attr_class = cls.create_sync_class(attr.__class__) | |
# Create instance of the sync wrapper with the async instance | |
sync_attr = object.__new__(sync_attr_class) # type: ignore | |
sync_attr._async_instance = attr | |
setattr(self, attr_name, sync_attr) | |
except Exception: | |
# If we can't create a sync version, keep the original | |
setattr(self, attr_name, attr) | |
else: | |
# Not async, just copy the reference | |
setattr(self, attr_name, attr) | |
else: | |
# Attribute doesn't exist, but is annotated - create it | |
# This handles cases like execution: Execution | |
if isinstance(attr_type, type): | |
# Check if the type is defined as an inner class | |
if hasattr(async_class, attr_type.__name__): | |
inner_class = getattr(async_class, attr_type.__name__) | |
from comfy_api.internal.singleton import ProxiedSingleton | |
# Create an instance of the inner class | |
try: | |
# For ProxiedSingleton classes, get or create the singleton instance | |
if issubclass(inner_class, ProxiedSingleton): | |
async_instance = inner_class.get_instance() | |
else: | |
async_instance = inner_class() | |
# Create sync wrapper | |
sync_attr_class = cls.create_sync_class(inner_class) | |
sync_attr = object.__new__(sync_attr_class) # type: ignore | |
sync_attr._async_instance = async_instance | |
setattr(self, attr_name, sync_attr) | |
# Also set on the async instance for consistency | |
setattr(self._async_instance, attr_name, async_instance) | |
except Exception as e: | |
logging.warning( | |
f"Failed to create instance for {attr_name}: {e}" | |
) | |
# Handle other instance attributes that might not be annotated | |
for name, attr in inspect.getmembers(self._async_instance): | |
if name.startswith("_") or hasattr(self, name): | |
continue | |
# If attribute is an instance of a class, and that class is defined in the original class | |
# we need to check if it needs a sync wrapper | |
if isinstance(attr, object) and not isinstance( | |
attr, (str, int, float, bool, list, dict, tuple) | |
): | |
from comfy_api.internal.singleton import ProxiedSingleton | |
if isinstance(attr, ProxiedSingleton): | |
# Create a sync version of this nested class | |
try: | |
sync_attr_class = cls.create_sync_class(attr.__class__) | |
# Create instance of the sync wrapper with the async instance | |
sync_attr = object.__new__(sync_attr_class) # type: ignore | |
sync_attr._async_instance = attr | |
setattr(self, name, sync_attr) | |
except Exception: | |
# If we can't create a sync version, keep the original | |
setattr(self, name, attr) | |
sync_class_dict["__init__"] = __init__ | |
# Process methods from the async class | |
for name, method in inspect.getmembers( | |
async_class, predicate=inspect.isfunction | |
): | |
if name.startswith("_"): | |
continue | |
# Extract the actual return type from a coroutine | |
if inspect.iscoroutinefunction(method): | |
# Create sync version of async method with proper signature | |
def sync_method(self, *args, _method_name=name, **kwargs): | |
async_method = getattr(self._async_instance, _method_name) | |
return AsyncToSyncConverter.run_async_in_thread( | |
async_method, *args, **kwargs | |
) | |
# Add to the class dict | |
sync_class_dict[name] = sync_method | |
else: | |
# For regular methods, create a proxy method | |
def proxy_method(self, *args, _method_name=name, **kwargs): | |
method = getattr(self._async_instance, _method_name) | |
return method(*args, **kwargs) | |
# Add to the class dict | |
sync_class_dict[name] = proxy_method | |
# Handle property access | |
for name, prop in inspect.getmembers( | |
async_class, lambda x: isinstance(x, property) | |
): | |
def make_property(name, prop_obj): | |
def getter(self): | |
value = getattr(self._async_instance, name) | |
if inspect.iscoroutinefunction(value): | |
def sync_fn(*args, **kwargs): | |
return AsyncToSyncConverter.run_async_in_thread( | |
value, *args, **kwargs | |
) | |
return sync_fn | |
return value | |
def setter(self, value): | |
setattr(self._async_instance, name, value) | |
return property(getter, setter if prop_obj.fset else None) | |
sync_class_dict[name] = make_property(name, prop) | |
# Create the class | |
sync_class = type(sync_class_name, (object,), sync_class_dict) | |
return sync_class | |
def _format_type_annotation( | |
cls, annotation, type_tracker: Optional[TypeTracker] = None | |
) -> str: | |
"""Convert a type annotation to its string representation for stub files.""" | |
if ( | |
annotation is inspect.Parameter.empty | |
or annotation is inspect.Signature.empty | |
): | |
return "Any" | |
# Handle None type | |
if annotation is type(None): | |
return "None" | |
# Track the type if we have a tracker | |
if type_tracker: | |
type_tracker.track_type(annotation) | |
# Try using typing.get_origin/get_args for Python 3.8+ | |
try: | |
origin = get_origin(annotation) | |
args = get_args(annotation) | |
if origin is not None: | |
# Track the origin type | |
if type_tracker: | |
type_tracker.track_type(origin) | |
# Get the origin name | |
origin_name = getattr(origin, "__name__", str(origin)) | |
if "." in origin_name: | |
origin_name = origin_name.split(".")[-1] | |
# Special handling for types.UnionType (Python 3.10+ pipe operator) | |
# Convert to old-style Union for compatibility | |
if str(origin) == "<class 'types.UnionType'>" or origin_name == "UnionType": | |
origin_name = "Union" | |
# Format arguments recursively | |
if args: | |
formatted_args = [] | |
for arg in args: | |
# Track each type in the union | |
if type_tracker: | |
type_tracker.track_type(arg) | |
formatted_args.append(cls._format_type_annotation(arg, type_tracker)) | |
return f"{origin_name}[{', '.join(formatted_args)}]" | |
else: | |
return origin_name | |
except (AttributeError, TypeError): | |
# Fallback for older Python versions or non-generic types | |
pass | |
# Handle generic types the old way for compatibility | |
if hasattr(annotation, "__origin__") and hasattr(annotation, "__args__"): | |
origin = annotation.__origin__ | |
origin_name = ( | |
origin.__name__ | |
if hasattr(origin, "__name__") | |
else str(origin).split("'")[1] | |
) | |
# Format each type argument | |
args = [] | |
for arg in annotation.__args__: | |
args.append(cls._format_type_annotation(arg, type_tracker)) | |
return f"{origin_name}[{', '.join(args)}]" | |
# Handle regular types with __name__ | |
if hasattr(annotation, "__name__"): | |
return annotation.__name__ | |
# Handle special module types (like types from typing module) | |
if hasattr(annotation, "__module__") and hasattr(annotation, "__qualname__"): | |
# For types like typing.Literal, typing.TypedDict, etc. | |
return annotation.__qualname__ | |
# Last resort: string conversion with cleanup | |
type_str = str(annotation) | |
# Clean up common patterns more robustly | |
if type_str.startswith("<class '") and type_str.endswith("'>"): | |
type_str = type_str[8:-2] # Remove "<class '" and "'>" | |
# Remove module prefixes for common modules | |
for prefix in ["typing.", "builtins.", "types."]: | |
if type_str.startswith(prefix): | |
type_str = type_str[len(prefix) :] | |
# Handle special cases | |
if type_str in ("_empty", "inspect._empty"): | |
return "None" | |
# Fix NoneType (this should rarely be needed now) | |
if type_str == "NoneType": | |
return "None" | |
return type_str | |
def _extract_coroutine_return_type(cls, annotation): | |
"""Extract the actual return type from a Coroutine annotation.""" | |
if hasattr(annotation, "__args__") and len(annotation.__args__) > 2: | |
# Coroutine[Any, Any, ReturnType] -> extract ReturnType | |
return annotation.__args__[2] | |
return annotation | |
def _format_parameter_default(cls, default_value) -> str: | |
"""Format a parameter's default value for stub files.""" | |
if default_value is inspect.Parameter.empty: | |
return "" | |
elif default_value is None: | |
return " = None" | |
elif isinstance(default_value, bool): | |
return f" = {default_value}" | |
elif default_value == {}: | |
return " = {}" | |
elif default_value == []: | |
return " = []" | |
else: | |
return f" = {default_value}" | |
def _format_method_parameters( | |
cls, | |
sig: inspect.Signature, | |
skip_self: bool = True, | |
type_hints: Optional[dict] = None, | |
type_tracker: Optional[TypeTracker] = None, | |
) -> str: | |
"""Format method parameters for stub files.""" | |
params = [] | |
if type_hints is None: | |
type_hints = {} | |
for i, (param_name, param) in enumerate(sig.parameters.items()): | |
if i == 0 and param_name == "self" and skip_self: | |
params.append("self") | |
else: | |
# Get type annotation from type hints if available, otherwise from signature | |
annotation = type_hints.get(param_name, param.annotation) | |
type_str = cls._format_type_annotation(annotation, type_tracker) | |
# Get default value | |
default_str = cls._format_parameter_default(param.default) | |
# Combine parameter parts | |
if annotation is inspect.Parameter.empty: | |
params.append(f"{param_name}: Any{default_str}") | |
else: | |
params.append(f"{param_name}: {type_str}{default_str}") | |
return ", ".join(params) | |
def _generate_method_signature( | |
cls, | |
method_name: str, | |
method, | |
is_async: bool = False, | |
type_tracker: Optional[TypeTracker] = None, | |
) -> str: | |
"""Generate a complete method signature for stub files.""" | |
sig = inspect.signature(method) | |
# Try to get evaluated type hints to resolve string annotations | |
try: | |
from typing import get_type_hints | |
type_hints = get_type_hints(method) | |
except Exception: | |
# Fallback to empty dict if we can't get type hints | |
type_hints = {} | |
# For async methods, extract the actual return type | |
return_annotation = type_hints.get('return', sig.return_annotation) | |
if is_async and inspect.iscoroutinefunction(method): | |
return_annotation = cls._extract_coroutine_return_type(return_annotation) | |
# Format parameters with type hints | |
params_str = cls._format_method_parameters(sig, type_hints=type_hints, type_tracker=type_tracker) | |
# Format return type | |
return_type = cls._format_type_annotation(return_annotation, type_tracker) | |
if return_annotation is inspect.Signature.empty: | |
return_type = "None" | |
return f"def {method_name}({params_str}) -> {return_type}: ..." | |
def _generate_imports( | |
cls, async_class: Type, type_tracker: TypeTracker | |
) -> list[str]: | |
"""Generate import statements for the stub file.""" | |
imports = [] | |
# Add standard typing imports | |
imports.append( | |
"from typing import Any, Dict, List, Optional, Tuple, Union, Set, Sequence, cast, NamedTuple" | |
) | |
# Add imports from the original module | |
if async_class.__module__ != "builtins": | |
module = inspect.getmodule(async_class) | |
additional_types = [] | |
if module: | |
# Check if module has __all__ defined | |
module_all = getattr(module, "__all__", None) | |
for name, obj in sorted(inspect.getmembers(module)): | |
if isinstance(obj, type): | |
# Skip if __all__ is defined and this name isn't in it | |
# unless it's already been tracked as used in type annotations | |
if module_all is not None and name not in module_all: | |
# Check if this type was actually used in annotations | |
if name not in type_tracker.discovered_types: | |
continue | |
# Check for NamedTuple | |
if issubclass(obj, tuple) and hasattr(obj, "_fields"): | |
additional_types.append(name) | |
# Mark as already imported | |
type_tracker.already_imported.add(name) | |
# Check for Enum | |
elif issubclass(obj, Enum) and name != "Enum": | |
additional_types.append(name) | |
# Mark as already imported | |
type_tracker.already_imported.add(name) | |
if additional_types: | |
type_imports = ", ".join([async_class.__name__] + additional_types) | |
imports.append(f"from {async_class.__module__} import {type_imports}") | |
else: | |
imports.append( | |
f"from {async_class.__module__} import {async_class.__name__}" | |
) | |
# Add imports for all discovered types | |
# Pass the main module name to avoid duplicate imports | |
imports.extend( | |
type_tracker.get_imports(main_module_name=async_class.__module__) | |
) | |
# Add base module import if needed | |
if hasattr(inspect.getmodule(async_class), "__name__"): | |
module_name = inspect.getmodule(async_class).__name__ | |
if "." in module_name: | |
base_module = module_name.split(".")[0] | |
# Only add if not already importing from it | |
if not any(imp.startswith(f"from {base_module}") for imp in imports): | |
imports.append(f"import {base_module}") | |
return imports | |
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: | |
"""Extract class attributes that are classes themselves.""" | |
class_attributes = [] | |
# Look for class attributes that are classes | |
for name, attr in sorted(inspect.getmembers(async_class)): | |
if isinstance(attr, type) and not name.startswith("_"): | |
class_attributes.append((name, attr)) | |
elif ( | |
hasattr(async_class, "__annotations__") | |
and name in async_class.__annotations__ | |
): | |
annotation = async_class.__annotations__[name] | |
if isinstance(annotation, type): | |
class_attributes.append((name, annotation)) | |
return class_attributes | |
def _generate_inner_class_stub( | |
cls, | |
name: str, | |
attr: Type, | |
indent: str = " ", | |
type_tracker: Optional[TypeTracker] = None, | |
) -> list[str]: | |
"""Generate stub for an inner class.""" | |
stub_lines = [] | |
stub_lines.append(f"{indent}class {name}Sync:") | |
# Add docstring if available | |
if hasattr(attr, "__doc__") and attr.__doc__: | |
stub_lines.extend( | |
cls._format_docstring_for_stub(attr.__doc__, f"{indent} ") | |
) | |
# Add __init__ if it exists | |
if hasattr(attr, "__init__"): | |
try: | |
init_method = getattr(attr, "__init__") | |
init_sig = inspect.signature(init_method) | |
# Try to get type hints | |
try: | |
from typing import get_type_hints | |
init_hints = get_type_hints(init_method) | |
except Exception: | |
init_hints = {} | |
# Format parameters | |
params_str = cls._format_method_parameters( | |
init_sig, type_hints=init_hints, type_tracker=type_tracker | |
) | |
# Add __init__ docstring if available (before the method) | |
if hasattr(init_method, "__doc__") and init_method.__doc__: | |
stub_lines.extend( | |
cls._format_docstring_for_stub( | |
init_method.__doc__, f"{indent} " | |
) | |
) | |
stub_lines.append( | |
f"{indent} def __init__({params_str}) -> None: ..." | |
) | |
except (ValueError, TypeError): | |
stub_lines.append( | |
f"{indent} def __init__(self, *args, **kwargs) -> None: ..." | |
) | |
# Add methods to the inner class | |
has_methods = False | |
for method_name, method in sorted( | |
inspect.getmembers(attr, predicate=inspect.isfunction) | |
): | |
if method_name.startswith("_"): | |
continue | |
has_methods = True | |
try: | |
# Add method docstring if available (before the method signature) | |
if method.__doc__: | |
stub_lines.extend( | |
cls._format_docstring_for_stub(method.__doc__, f"{indent} ") | |
) | |
method_sig = cls._generate_method_signature( | |
method_name, method, is_async=True, type_tracker=type_tracker | |
) | |
stub_lines.append(f"{indent} {method_sig}") | |
except (ValueError, TypeError): | |
stub_lines.append( | |
f"{indent} def {method_name}(self, *args, **kwargs): ..." | |
) | |
if not has_methods: | |
stub_lines.append(f"{indent} pass") | |
return stub_lines | |
def _format_docstring_for_stub( | |
cls, docstring: str, indent: str = " " | |
) -> list[str]: | |
"""Format a docstring for inclusion in a stub file with proper indentation.""" | |
if not docstring: | |
return [] | |
# First, dedent the docstring to remove any existing indentation | |
dedented = textwrap.dedent(docstring).strip() | |
# Split into lines | |
lines = dedented.split("\n") | |
# Build the properly indented docstring | |
result = [] | |
result.append(f'{indent}"""') | |
for line in lines: | |
if line.strip(): # Non-empty line | |
result.append(f"{indent}{line}") | |
else: # Empty line | |
result.append("") | |
result.append(f'{indent}"""') | |
return result | |
def _post_process_stub_content(cls, stub_content: list[str]) -> list[str]: | |
"""Post-process stub content to fix any remaining issues.""" | |
processed = [] | |
for line in stub_content: | |
# Skip processing imports | |
if line.startswith(("from ", "import ")): | |
processed.append(line) | |
continue | |
# Fix method signatures missing return types | |
if ( | |
line.strip().startswith("def ") | |
and line.strip().endswith(": ...") | |
and ") -> " not in line | |
): | |
# Add -> None for methods without return annotation | |
line = line.replace(": ...", " -> None: ...") | |
processed.append(line) | |
return processed | |
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: | |
""" | |
Generate a .pyi stub file for the sync class to help IDEs with type checking. | |
""" | |
try: | |
# Only generate stub if we can determine module path | |
if async_class.__module__ == "__main__": | |
return | |
module = inspect.getmodule(async_class) | |
if not module: | |
return | |
module_path = module.__file__ | |
if not module_path: | |
return | |
# Create stub file path in a 'generated' subdirectory | |
module_dir = os.path.dirname(module_path) | |
stub_dir = os.path.join(module_dir, "generated") | |
# Ensure the generated directory exists | |
os.makedirs(stub_dir, exist_ok=True) | |
module_name = os.path.basename(module_path) | |
if module_name.endswith(".py"): | |
module_name = module_name[:-3] | |
sync_stub_path = os.path.join(stub_dir, f"{sync_class.__name__}.pyi") | |
# Create a type tracker for this stub generation | |
type_tracker = TypeTracker() | |
stub_content = [] | |
# We'll generate imports after processing all methods to capture all types | |
# Leave a placeholder for imports | |
imports_placeholder_index = len(stub_content) | |
stub_content.append("") # Will be replaced with imports later | |
# Class definition | |
stub_content.append(f"class {sync_class.__name__}:") | |
# Docstring | |
if async_class.__doc__: | |
stub_content.extend( | |
cls._format_docstring_for_stub(async_class.__doc__, " ") | |
) | |
# Generate __init__ | |
try: | |
init_method = async_class.__init__ | |
init_signature = inspect.signature(init_method) | |
# Try to get type hints for __init__ | |
try: | |
from typing import get_type_hints | |
init_hints = get_type_hints(init_method) | |
except Exception: | |
init_hints = {} | |
# Format parameters | |
params_str = cls._format_method_parameters( | |
init_signature, type_hints=init_hints, type_tracker=type_tracker | |
) | |
# Add __init__ docstring if available (before the method) | |
if hasattr(init_method, "__doc__") and init_method.__doc__: | |
stub_content.extend( | |
cls._format_docstring_for_stub(init_method.__doc__, " ") | |
) | |
stub_content.append(f" def __init__({params_str}) -> None: ...") | |
except (ValueError, TypeError): | |
stub_content.append( | |
" def __init__(self, *args, **kwargs) -> None: ..." | |
) | |
stub_content.append("") # Add newline after __init__ | |
# Get class attributes | |
class_attributes = cls._get_class_attributes(async_class) | |
# Generate inner classes | |
for name, attr in class_attributes: | |
inner_class_stub = cls._generate_inner_class_stub( | |
name, attr, type_tracker=type_tracker | |
) | |
stub_content.extend(inner_class_stub) | |
stub_content.append("") # Add newline after the inner class | |
# Add methods to the main class | |
processed_methods = set() # Keep track of methods we've processed | |
for name, method in sorted( | |
inspect.getmembers(async_class, predicate=inspect.isfunction) | |
): | |
if name.startswith("_") or name in processed_methods: | |
continue | |
processed_methods.add(name) | |
try: | |
method_sig = cls._generate_method_signature( | |
name, method, is_async=True, type_tracker=type_tracker | |
) | |
# Add docstring if available (before the method signature for proper formatting) | |
if method.__doc__: | |
stub_content.extend( | |
cls._format_docstring_for_stub(method.__doc__, " ") | |
) | |
stub_content.append(f" {method_sig}") | |
stub_content.append("") # Add newline after each method | |
except (ValueError, TypeError): | |
# If we can't get the signature, just add a simple stub | |
stub_content.append(f" def {name}(self, *args, **kwargs): ...") | |
stub_content.append("") # Add newline | |
# Add properties | |
for name, prop in sorted( | |
inspect.getmembers(async_class, lambda x: isinstance(x, property)) | |
): | |
stub_content.append(" @property") | |
stub_content.append(f" def {name}(self) -> Any: ...") | |
if prop.fset: | |
stub_content.append(f" @{name}.setter") | |
stub_content.append( | |
f" def {name}(self, value: Any) -> None: ..." | |
) | |
stub_content.append("") # Add newline after each property | |
# Add placeholders for the nested class instances | |
# Check the actual attribute names from class annotations and attributes | |
attribute_mappings = {} | |
# First check annotations for typed attributes (including from parent classes) | |
# Collect all annotations from the class hierarchy | |
all_annotations = {} | |
for base_class in reversed(inspect.getmro(async_class)): | |
if hasattr(base_class, "__annotations__"): | |
all_annotations.update(base_class.__annotations__) | |
for attr_name, attr_type in sorted(all_annotations.items()): | |
for class_name, class_type in class_attributes: | |
# If the class type matches the annotated type | |
if ( | |
attr_type == class_type | |
or (hasattr(attr_type, "__name__") and attr_type.__name__ == class_name) | |
or (isinstance(attr_type, str) and attr_type == class_name) | |
): | |
attribute_mappings[class_name] = attr_name | |
# Remove the extra checking - annotations should be sufficient | |
# Add the attribute declarations with proper names | |
for class_name, class_type in class_attributes: | |
# Check if there's a mapping from annotation | |
attr_name = attribute_mappings.get(class_name, class_name) | |
# Use the annotation name if it exists, even if the attribute doesn't exist yet | |
# This is because the attribute might be created at runtime | |
stub_content.append(f" {attr_name}: {class_name}Sync") | |
stub_content.append("") # Add a final newline | |
# Now generate imports with all discovered types | |
imports = cls._generate_imports(async_class, type_tracker) | |
# Deduplicate imports while preserving order | |
seen = set() | |
unique_imports = [] | |
for imp in imports: | |
if imp not in seen: | |
seen.add(imp) | |
unique_imports.append(imp) | |
else: | |
logging.warning(f"Duplicate import detected: {imp}") | |
# Replace the placeholder with actual imports | |
stub_content[imports_placeholder_index : imports_placeholder_index + 1] = ( | |
unique_imports | |
) | |
# Post-process stub content | |
stub_content = cls._post_process_stub_content(stub_content) | |
# Write stub file | |
with open(sync_stub_path, "w") as f: | |
f.write("\n".join(stub_content)) | |
logging.info(f"Generated stub file: {sync_stub_path}") | |
except Exception as e: | |
# If stub generation fails, log the error but don't break the main functionality | |
logging.error( | |
f"Error generating stub file for {sync_class.__name__}: {str(e)}" | |
) | |
import traceback | |
logging.error(traceback.format_exc()) | |
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: | |
""" | |
Creates a sync version of an async class | |
Args: | |
async_class: The async class to convert | |
thread_pool_size: Size of thread pool to use | |
Returns: | |
A new class with sync versions of all async methods | |
""" | |
return AsyncToSyncConverter.create_sync_class(async_class, thread_pool_size) | |