Update src/pipeline.py
Browse files- src/pipeline.py +5 -2
src/pipeline.py
CHANGED
@@ -23,6 +23,7 @@ def error_handler(func: Callable):
|
|
23 |
return func(*args, **kwargs)
|
24 |
except Exception as e:
|
25 |
print(f"Error in {func.__name__}: {str(e)}")
|
|
|
26 |
return wrapper
|
27 |
|
28 |
class TorchOptimizer:
|
@@ -107,8 +108,10 @@ class PipelineManager:
|
|
107 |
pipe.to("cuda")
|
108 |
|
109 |
# Optimize pipeline
|
110 |
-
|
111 |
-
|
|
|
|
|
112 |
# Trigger compilation
|
113 |
print("Running torch compilation...")
|
114 |
pipe(
|
|
|
23 |
return func(*args, **kwargs)
|
24 |
except Exception as e:
|
25 |
print(f"Error in {func.__name__}: {str(e)}")
|
26 |
+
return None
|
27 |
return wrapper
|
28 |
|
29 |
class TorchOptimizer:
|
|
|
108 |
pipe.to("cuda")
|
109 |
|
110 |
# Optimize pipeline
|
111 |
+
pipe_ops = self.optimize_pipeline(pipe)
|
112 |
+
if pipe_ops!=None:
|
113 |
+
pipe = pipe_ops
|
114 |
+
|
115 |
# Trigger compilation
|
116 |
print("Running torch compilation...")
|
117 |
pipe(
|