Spaces:
Running
Running
import time | |
from .ansi_utils import ansi_color_str | |
from .attr_utils import get_prev_caller_info | |
from .log_utils import log_info, LogSpinner | |
from .table_utils import format_table | |
time_color = '' | |
# helper funcs | |
class ProfileSection: | |
Registry = {} # Class Variable | |
def __init__(self, newMsg=None, enable=True, do_print=False): | |
self.msg = newMsg if newMsg else get_prev_caller_info() | |
self.enable = enable | |
self.startTime = 0.0 | |
self.endTime = 0.0 | |
self.elapsedTime = 0.0 | |
self.totalTime = 0.0 | |
self.avgTime = 0.0 | |
self.numCalls = 0 | |
self.do_print = do_print | |
if self.msg in ProfileSection.Registry: | |
p = ProfileSection.Registry[self.msg] | |
self.__dict__ = p.__dict__ | |
self.numCalls += 1 | |
else: | |
self.numCalls += 1 | |
if (self.enable): | |
self.start() | |
def __del__(self): | |
if (self.enable): | |
self.stop() | |
def __enter__(self): | |
# if self.enable: | |
self.start() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
# if self.enable: | |
self.stop() | |
def start(self): | |
self.startTime = time.time() | |
# if self.do_print : log_info(str(self)) | |
def stop(self): | |
self.endTime = time.time() | |
self.update_stats() | |
stopMsg = str(self) | |
p = create_or_get_profile(self.msg) | |
p.__dict__.update(self.__dict__) | |
if self.do_print: log_info(str(p)) | |
# st.toast(stopMsg) | |
return stopMsg | |
def update_stats(self): | |
self.elapsedTime = self.endTime - self.startTime | |
self.totalTime += self.elapsedTime | |
self.avgTime = self.totalTime / float(self.numCalls) | |
def __str__(self, use_color=True): | |
msg_str = self.msg #ansi_color_str(self.msg, fg='green') # Green color for message | |
elapsed_time_str = make_time_str('elapsed', self.elapsedTime) | |
total_time_str = make_time_str('total', self.totalTime) | |
avg_time_str = make_time_str('avg', self.avgTime) | |
calls_str = f'calls={self.numCalls}' | |
if use_color: | |
elapsed_time_str = ansi_color_str(elapsed_time_str, fg='bright_cyan') # Cyan color for elapsed time | |
total_time_str = ansi_color_str(total_time_str, fg='bright_cyan') # Cyan color for total time | |
avg_time_str = ansi_color_str(avg_time_str, fg='bright_cyan') # Cyan color for average time | |
calls_str = ansi_color_str(calls_str, fg='yellow') # Cyan color for calls information | |
return f"{msg_str} ~ Elapsed Time: {elapsed_time_str} | Total Time: {total_time_str} | Avg Time: {avg_time_str} | {calls_str}" | |
from functools import wraps | |
from threading import Thread | |
from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, TaskProgressColumn | |
from rich.console import Console | |
from rich.style import Style | |
from rich.panel import Panel | |
def profile_function(func): | |
""" | |
Decorator to track the progress of a function with a spinner. | |
The decorated function should not require explicit progress updates. | |
""" | |
def wrapper(*args, **kwargs): | |
profile_section = ProfileSection(func.__name__, enable=True, do_print=False) | |
console = Console() | |
with LogSpinner(func.__name__): | |
with profile_section: | |
result = func(*args, **kwargs) | |
#profile_section.stop() | |
# Display timing information using rich | |
panel = Panel.fit( | |
f"[bold green]{profile_section.msg}[/bold green]\n" | |
f"[cyan]Elapsed Time:[/cyan] {make_time_str('elapsed', profile_section.elapsedTime)}\n" | |
f"[cyan]Total Time:[/cyan] {make_time_str('total', profile_section.totalTime)}\n" | |
f"[cyan]Avg Time:[/cyan] {make_time_str('avg', profile_section.avgTime)}\n" | |
f"[yellow]Calls:[/yellow] {profile_section.numCalls}", | |
title="Profile Report", | |
border_style="bright_blue" | |
) | |
console.print(panel) | |
#print(str(profile_section)) | |
return result | |
return wrapper | |
def make_time_str(msg, value): | |
# do something fancy | |
value, time_unit = (value / 60, 'min') if value >= 60 else (value * 1000, 'ms') if value < 0.01 else (value, 's') | |
return f"{msg}={int(value) if value % 1 == 0 else value:.2f} {time_unit}" | |
def create_or_get_profile(key, enable=False, do_print=False): | |
if key not in ProfileSection.Registry: | |
ProfileSection.Registry[key] = ProfileSection(key, enable, do_print) | |
return ProfileSection.Registry[key] | |
def profile_start(msg, enable=True, do_print=False): | |
p = create_or_get_profile(msg, enable, do_print) | |
if not enable: p.start() | |
def profile_stop(msg): | |
if key in ProfileSection.Registry: | |
create_or_get_profile(msg).stop() | |
def get_profile_registry(): | |
return ProfileSection.Registry | |
from loguru import logger | |
def get_profile_reports(): | |
reports = [value for value in ProfileSection.Registry.values()] | |
reports.sort(key=lambda x: (x.totalTime, x.avgTime), reverse=True) | |
return reports | |
def log_profile_registry(use_color=True): | |
formatted_output = format_profile_registry(use_color=use_color) | |
print(formatted_output) | |
#logger.info(formatted_output) | |
#return formatted_output | |
def allow_curly_braces(original_string): | |
escaped_string = original_string.replace("{", "{{").replace("}", "}}") | |
#print("Escaped String:", escaped_string) # Debug output | |
return escaped_string | |
def format_profile_registry(use_color=True): | |
reports = get_profile_reports() | |
out_str = [] | |
out_str.append('=== Profile Reports ===\n') | |
for report in reports: | |
out_str.append(str(report)+'\n') | |
out_str.append('===>_<===\n') | |
return ''.join(out_str) | |
# import random | |
# def do_this(y): | |
# p = ProfileSection("do_this", True) #only way to use the auto destruct method | |
# x = random.randint(0, (y+1)*2) | |
# print(f'do_this: {x} - {y}') | |
# for index in range(1000): | |
# do_this(index) | |
# log_profile_registry() | |