Spaces:
Sleeping
Sleeping
| # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/xai.ipynb. | |
| # %% auto 0 | |
| __all__ = ['get_embeddings', 'get_dataset', 'umap_parameters', 'get_prjs', 'plot_projections', 'plot_projections_clusters', | |
| 'calculate_cluster_stats', 'anomaly_score', 'detector', 'plot_anomaly_scores_distribution', | |
| 'plot_clusters_with_anomalies', 'update_plot', 'plot_clusters_with_anomalies_interactive_plot', | |
| 'get_df_selected', 'shift_datetime', 'get_dateformat', 'get_anomalies', 'get_anomaly_styles', | |
| 'InteractiveAnomalyPlot', 'plot_save', 'plot_initial_config', 'merge_overlapping_windows', | |
| 'InteractiveTSPlot', 'add_selected_features', 'add_windows', 'setup_style', 'toggle_trace', | |
| 'set_features_buttons', 'move_left', 'move_right', 'move_down', 'move_up', 'delta_x_bigger', | |
| 'delta_y_bigger', 'delta_x_lower', 'delta_y_lower', 'add_movement_buttons', 'setup_boxes', 'initial_plot', | |
| 'show'] | |
| # %% ../nbs/xai.ipynb 1 | |
| #Weight & Biases | |
| import wandb | |
| #Yaml | |
| from yaml import load, FullLoader | |
| #Embeddings | |
| from .all import * | |
| from tsai.data.preparation import prepare_forecasting_data | |
| from tsai.data.validation import get_forecasting_splits | |
| from fastcore.all import * | |
| #Dimensionality reduction | |
| from tsai.imports import * | |
| #Clustering | |
| import hdbscan | |
| import time | |
| from .dr import get_PCA_prjs, get_UMAP_prjs, get_TSNE_prjs | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import ipywidgets as widgets | |
| from IPython.display import display | |
| from functools import partial | |
| from IPython.display import display, clear_output, HTML as IPHTML | |
| from ipywidgets import Button, Output, VBox, HBox, HTML, Layout, FloatSlider | |
| import plotly.graph_objs as go | |
| import plotly.offline as py | |
| import plotly.io as pio | |
| #! pip install kaleido | |
| import kaleido | |
| # %% ../nbs/xai.ipynb 4 | |
| def get_embeddings(config_lrp, run_lrp, api, print_flag = False): | |
| artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact | |
| emb_artifact = artifacts_gettr(config_lrp.emb_artifact, type='embeddings') | |
| if print_flag: print(emb_artifact.name) | |
| emb_config = emb_artifact.logged_by().config | |
| return emb_artifact.to_obj(), emb_artifact, emb_config | |
| # %% ../nbs/xai.ipynb 5 | |
| def get_dataset( | |
| config_lrp, | |
| config_emb, | |
| config_dr, | |
| run_lrp, | |
| api, | |
| print_flag = False | |
| ): | |
| # Botch to use artifacts offline | |
| artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact | |
| enc_artifact = artifacts_gettr(config_emb['enc_artifact'], type='learner') | |
| if print_flag: print (enc_artifact.name) | |
| ## TODO: This only works when you run it two timeS! WTF? | |
| try: | |
| enc_learner = enc_artifact.to_obj() | |
| except: | |
| enc_learner = enc_artifact.to_obj() | |
| ## Restore artifact | |
| enc_logger = enc_artifact.logged_by() | |
| enc_artifact_train = artifacts_gettr(enc_logger.config['train_artifact'], type='dataset') | |
| #cfg_.show_attrdict(enc_logger.config) | |
| if enc_logger.config['valid_artifact'] is not None: | |
| enc_artifact_valid = artifacts_gettr(enc_logger.config['valid_artifact'], type='dataset') | |
| if print_flag: print("enc_artifact_valid:", enc_artifact_valid.name) | |
| if print_flag: print("enc_artifact_train: ", enc_artifact_train.name) | |
| if config_dr['dr_artifact'] is not None: | |
| print("Is not none") | |
| dr_artifact = artifacts_gettr(config_dr['enc_artifact']) | |
| else: | |
| dr_artifact = enc_artifact_train | |
| if print_flag: print("DR artifact train: ", dr_artifact.name) | |
| if print_flag: print("--> DR artifact name", dr_artifact.name) | |
| dr_artifact | |
| df = dr_artifact.to_df() | |
| if print_flag: print("--> DR After to df", df.shape) | |
| if print_flag: display(df.head()) | |
| return df, dr_artifact, enc_artifact, enc_learner | |
| # %% ../nbs/xai.ipynb 6 | |
| def umap_parameters(config_dr, config): | |
| umap_params_cpu = { | |
| 'n_neighbors' : config_dr.n_neighbors, | |
| 'min_dist' : config_dr.min_dist, | |
| 'random_state': np.uint64(822569775), | |
| 'metric': config_dr.metric, | |
| #'a': 1.5769434601962196, | |
| #'b': 0.8950608779914887, | |
| #'metric_kwds': {'p': 2}, #No debería ser necesario, just in case | |
| #'output_metric': 'euclidean', | |
| 'verbose': 4, | |
| #'n_epochs': 200 | |
| } | |
| umap_params_gpu = { | |
| 'n_neighbors' : config_dr.n_neighbors, | |
| 'min_dist' : config_dr.min_dist, | |
| 'random_state': np.uint64(1234), | |
| 'metric': config_dr.metric, | |
| 'a': 1.5769434601962196, | |
| 'b': 0.8950608779914887, | |
| 'target_metric': 'euclidean', | |
| 'target_n_neighbors': config_dr.n_neighbors, | |
| 'verbose': 4, #6, #CUML_LEVEL_TRACE | |
| 'n_epochs': 200*3*2, | |
| 'init': 'random', | |
| 'hash_input': True | |
| } | |
| if config_dr.cpu_flag: | |
| umap_params = umap_params_cpu | |
| else: | |
| umap_params = umap_params_gpu | |
| return umap_params | |
| # %% ../nbs/xai.ipynb 7 | |
| def get_prjs(embs_no_nan, config_dr, config, print_flag = False): | |
| umap_params = umap_parameters(config_dr, config) | |
| prjs_pca = get_PCA_prjs( | |
| X = embs_no_nan, | |
| cpu = False, | |
| print_flag = print_flag, | |
| **umap_params | |
| ) | |
| if print_flag: | |
| print(prjs_pca.shape) | |
| prjs_umap = get_UMAP_prjs( | |
| input_data = prjs_pca, | |
| cpu = config_dr.cpu_flag, #config_dr.cpu, | |
| print_flag = print_flag, | |
| **umap_params | |
| ) | |
| if print_flag: prjs_umap.shape | |
| return prjs_umap | |
| # %% ../nbs/xai.ipynb 9 | |
| def plot_projections(prjs, umap_params, fig_size = (25,25)): | |
| "Plot 2D projections thorugh a connected scatter plot" | |
| df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2']) | |
| fig = plt.figure(figsize=(fig_size[0],fig_size[1])) | |
| ax = fig.add_subplot(111) | |
| ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1) | |
| ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1) | |
| plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format( | |
| umap_params['n_neighbors'],umap_params['min_dist'])) | |
| return ax | |
| # %% ../nbs/xai.ipynb 10 | |
| def plot_projections_clusters(prjs, clusters_labels, umap_params, fig_size = (25,25)): | |
| "Plot 2D projections thorugh a connected scatter plot" | |
| df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2']) | |
| df_prjs['cluster'] = clusters_labels | |
| fig = plt.figure(figsize=(fig_size[0],fig_size[1])) | |
| ax = fig.add_subplot(111) | |
| # Create a scatter plot for each cluster with different colors | |
| unique_labels = df_prjs['cluster'].unique() | |
| print(unique_labels) | |
| for label in unique_labels: | |
| cluster_data = df_prjs[df_prjs['cluster'] == label] | |
| ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}') | |
| #ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1) | |
| #ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1) | |
| plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format( | |
| umap_params['n_neighbors'],umap_params['min_dist'])) | |
| return ax | |
| # %% ../nbs/xai.ipynb 11 | |
| def calculate_cluster_stats(data, labels): | |
| """Computes the media and the standard deviation for every cluster.""" | |
| cluster_stats = {} | |
| for label in np.unique(labels): | |
| #members = data[labels == label] | |
| members = data | |
| mean = np.mean(members, axis = 0) | |
| std = np.std(members, axis = 0) | |
| cluster_stats[label] = (mean, std) | |
| return cluster_stats | |
| # %% ../nbs/xai.ipynb 12 | |
| def anomaly_score(point, cluster_stats, label): | |
| """Computes an anomaly score for each point.""" | |
| mean, std = cluster_stats[label] | |
| return np.linalg.norm((point - mean) / std) | |
| # %% ../nbs/xai.ipynb 13 | |
| def detector(data, labels): | |
| """Anomaly detection function.""" | |
| cluster_stats = calculate_cluster_stats(data, labels) | |
| scores = [] | |
| for point, label in zip(data, labels): | |
| score = anomaly_score(point, cluster_stats, label) | |
| scores.append(score) | |
| return np.array(scores) | |
| # %% ../nbs/xai.ipynb 15 | |
| def plot_anomaly_scores_distribution(anomaly_scores): | |
| "Plot the distribution of anomaly scores to check for normality" | |
| plt.figure(figsize=(10, 6)) | |
| sns.histplot(anomaly_scores, kde=True, bins=30) | |
| plt.title("Distribución de Anomaly Scores") | |
| plt.xlabel("Anomaly Score") | |
| plt.ylabel("Frecuencia") | |
| plt.show() | |
| # %% ../nbs/xai.ipynb 16 | |
| def plot_clusters_with_anomalies(prjs, clusters_labels, anomaly_scores, threshold, fig_size=(25, 25)): | |
| "Plot 2D projections of clusters and superimpose anomalies" | |
| df_prjs = pd.DataFrame(prjs, columns=['x1', 'x2']) | |
| df_prjs['cluster'] = clusters_labels | |
| df_prjs['anomaly'] = anomaly_scores > threshold | |
| fig = plt.figure(figsize=(fig_size[0], fig_size[1])) | |
| ax = fig.add_subplot(111) | |
| # Plot each cluster with different colors | |
| unique_labels = df_prjs['cluster'].unique() | |
| for label in unique_labels: | |
| cluster_data = df_prjs[df_prjs['cluster'] == label] | |
| ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}', alpha=0.7) | |
| # Superimpose anomalies | |
| anomalies = df_prjs[df_prjs['anomaly']] | |
| ax.scatter(anomalies['x1'], anomalies['x2'], color='red', label='Anomalies', edgecolor='k', s=50) | |
| plt.title('Clusters and anomalies') | |
| plt.legend() | |
| plt.show() | |
| def update_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size): | |
| plot_clusters_with_anomalies(prjs_umap, clusters_labels, anomaly_scores, threshold, fig_size) | |
| def plot_clusters_with_anomalies_interactive_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size): | |
| threshold_slider = widgets.FloatSlider(value=threshold, min=0.001, max=3, step=0.001, description='Threshold') | |
| interactive_plot = widgets.interactive(update_plot, threshold = threshold_slider, | |
| prjs_umap = widgets.fixed(prjs_umap), | |
| clusters_labels = widgets.fixed(clusters_labels), | |
| anomaly_scores = widgets.fixed(anomaly_scores), | |
| fig_size = widgets.fixed((25,25))) | |
| display(interactive_plot) | |
| # %% ../nbs/xai.ipynb 18 | |
| import plotly.express as px | |
| from datetime import timedelta | |
| # %% ../nbs/xai.ipynb 19 | |
| def get_df_selected(df, selected_indices, w, stride = 1): #Cuidado con stride | |
| '''Links back the selected points to the original dataframe and returns the associated windows indices''' | |
| n_windows = len(selected_indices) | |
| window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices] | |
| #window_ranges = [(id*w, (id+1)*w+1) for id in selected_indices] | |
| #window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices] | |
| #print(window_ranges) | |
| valores_tramos = [df.iloc[inicio:fin+1] for inicio, fin in window_ranges] | |
| df_selected = pd.concat(valores_tramos, ignore_index=False) | |
| return window_ranges, n_windows, df_selected | |
| # %% ../nbs/xai.ipynb 20 | |
| def shift_datetime(dt, seconds, sign, dateformat="%Y-%m-%d %H:%M:%S.%f", print_flag = False): | |
| """ | |
| This function gets a datetime dt, a number of seconds, | |
| a sign and moves the date such number of seconds to the future | |
| if sign is '+' and to the past if sing is '-'. | |
| """ | |
| if print_flag: print(dateformat) | |
| dateformat2= "%Y-%m-%d %H:%M:%S.%f" | |
| dateformat3 = "%Y-%m-%d" | |
| ok = False | |
| try: | |
| if print_flag: print("dt ", dt, "seconds", seconds, "sign", sign) | |
| new_dt = datetime.strptime(dt, dateformat) | |
| if print_flag: print("ndt", new_dt) | |
| ok = True | |
| except ValueError as e: | |
| if print_flag: | |
| print("Error: ", e) | |
| if (not ok): | |
| try: | |
| if print_flag: print("Parsing alternative dataformat", dt, "seconds", seconds, "sign", sign, dateformat2) | |
| new_dt = datetime.strptime(dt, dateformat3) | |
| if print_flag: print("2ndt", new_dt) | |
| except ValueError as e: | |
| print("Error: ", e) | |
| if print_flag: print(new_dt) | |
| try: | |
| if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0: | |
| if print_flag: "Aqui" | |
| new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0) | |
| if print_flag: print(new_dt) | |
| if print_flag: print("ndt", new_dt) | |
| if (sign == '+'): | |
| if print_flag: print("Aqui") | |
| new_dt = new_dt + timedelta(seconds = seconds) | |
| if print_flag: print(new_dt) | |
| else: | |
| if print_flag: print(sign, type(dt)) | |
| new_dt = new_dt - timedelta(seconds = seconds) | |
| if print_flag: print(new_dt) | |
| if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0: | |
| if print_flag: print("replacing") | |
| new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0) | |
| new_dt_str = new_dt.strftime(dateformat2) | |
| if print_flag: print("new dt ", new_dt) | |
| except ValueError as e: | |
| if print_flag: print("Aqui3") | |
| shift_datetime(dt, 0, sign, dateformat = "%Y-%m-%d", print_flag = False) | |
| return str(e) | |
| return new_dt_str | |
| # %% ../nbs/xai.ipynb 21 | |
| def get_dateformat(text_date): | |
| dateformat1 = "%Y-%m-%d %H:%M:%S" | |
| dateformat2 = "%Y-%m-%d %H:%M:%S.%f" | |
| dateformat3 = "%Y-%m-%d" | |
| dateformat = "" | |
| parts = text_date.split() | |
| if len(parts) == 2: | |
| time_parts = parts[1].split(':') | |
| if len(time_parts) == 3: | |
| sec_parts = time_parts[2].split('.') | |
| if len(sec_parts) == 2: | |
| dateformat = dateformat2 | |
| else: | |
| dateformat = dateformat1 | |
| else: | |
| dateformat = "unknown format 1" | |
| elif len(parts) == 1: | |
| dateformat = dateformat3 | |
| else: | |
| dateformat = "unknown format 2" | |
| return dateformat | |
| # %% ../nbs/xai.ipynb 23 | |
| def get_anomalies(df, threshold, flag): | |
| df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score']] | |
| def get_anomaly_styles(df, threshold, anomaly_scores, flag = False, print_flag = False): | |
| if print_flag: print("Threshold: ", threshold) | |
| if print_flag: print("Flag", flag) | |
| if print_flag: print("df ~", df.shape) | |
| df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score'] ] | |
| if print_flag: print(df) | |
| get_anomalies(df, threshold, flag) | |
| anomalies = df[df['anomaly']] | |
| if flag: | |
| df['anomaly'] = [ | |
| (score > threshold) and flag | |
| for score in anomaly_scores | |
| ] | |
| symbols = [ | |
| 'x' if is_anomaly else 'circle' | |
| for is_anomaly in df['anomaly'] | |
| ] | |
| line_colors = [ | |
| 'black' | |
| if (is_anomaly and flag) else 'rgba(0,0,0,0)' | |
| for is_anomaly in df['anomaly'] | |
| ] | |
| else: | |
| symbols = ['circle' for _ in df['x1']] | |
| line_colors = ['rgba(0,0,0,0)' for _ in df['x1']] | |
| if print_flag: print(anomalies) | |
| return symbols, line_colors | |
| ### Example of use | |
| #prjs_df = pd.DataFrame(prjs_umap, columns = ['x1', 'x2']) | |
| #prjs_df['anomaly_score'] = anomaly_scores | |
| #s, l = get_anomaly_styles(prjs_df, 1, True) | |
| # %% ../nbs/xai.ipynb 24 | |
| class InteractiveAnomalyPlot(): | |
| def __init__( | |
| self, selected_indices = [], | |
| threshold = 0.15, | |
| anomaly_flag = False, | |
| path = "../imgs", w = 0 | |
| ): | |
| self.selected_indices = selected_indices | |
| self.selected_indices_tmp = selected_indices | |
| self.threshold = threshold | |
| self.threshold_ = threshold | |
| self.anomaly_flag = anomaly_flag | |
| self.w = w | |
| self.name = f"w={self.w}" | |
| self.path = f"{path}{self.name}.png" | |
| self.interaction_enabled = True | |
| def plot_projections_clusters_interactive( | |
| self, prjs, cluster_labels, umap_params, anomaly_scores=[], fig_size=(7,7), print_flag = False | |
| ): | |
| self.selected_indices_tmp = self.selected_indices | |
| py.init_notebook_mode() | |
| prjs_df, cluster_colors = plot_initial_config(prjs, cluster_labels, anomaly_scores) | |
| legend_items = [widgets.HTML(f'<b>Cluster {cluster}:</b> <span style="color:{color};">■</span>') | |
| for cluster, color in cluster_colors.items()] | |
| legend = widgets.VBox(legend_items) | |
| marker_colors = prjs_df['cluster'].map(cluster_colors) | |
| symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag) | |
| fig = go.FigureWidget( | |
| [ | |
| go.Scatter( | |
| x=prjs_df['x1'], y=prjs_df['x2'], | |
| mode="markers", | |
| marker= { | |
| 'color': marker_colors, | |
| 'line': { 'color': line_colors, 'width': 1 }, | |
| 'symbol': symbols | |
| }, | |
| text = prjs_df.index | |
| ) | |
| ] | |
| ) | |
| line_trace = go.Scatter( | |
| x=prjs_df['x1'], | |
| y=prjs_df['x2'], | |
| mode="lines", | |
| line=dict(color='rgba(128, 128, 128, 0.5)', width=1)#, | |
| #showlegend=False # Puedes configurar si deseas mostrar esta línea en la leyenda | |
| ) | |
| fig.add_trace(line_trace) | |
| sca = fig.data[0] | |
| fig.update_layout( | |
| dragmode='lasso', | |
| width=700, | |
| height=500, | |
| title={ | |
| 'text': '<span style="font-weight:bold">DR params - n_neighbors:{:d} min_dist:{:f}</span>'.format( | |
| umap_params['n_neighbors'], umap_params['min_dist']), | |
| 'y':0.98, | |
| 'x':0.5, | |
| 'xanchor': 'center', | |
| 'yanchor': 'top' | |
| }, | |
| plot_bgcolor='white', | |
| paper_bgcolor='#f0f0f0', | |
| xaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'x'), | |
| yaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'y'), | |
| margin=dict(l=10, r=20, t=30, b=10) | |
| ) | |
| output_tmp = Output() | |
| output_button = Output() | |
| output_anomaly = Output() | |
| output_threshold = Output() | |
| output_width = Output() | |
| def select_action(trace, points, selector): | |
| self.selected_indices_tmp = points.point_inds | |
| with output_tmp: | |
| output_tmp.clear_output(wait=True) | |
| if print_flag: print("Selected indices tmp:", self.selected_indices_tmp) | |
| def button_action(b): | |
| self.selected_indices = self.selected_indices_tmp | |
| with output_button: | |
| output_button.clear_output(wait = True) | |
| if print_flag: print("Selected indices:", self.selected_indices) | |
| def update_anomalies(): | |
| if print_flag: print("About to update anomalies") | |
| symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag) | |
| if print_flag: print("Anomaly styles got") | |
| with fig.batch_update(): | |
| fig.data[0].marker.symbol = symbols | |
| fig.data[0].marker.line.color = line_colors | |
| if print_flag: print("Anomalies updated") | |
| if print_flag: print("Threshold: ", self.threshold_) | |
| if print_flag: print("Scores: ", anomaly_scores) | |
| def anomaly_action(b): | |
| with output_anomaly: # Cambia output_flag a output_anomaly | |
| output_anomaly.clear_output(wait=True) | |
| if print_fllag: print("Negate anomaly flag") | |
| self.anomaly_flag = not self.anomaly_flag | |
| if print_flag: print("Show anomalies:", self.anomaly_flag) | |
| update_anomalies() | |
| sca.on_selection(select_action) | |
| layout = widgets.Layout(width='auto', height='40px') | |
| button = Button( | |
| description="Update selected_indices", | |
| style = {'button_color': 'lightblue'}, | |
| display = 'flex', | |
| flex_row = 'column', | |
| align_items = 'stretch', | |
| layout = layout | |
| ) | |
| anomaly_button = Button( | |
| description = "Show anomalies", | |
| style = {'button_color': 'lightgray'}, | |
| display = 'flex', | |
| flex_row = 'column', | |
| align_items = 'stretch', | |
| layout = layout | |
| ) | |
| button.on_click(button_action) | |
| anomaly_button.on_click(anomaly_action) | |
| ##### Reactivity buttons | |
| pause_button = Button( | |
| description = "Pause interactiveness", | |
| style = {'button_color': 'pink'}, | |
| display = 'flex', | |
| flex_row = 'column', | |
| align_items = 'stretch', | |
| layout = layout | |
| ) | |
| resume_button = Button( | |
| description = "Resume interactiveness", | |
| style = {'button_color': 'lightgreen'}, | |
| display = 'flex', | |
| flex_row = 'column', | |
| align_items = 'stretch', | |
| layout = layout | |
| ) | |
| threshold_slider = FloatSlider( | |
| value=self.threshold_, | |
| min=0.0, | |
| max=float(np.ceil(self.threshold+5)), | |
| step=0.0001, | |
| description='Anomaly threshold:', | |
| continuous_update=False | |
| ) | |
| def pause_interaction(b): | |
| self.interaction_enabled = False | |
| fig.update_layout(dragmode='pan') | |
| def resume_interaction(b): | |
| self.interaction_enabled = True | |
| fig.update_layout(dragmode='lasso') | |
| def update_threshold(change): | |
| with output_threshold: | |
| output_threshold.clear_output(wait = True) | |
| if print_flag: print("Update threshold") | |
| self.threshold_ = change.new | |
| if print_flag: print("Update anomalies threshold = ", self.threshold_) | |
| update_anomalies() | |
| #### Width | |
| width_slider = FloatSlider( | |
| value = 0.5, | |
| min = 0.0, | |
| max = 1.0, | |
| step = 0.0001, | |
| description = 'Line width:', | |
| continuous_update = False | |
| ) | |
| def update_width(change): | |
| with output_width: | |
| try: | |
| output_width.clear_output(wait = True) | |
| if print_flag: | |
| print("Change line width") | |
| print("Trace to update:", fig.data[1]) | |
| with fig.batch_update(): | |
| fig.data[1].line.width = change.new # Actualiza la opacidad de la línea | |
| if print_flag: print("ChangeD line width") | |
| except Exception as e: | |
| print("Error updating line width:", e) | |
| pause_button.on_click(pause_interaction) | |
| resume_button.on_click(resume_interaction) | |
| threshold_slider.observe(update_threshold, 'value') | |
| #### | |
| width_slider.observe(update_width, names = 'value') | |
| ##### | |
| space = HTML(" ") | |
| vbox = VBox((output_tmp, output_button, output_anomaly, output_threshold, fig)) | |
| hbox = HBox((space, button, space, pause_button, space, resume_button, anomaly_button)) | |
| # Centrar las dos cajas horizontalmente en el VBox | |
| box_layout = widgets.Layout(display='flex', | |
| flex_flow='column', | |
| align_items='center', | |
| width='100%') | |
| if self.anomaly_flag: | |
| box = VBox((hbox,threshold_slider,width_slider, output_width, vbox), layout = box_layout) | |
| else: | |
| box = VBox((hbox, width_slider, output_width, vbox), layout = box_layout) | |
| box.add_class("layout") | |
| plot_save(fig, self.w) | |
| display(box) | |
| # %% ../nbs/xai.ipynb 25 | |
| def plot_save(fig, w): | |
| image_bytes = pio.to_image(fig, format='png') | |
| with open(f"../imgs/w={w}.png", 'wb') as f: | |
| f.write(image_bytes) | |
| # %% ../nbs/xai.ipynb 26 | |
| def plot_initial_config(prjs, cluster_labels, anomaly_scores): | |
| prjs_df = pd.DataFrame(prjs, columns = ['x1', 'x2']) | |
| prjs_df['cluster'] = cluster_labels | |
| prjs_df['anomaly_score'] = anomaly_scores | |
| cluster_colors_df = pd.DataFrame({'cluster': cluster_labels}).drop_duplicates() | |
| cluster_colors_df['color'] = px.colors.qualitative.Set1[:len(cluster_colors_df)] | |
| cluster_colors = dict(zip(cluster_colors_df['cluster'], cluster_colors_df['color'])) | |
| return prjs_df, cluster_colors | |
| # %% ../nbs/xai.ipynb 27 | |
| def merge_overlapping_windows(windows): | |
| if not windows: | |
| return [] | |
| # Order | |
| sorted_windows = sorted(windows, key=lambda x: x[0]) | |
| merged_windows = [sorted_windows[0]] | |
| for window in sorted_windows[1:]: | |
| if window[0] <= merged_windows[-1][1]: | |
| # Merge! | |
| merged_windows[-1] = (merged_windows[-1][0], max(window[1], merged_windows[-1][1])) | |
| else: | |
| merged_windows.append(window) | |
| return merged_windows | |
| # %% ../nbs/xai.ipynb 29 | |
| class InteractiveTSPlot: | |
| def __init__( | |
| self, | |
| df, | |
| selected_indices, | |
| meaningful_features_subset_ids, | |
| w, | |
| stride=1, | |
| print_flag=False, | |
| num_points=10000, | |
| dateformat='%Y-%m-%d %H:%M:%S', | |
| delta_x = 10, | |
| delta_y = 0.1 | |
| ): | |
| self.df = df | |
| self.selected_indices = selected_indices | |
| self.meaningful_features_subset_ids = meaningful_features_subset_ids | |
| self.w = w | |
| self.stride = stride | |
| self.print_flag = print_flag | |
| self.num_points = num_points | |
| self.dateformat = dateformat | |
| self.fig = go.FigureWidget() | |
| self.buttons = [] | |
| self.print_flag = print_flag | |
| self.delta_x = delta_x | |
| self.delta_y = delta_y | |
| self.window_ranges, self.n_windows, self.df_selected = get_df_selected( | |
| self.df, self.selected_indices, self.w, self.stride | |
| ) | |
| # Ensure the small possible number of windows to plot (like in R Shiny App) | |
| self.window_ranges = merge_overlapping_windows(self.window_ranges) | |
| #Num points no va bien... | |
| #num_points = min(df_selected.shape[0], num_points) | |
| if self.print_flag: | |
| print("windows: ", self.n_windows, self.window_ranges) | |
| print("selected id: ", self.df_selected.index) | |
| print("points: ", self.num_points) | |
| self.df.index = self.df.index.astype(str) | |
| self.fig = go.FigureWidget() | |
| self.colors = [ | |
| f'rgb({np.random.randint(0, 256)}, {np.random.randint(0, 256)}, {np.random.randint(0, 256)})' | |
| for _ in range(self.n_windows) | |
| ] | |
| ############################## | |
| # Outputs for debug printing # | |
| ############################## | |
| self.output_windows = Output() | |
| self.output_move = Output() | |
| self.output_delta_x = Output() | |
| self.output_delta_y = Output() | |
| # %% ../nbs/xai.ipynb 30 | |
| def add_selected_features(self: InteractiveTSPlot): | |
| # Add features time series | |
| for feature_id in self.df.columns: | |
| feature_pos = self.df.columns.get_loc(feature_id) | |
| trace = go.Scatter( | |
| #x=df.index[:num_points], | |
| #y=df[feature_id][:num_points], | |
| x = self.df.index, | |
| y = self.df[feature_id], | |
| mode='lines', | |
| name=feature_id, | |
| visible=feature_pos in self.meaningful_features_subset_ids, | |
| text=self.df.index | |
| #text=[f'{i}-{val}' for i, val in enumerate(df.index)] | |
| ) | |
| self.fig.add_trace(trace) | |
| InteractiveTSPlot.add_selected_features = add_selected_features | |
| # %% ../nbs/xai.ipynb 31 | |
| def add_windows(self: InteractiveTSPlot): | |
| for i, (start, end) in enumerate(self.window_ranges): | |
| self.fig.add_shape( | |
| type="rect", | |
| x0=self.df.index[start], | |
| x1=self.df.index[end], | |
| y0= 0, | |
| y1= 1, | |
| yref = "paper", | |
| fillcolor=self.colors[i], #"LightSalmon", | |
| opacity=0.25, | |
| layer="below", | |
| line=dict(color=self.colors[i], width=1), | |
| name = f"w_{i}" | |
| ) | |
| with self.output_windows: | |
| print("w[" + str( self.selected_indices[i] )+ "]="+str(self.df.index[start])+", "+str(self.df.index[end])+")") | |
| InteractiveTSPlot.add_windows = add_windows | |
| # %% ../nbs/xai.ipynb 32 | |
| def setup_style(self: InteractiveTSPlot): | |
| self.fig.update_layout( | |
| title='Time Series with time window plot', | |
| xaxis_title='Datetime', | |
| yaxis_title='Value', | |
| legend_title='Variables', | |
| margin=dict(l=10, r=10, t=30, b=10), | |
| xaxis=dict( | |
| tickformat = '%d-' + self.dateformat, | |
| #tickvals=list(range(len(df.index))), | |
| #ticktext = [f'{i}-{val}' for i, val in enumerate(df.index)] | |
| #grid_color = 'lightgray', zerolinecolor='black', title = 'x' | |
| ), | |
| #yaxis = dict(grid_color = 'lightgray', zerolinecolor='black', title = 'y'), | |
| #plot_color = 'white', | |
| paper_bgcolor='#f0f0f0' | |
| ) | |
| self.fig.update_yaxes(fixedrange=True) | |
| InteractiveTSPlot.setup_style = setup_style | |
| # %% ../nbs/xai.ipynb 34 | |
| def toggle_trace(self : InteractiveTSPlot, button : Button): | |
| idx = button.description | |
| trace = self.fig.data[self.df.columns.get_loc(idx)] | |
| trace.visible = not trace.visible | |
| InteractiveTSPlot.toggle_trace = toggle_trace | |
| # %% ../nbs/xai.ipynb 35 | |
| def set_features_buttons(self): | |
| self.buttons = [ | |
| Button( | |
| description=str(feature_id), | |
| button_style='success' if self.df.columns.get_loc(feature_id) in self.meaningful_features_subset_ids else '' | |
| ) | |
| for feature_id in self.df.columns | |
| ] | |
| for button in self.buttons: | |
| button.on_click(self.toggle_trace) | |
| InteractiveTSPlot.set_features_buttons = set_features_buttons | |
| # %% ../nbs/xai.ipynb 36 | |
| def move_left(self : InteractiveTSPlot, button : Button): | |
| with self.output_move: | |
| self.output_move.clear_output(wait=True) | |
| start_date, end_date = self.fig.layout.xaxis.range | |
| new_start_date = shift_datetime(start_date, self.delta_x, '-', self.dateformat, self.print_flag) | |
| new_end_date = shift_datetime(end_date, self.delta_x, '-', self.dateformat, self.print_flag) | |
| with self.fig.batch_update(): | |
| self.fig.layout.xaxis.range = [new_start_date, new_end_date] | |
| def move_right(self : InteractiveTSPlot, button : Button): | |
| self.output_move.clear_output(wait=True) | |
| with self.output_move: | |
| start_date, end_date = self.fig.layout.xaxis.range | |
| new_start_date = shift_datetime(start_date, self.delta_x, '+', self.dateformat, self.print_flag) | |
| new_end_date = shift_datetime(end_date, self.delta_x, '+', self.dateformat, self.print_flag) | |
| with self.fig.batch_update(): | |
| self.fig.layout.xaxis.range = [new_start_date, new_end_date] | |
| def move_down(self: InteractiveTSPlot, button : Button): | |
| with self.output_move: | |
| self.output_move.clear_output(wait=True) | |
| start_y, end_y = self.fig.layout.yaxis.range | |
| with self.fig.batch_update(): | |
| self.ig.layout.yaxis.range = [start_y-self.delta_y, end_y-self.delta_y] | |
| def move_up(self: InteractiveTSPlot, button : Button): | |
| with self.output_move: | |
| self.output_move.clear_output(wait=True) | |
| start_y, end_y = self.fig.layout.yaxis.range | |
| with self.fig.batch_update(): | |
| self.fig.layout.yaxis.range = [start_y+self.delta_y, end_y+self.delta_y] | |
| InteractiveTSPlot.move_left = move_left | |
| InteractiveTSPlot.move_right = move_right | |
| InteractiveTSPlot.move_down = move_down | |
| InteractiveTSPlot.move_up = move_up | |
| # %% ../nbs/xai.ipynb 37 | |
| def delta_x_bigger(self: InteractiveTSPlot): | |
| with self.output_delta_x: | |
| self.output_delta_x.clear_output(wait = True) | |
| if self.print_flag: print("Delta before", self.delta_x) | |
| self.delta_x *= 10 | |
| if self.print_flag: print("delta_x:", self.delta_x) | |
| def delta_y_bigger(self: InteractiveTSPlot): | |
| with self.output_delta_y: | |
| self.output_delta_y.clear_output(wait = True) | |
| if self.print_flag: print("Delta before", self.delta_y) | |
| self.delta_y *= 10 | |
| if self.print_flag: print("delta_y:", self.delta_y) | |
| def delta_x_lower(self:InteractiveTSPlot): | |
| with self.output_delta_x: | |
| self.output_delta_x.clear_output(wait = True) | |
| if self.print_flag: print("Delta before", self.delta_x) | |
| self.delta_x /= 10 | |
| if self.print_flag: print("delta_x:", self.delta_x) | |
| def delta_y_lower(self:InteractiveTSPlot): | |
| with self.output_delta_y: | |
| self.output_delta_y.clear_output(wait = True) | |
| print("Delta before", self.delta_y) | |
| self.delta_y = self.delta_y * 10 | |
| print("delta_y:", self.delta_y) | |
| InteractiveTSPlot.delta_x_bigger = delta_x_bigger | |
| InteractiveTSPlot.delta_y_bigger = delta_y_bigger | |
| InteractiveTSPlot.delta_x_lower = delta_x_lower | |
| InteractiveTSPlot.delta_y_lower = delta_y_lower | |
| # %% ../nbs/xai.ipynb 38 | |
| def add_movement_buttons(self: InteractiveTSPlot): | |
| self.button_left = Button(description="←") | |
| self.button_right = Button(description="→") | |
| self.button_up = Button(description="↑") | |
| self.button_down = Button(description="↓") | |
| self.button_step_x_up = Button(description="dx ↑") | |
| self.button_step_x_down = Button(description="dx ↓") | |
| self.button_step_y_up = Button(description="dy↑") | |
| self.button_step_y_down = Button(description="dy↓") | |
| # TODO: Arreglar que se pueda modificar el paso con el que se avanza. No se ve el output y no se modifica el valor | |
| self.button_step_x_up.on_click(self.delta_x_bigger) | |
| self.button_step_x_down.on_click(self.delta_x_lower) | |
| self.button_step_y_up.on_click(self.delta_y_bigger) | |
| self.button_step_y_down.on_click(self.delta_y_lower) | |
| self.button_left.on_click(self.move_left) | |
| self.button_right.on_click(self.move_right) | |
| self.button_up.on_click(self.move_up) | |
| self.button_down.on_click(self.move_down) | |
| InteractiveTSPlot.add_movement_buttons = add_movement_buttons | |
| # %% ../nbs/xai.ipynb 40 | |
| def setup_boxes(self: InteractiveTSPlot): | |
| self.steps_x = VBox([self.button_step_x_up, self.button_step_x_down]) | |
| self.steps_y = VBox([self.button_step_y_up, self.button_step_y_down]) | |
| arrow_buttons = HBox([self.button_left, self.button_right, self.button_up, self.button_down, self.steps_x, self.steps_y]) | |
| hbox_layout = widgets.Layout(display='flex', flex_flow='row wrap', align_items='flex-start') | |
| hbox = HBox(self.buttons, layout=hbox_layout) | |
| box_layout = widgets.Layout( | |
| display='flex', | |
| flex_flow='column', | |
| align_items='center', | |
| width='100%' | |
| ) | |
| if self.print_flag: | |
| self.box = VBox([hbox, arrow_buttons, self.output_move, self.output_delta_x, self.output_delta_y, self.fig, self.output_windows], layout=box_layout) | |
| else: | |
| self.box = VBox([hbox, arrow_buttons, self.fig, self.output_windows], layout=box_layout) | |
| InteractiveTSPlot.setup_boxes = setup_boxes | |
| # %% ../nbs/xai.ipynb 41 | |
| def initial_plot(self: InteractiveTSPlot): | |
| self.add_selected_features() | |
| self.add_windows() | |
| self.setup_style() | |
| self.set_features_buttons() | |
| self.add_movement_buttons() | |
| self.setup_boxes() | |
| InteractiveTSPlot.initial_plot = initial_plot | |
| # %% ../nbs/xai.ipynb 42 | |
| def show(self : InteractiveTSPlot): | |
| self.initial_plot() | |
| display(self.box) | |
| InteractiveTSPlot.show = show | |