misantamaria commited on
Commit
a458f3b
·
1 Parent(s): 17629d1

separado en dos local_exec docker y local_build_docker. Eliminado xai.py, que no es necesario aqui y da problemas

Browse files
dvats/xai.py CHANGED
@@ -1,964 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/xai.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['get_embeddings', 'get_dataset', 'umap_parameters', 'get_prjs', 'plot_projections', 'plot_projections_clusters',
5
- 'calculate_cluster_stats', 'anomaly_score', 'detector', 'plot_anomaly_scores_distribution',
6
- 'plot_clusters_with_anomalies', 'update_plot', 'plot_clusters_with_anomalies_interactive_plot',
7
- 'get_df_selected', 'shift_datetime', 'get_dateformat', 'get_anomalies', 'get_anomaly_styles',
8
- 'InteractiveAnomalyPlot', 'plot_save', 'plot_initial_config', 'merge_overlapping_windows',
9
- 'InteractiveTSPlot', 'add_selected_features', 'add_windows', 'setup_style', 'toggle_trace',
10
- 'set_features_buttons', 'move_left', 'move_right', 'move_down', 'move_up', 'delta_x_bigger',
11
- 'delta_y_bigger', 'delta_x_lower', 'delta_y_lower', 'add_movement_buttons', 'setup_boxes', 'initial_plot',
12
- 'show']
13
-
14
- # %% ../nbs/xai.ipynb 1
15
- #Weight & Biases
16
- import wandb
17
-
18
- #Yaml
19
- from yaml import load, FullLoader
20
-
21
- #Embeddings
22
- from .all import *
23
- from tsai.data.preparation import prepare_forecasting_data
24
- from tsai.data.validation import get_forecasting_splits
25
- from fastcore.all import *
26
-
27
- #Dimensionality reduction
28
- from tsai.imports import *
29
-
30
- #Clustering
31
- import hdbscan
32
- import time
33
- from .dr import get_PCA_prjs, get_UMAP_prjs, get_TSNE_prjs
34
-
35
- import seaborn as sns
36
- import matplotlib.pyplot as plt
37
- import pandas as pd
38
- import ipywidgets as widgets
39
- from IPython.display import display
40
- from functools import partial
41
-
42
- from IPython.display import display, clear_output, HTML as IPHTML
43
- from ipywidgets import Button, Output, VBox, HBox, HTML, Layout, FloatSlider
44
-
45
- import plotly.graph_objs as go
46
- import plotly.offline as py
47
- import plotly.io as pio
48
- #! pip install kaleido
49
- import kaleido
50
-
51
-
52
- # %% ../nbs/xai.ipynb 4
53
- def get_embeddings(config_lrp, run_lrp, api, print_flag = False):
54
- artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact
55
- emb_artifact = artifacts_gettr(config_lrp.emb_artifact, type='embeddings')
56
- if print_flag: print(emb_artifact.name)
57
- emb_config = emb_artifact.logged_by().config
58
- return emb_artifact.to_obj(), emb_artifact, emb_config
59
-
60
- # %% ../nbs/xai.ipynb 5
61
- def get_dataset(
62
- config_lrp,
63
- config_emb,
64
- config_dr,
65
- run_lrp,
66
- api,
67
- print_flag = False
68
- ):
69
- # Botch to use artifacts offline
70
- artifacts_gettr = run_lrp.use_artifact if config_lrp.use_wandb else api.artifact
71
- enc_artifact = artifacts_gettr(config_emb['enc_artifact'], type='learner')
72
- if print_flag: print (enc_artifact.name)
73
- ## TODO: This only works when you run it two timeS! WTF?
74
- try:
75
- enc_learner = enc_artifact.to_obj()
76
- except:
77
- enc_learner = enc_artifact.to_obj()
78
-
79
- ## Restore artifact
80
- enc_logger = enc_artifact.logged_by()
81
- enc_artifact_train = artifacts_gettr(enc_logger.config['train_artifact'], type='dataset')
82
- #cfg_.show_attrdict(enc_logger.config)
83
- if enc_logger.config['valid_artifact'] is not None:
84
- enc_artifact_valid = artifacts_gettr(enc_logger.config['valid_artifact'], type='dataset')
85
- if print_flag: print("enc_artifact_valid:", enc_artifact_valid.name)
86
- if print_flag: print("enc_artifact_train: ", enc_artifact_train.name)
87
-
88
- if config_dr['dr_artifact'] is not None:
89
- print("Is not none")
90
- dr_artifact = artifacts_gettr(config_dr['enc_artifact'])
91
- else:
92
- dr_artifact = enc_artifact_train
93
- if print_flag: print("DR artifact train: ", dr_artifact.name)
94
- if print_flag: print("--> DR artifact name", dr_artifact.name)
95
- dr_artifact
96
- df = dr_artifact.to_df()
97
- if print_flag: print("--> DR After to df", df.shape)
98
- if print_flag: display(df.head())
99
- return df, dr_artifact, enc_artifact, enc_learner
100
-
101
- # %% ../nbs/xai.ipynb 6
102
- def umap_parameters(config_dr, config):
103
- umap_params_cpu = {
104
- 'n_neighbors' : config_dr.n_neighbors,
105
- 'min_dist' : config_dr.min_dist,
106
- 'random_state': np.uint64(822569775),
107
- 'metric': config_dr.metric,
108
- #'a': 1.5769434601962196,
109
- #'b': 0.8950608779914887,
110
- #'metric_kwds': {'p': 2}, #No debería ser necesario, just in case
111
- #'output_metric': 'euclidean',
112
- 'verbose': 4,
113
- #'n_epochs': 200
114
- }
115
- umap_params_gpu = {
116
- 'n_neighbors' : config_dr.n_neighbors,
117
- 'min_dist' : config_dr.min_dist,
118
- 'random_state': np.uint64(1234),
119
- 'metric': config_dr.metric,
120
- 'a': 1.5769434601962196,
121
- 'b': 0.8950608779914887,
122
- 'target_metric': 'euclidean',
123
- 'target_n_neighbors': config_dr.n_neighbors,
124
- 'verbose': 4, #6, #CUML_LEVEL_TRACE
125
- 'n_epochs': 200*3*2,
126
- 'init': 'random',
127
- 'hash_input': True
128
- }
129
- if config_dr.cpu_flag:
130
- umap_params = umap_params_cpu
131
- else:
132
- umap_params = umap_params_gpu
133
- return umap_params
134
-
135
- # %% ../nbs/xai.ipynb 7
136
- def get_prjs(embs_no_nan, config_dr, config, print_flag = False):
137
- umap_params = umap_parameters(config_dr, config)
138
- prjs_pca = get_PCA_prjs(
139
- X = embs_no_nan,
140
- cpu = False,
141
- print_flag = print_flag,
142
- **umap_params
143
- )
144
- if print_flag:
145
- print(prjs_pca.shape)
146
- prjs_umap = get_UMAP_prjs(
147
- input_data = prjs_pca,
148
- cpu = config_dr.cpu_flag, #config_dr.cpu,
149
- print_flag = print_flag,
150
- **umap_params
151
- )
152
- if print_flag: prjs_umap.shape
153
- return prjs_umap
154
-
155
- # %% ../nbs/xai.ipynb 9
156
- def plot_projections(prjs, umap_params, fig_size = (25,25)):
157
- "Plot 2D projections thorugh a connected scatter plot"
158
- df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2'])
159
- fig = plt.figure(figsize=(fig_size[0],fig_size[1]))
160
- ax = fig.add_subplot(111)
161
- ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1)
162
- ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1)
163
- plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format(
164
- umap_params['n_neighbors'],umap_params['min_dist']))
165
- return ax
166
-
167
- # %% ../nbs/xai.ipynb 10
168
- def plot_projections_clusters(prjs, clusters_labels, umap_params, fig_size = (25,25)):
169
- "Plot 2D projections thorugh a connected scatter plot"
170
- df_prjs = pd.DataFrame(prjs, columns = ['x1', 'x2'])
171
- df_prjs['cluster'] = clusters_labels
172
-
173
- fig = plt.figure(figsize=(fig_size[0],fig_size[1]))
174
- ax = fig.add_subplot(111)
175
-
176
- # Create a scatter plot for each cluster with different colors
177
- unique_labels = df_prjs['cluster'].unique()
178
- print(unique_labels)
179
- for label in unique_labels:
180
- cluster_data = df_prjs[df_prjs['cluster'] == label]
181
- ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}')
182
- #ax.scatter(df_prjs['x1'], df_prjs['x2'], marker='o', facecolors='none', edgecolors='b', alpha=0.1)
183
-
184
- #ax.plot(df_prjs['x1'], df_prjs['x2'], alpha=0.5, picker=1)
185
- plt.title('DR params - n_neighbors:{:d} min_dist:{:f}'.format(
186
- umap_params['n_neighbors'],umap_params['min_dist']))
187
- return ax
188
-
189
- # %% ../nbs/xai.ipynb 11
190
- def calculate_cluster_stats(data, labels):
191
- """Computes the media and the standard deviation for every cluster."""
192
- cluster_stats = {}
193
- for label in np.unique(labels):
194
- #members = data[labels == label]
195
- members = data
196
- mean = np.mean(members, axis = 0)
197
- std = np.std(members, axis = 0)
198
- cluster_stats[label] = (mean, std)
199
- return cluster_stats
200
-
201
- # %% ../nbs/xai.ipynb 12
202
- def anomaly_score(point, cluster_stats, label):
203
- """Computes an anomaly score for each point."""
204
- mean, std = cluster_stats[label]
205
- return np.linalg.norm((point - mean) / std)
206
-
207
- # %% ../nbs/xai.ipynb 13
208
- def detector(data, labels):
209
- """Anomaly detection function."""
210
- cluster_stats = calculate_cluster_stats(data, labels)
211
- scores = []
212
- for point, label in zip(data, labels):
213
- score = anomaly_score(point, cluster_stats, label)
214
- scores.append(score)
215
- return np.array(scores)
216
-
217
- # %% ../nbs/xai.ipynb 15
218
- def plot_anomaly_scores_distribution(anomaly_scores):
219
- "Plot the distribution of anomaly scores to check for normality"
220
- plt.figure(figsize=(10, 6))
221
- sns.histplot(anomaly_scores, kde=True, bins=30)
222
- plt.title("Distribución de Anomaly Scores")
223
- plt.xlabel("Anomaly Score")
224
- plt.ylabel("Frecuencia")
225
- plt.show()
226
-
227
- # %% ../nbs/xai.ipynb 16
228
- def plot_clusters_with_anomalies(prjs, clusters_labels, anomaly_scores, threshold, fig_size=(25, 25)):
229
- "Plot 2D projections of clusters and superimpose anomalies"
230
- df_prjs = pd.DataFrame(prjs, columns=['x1', 'x2'])
231
- df_prjs['cluster'] = clusters_labels
232
- df_prjs['anomaly'] = anomaly_scores > threshold
233
-
234
- fig = plt.figure(figsize=(fig_size[0], fig_size[1]))
235
- ax = fig.add_subplot(111)
236
-
237
- # Plot each cluster with different colors
238
- unique_labels = df_prjs['cluster'].unique()
239
- for label in unique_labels:
240
- cluster_data = df_prjs[df_prjs['cluster'] == label]
241
- ax.scatter(cluster_data['x1'], cluster_data['x2'], label=f'Cluster {label}', alpha=0.7)
242
-
243
- # Superimpose anomalies
244
- anomalies = df_prjs[df_prjs['anomaly']]
245
- ax.scatter(anomalies['x1'], anomalies['x2'], color='red', label='Anomalies', edgecolor='k', s=50)
246
-
247
- plt.title('Clusters and anomalies')
248
- plt.legend()
249
- plt.show()
250
-
251
- def update_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size):
252
- plot_clusters_with_anomalies(prjs_umap, clusters_labels, anomaly_scores, threshold, fig_size)
253
-
254
- def plot_clusters_with_anomalies_interactive_plot(threshold, prjs_umap, clusters_labels, anomaly_scores, fig_size):
255
- threshold_slider = widgets.FloatSlider(value=threshold, min=0.001, max=3, step=0.001, description='Threshold')
256
- interactive_plot = widgets.interactive(update_plot, threshold = threshold_slider,
257
- prjs_umap = widgets.fixed(prjs_umap),
258
- clusters_labels = widgets.fixed(clusters_labels),
259
- anomaly_scores = widgets.fixed(anomaly_scores),
260
- fig_size = widgets.fixed((25,25)))
261
- display(interactive_plot)
262
-
263
-
264
- # %% ../nbs/xai.ipynb 18
265
- import plotly.express as px
266
- from datetime import timedelta
267
-
268
- # %% ../nbs/xai.ipynb 19
269
- def get_df_selected(df, selected_indices, w, stride = 1): #Cuidado con stride
270
- '''Links back the selected points to the original dataframe and returns the associated windows indices'''
271
- n_windows = len(selected_indices)
272
- window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices]
273
- #window_ranges = [(id*w, (id+1)*w+1) for id in selected_indices]
274
- #window_ranges = [(id*stride, (id*stride)+w) for id in selected_indices]
275
- #print(window_ranges)
276
- valores_tramos = [df.iloc[inicio:fin+1] for inicio, fin in window_ranges]
277
- df_selected = pd.concat(valores_tramos, ignore_index=False)
278
- return window_ranges, n_windows, df_selected
279
-
280
- # %% ../nbs/xai.ipynb 20
281
- def shift_datetime(dt, seconds, sign, dateformat="%Y-%m-%d %H:%M:%S.%f", print_flag = False):
282
- """
283
- This function gets a datetime dt, a number of seconds,
284
- a sign and moves the date such number of seconds to the future
285
- if sign is '+' and to the past if sing is '-'.
286
- """
287
-
288
- if print_flag: print(dateformat)
289
- dateformat2= "%Y-%m-%d %H:%M:%S.%f"
290
- dateformat3 = "%Y-%m-%d"
291
- ok = False
292
-
293
- try:
294
- if print_flag: print("dt ", dt, "seconds", seconds, "sign", sign)
295
- new_dt = datetime.strptime(dt, dateformat)
296
- if print_flag: print("ndt", new_dt)
297
- ok = True
298
- except ValueError as e:
299
- if print_flag:
300
- print("Error: ", e)
301
-
302
- if (not ok):
303
- try:
304
- if print_flag: print("Parsing alternative dataformat", dt, "seconds", seconds, "sign", sign, dateformat2)
305
- new_dt = datetime.strptime(dt, dateformat3)
306
- if print_flag: print("2ndt", new_dt)
307
- except ValueError as e:
308
- print("Error: ", e)
309
- if print_flag: print(new_dt)
310
- try:
311
-
312
- if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0:
313
- if print_flag: "Aqui"
314
- new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0)
315
- if print_flag: print(new_dt)
316
-
317
- if print_flag: print("ndt", new_dt)
318
-
319
- if (sign == '+'):
320
- if print_flag: print("Aqui")
321
- new_dt = new_dt + timedelta(seconds = seconds)
322
- if print_flag: print(new_dt)
323
- else:
324
- if print_flag: print(sign, type(dt))
325
- new_dt = new_dt - timedelta(seconds = seconds)
326
- if print_flag: print(new_dt)
327
- if new_dt.hour == 0 and new_dt.minute == 0 and new_dt.second == 0:
328
- if print_flag: print("replacing")
329
- new_dt = new_dt.replace(hour=0, minute=0, second=0, microsecond=0)
330
-
331
- new_dt_str = new_dt.strftime(dateformat2)
332
- if print_flag: print("new dt ", new_dt)
333
- except ValueError as e:
334
- if print_flag: print("Aqui3")
335
- shift_datetime(dt, 0, sign, dateformat = "%Y-%m-%d", print_flag = False)
336
- return str(e)
337
- return new_dt_str
338
-
339
-
340
-
341
- # %% ../nbs/xai.ipynb 21
342
- def get_dateformat(text_date):
343
- dateformat1 = "%Y-%m-%d %H:%M:%S"
344
- dateformat2 = "%Y-%m-%d %H:%M:%S.%f"
345
- dateformat3 = "%Y-%m-%d"
346
- dateformat = ""
347
- parts = text_date.split()
348
-
349
- if len(parts) == 2:
350
- time_parts = parts[1].split(':')
351
- if len(time_parts) == 3:
352
- sec_parts = time_parts[2].split('.')
353
- if len(sec_parts) == 2:
354
- dateformat = dateformat2
355
- else:
356
- dateformat = dateformat1
357
- else:
358
- dateformat = "unknown format 1"
359
- elif len(parts) == 1:
360
- dateformat = dateformat3
361
- else:
362
- dateformat = "unknown format 2"
363
-
364
- return dateformat
365
-
366
- # %% ../nbs/xai.ipynb 23
367
- def get_anomalies(df, threshold, flag):
368
- df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score']]
369
-
370
- def get_anomaly_styles(df, threshold, anomaly_scores, flag = False, print_flag = False):
371
- if print_flag: print("Threshold: ", threshold)
372
- if print_flag: print("Flag", flag)
373
- if print_flag: print("df ~", df.shape)
374
- df['anomaly'] = [ (score > threshold) and flag for score in df['anomaly_score'] ]
375
- if print_flag: print(df)
376
- get_anomalies(df, threshold, flag)
377
- anomalies = df[df['anomaly']]
378
- if flag:
379
- df['anomaly'] = [
380
- (score > threshold) and flag
381
- for score in anomaly_scores
382
- ]
383
- symbols = [
384
- 'x' if is_anomaly else 'circle'
385
- for is_anomaly in df['anomaly']
386
- ]
387
- line_colors = [
388
- 'black'
389
- if (is_anomaly and flag) else 'rgba(0,0,0,0)'
390
- for is_anomaly in df['anomaly']
391
- ]
392
- else:
393
- symbols = ['circle' for _ in df['x1']]
394
- line_colors = ['rgba(0,0,0,0)' for _ in df['x1']]
395
- if print_flag: print(anomalies)
396
- return symbols, line_colors
397
- ### Example of use
398
- #prjs_df = pd.DataFrame(prjs_umap, columns = ['x1', 'x2'])
399
- #prjs_df['anomaly_score'] = anomaly_scores
400
- #s, l = get_anomaly_styles(prjs_df, 1, True)
401
-
402
- # %% ../nbs/xai.ipynb 24
403
- class InteractiveAnomalyPlot():
404
- def __init__(
405
- self, selected_indices = [],
406
- threshold = 0.15,
407
- anomaly_flag = False,
408
- path = "../imgs", w = 0
409
- ):
410
- self.selected_indices = selected_indices
411
- self.selected_indices_tmp = selected_indices
412
- self.threshold = threshold
413
- self.threshold_ = threshold
414
- self.anomaly_flag = anomaly_flag
415
- self.w = w
416
- self.name = f"w={self.w}"
417
- self.path = f"{path}{self.name}.png"
418
- self.interaction_enabled = True
419
-
420
-
421
- def plot_projections_clusters_interactive(
422
- self, prjs, cluster_labels, umap_params, anomaly_scores=[], fig_size=(7,7), print_flag = False
423
- ):
424
- self.selected_indices_tmp = self.selected_indices
425
- py.init_notebook_mode()
426
-
427
- prjs_df, cluster_colors = plot_initial_config(prjs, cluster_labels, anomaly_scores)
428
- legend_items = [widgets.HTML(f'<b>Cluster {cluster}:</b> <span style="color:{color};">■</span>')
429
- for cluster, color in cluster_colors.items()]
430
- legend = widgets.VBox(legend_items)
431
-
432
- marker_colors = prjs_df['cluster'].map(cluster_colors)
433
-
434
- symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag)
435
-
436
- fig = go.FigureWidget(
437
- [
438
- go.Scatter(
439
- x=prjs_df['x1'], y=prjs_df['x2'],
440
- mode="markers",
441
- marker= {
442
- 'color': marker_colors,
443
- 'line': { 'color': line_colors, 'width': 1 },
444
- 'symbol': symbols
445
- },
446
- text = prjs_df.index
447
- )
448
- ]
449
- )
450
-
451
- line_trace = go.Scatter(
452
- x=prjs_df['x1'],
453
- y=prjs_df['x2'],
454
- mode="lines",
455
- line=dict(color='rgba(128, 128, 128, 0.5)', width=1)#,
456
- #showlegend=False # Puedes configurar si deseas mostrar esta línea en la leyenda
457
- )
458
-
459
- fig.add_trace(line_trace)
460
-
461
- sca = fig.data[0]
462
-
463
- fig.update_layout(
464
- dragmode='lasso',
465
- width=700,
466
- height=500,
467
- title={
468
- 'text': '<span style="font-weight:bold">DR params - n_neighbors:{:d} min_dist:{:f}</span>'.format(
469
- umap_params['n_neighbors'], umap_params['min_dist']),
470
- 'y':0.98,
471
- 'x':0.5,
472
- 'xanchor': 'center',
473
- 'yanchor': 'top'
474
- },
475
- plot_bgcolor='white',
476
- paper_bgcolor='#f0f0f0',
477
- xaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'x'),
478
- yaxis=dict(gridcolor='lightgray', zerolinecolor='black', title = 'y'),
479
- margin=dict(l=10, r=20, t=30, b=10)
480
-
481
-
482
- )
483
-
484
- output_tmp = Output()
485
- output_button = Output()
486
- output_anomaly = Output()
487
- output_threshold = Output()
488
- output_width = Output()
489
-
490
- def select_action(trace, points, selector):
491
- self.selected_indices_tmp = points.point_inds
492
- with output_tmp:
493
- output_tmp.clear_output(wait=True)
494
- if print_flag: print("Selected indices tmp:", self.selected_indices_tmp)
495
-
496
- def button_action(b):
497
- self.selected_indices = self.selected_indices_tmp
498
- with output_button:
499
- output_button.clear_output(wait = True)
500
- if print_flag: print("Selected indices:", self.selected_indices)
501
-
502
-
503
- def update_anomalies():
504
- if print_flag: print("About to update anomalies")
505
-
506
- symbols, line_colors = get_anomaly_styles(prjs_df, self.threshold_, anomaly_scores, self.anomaly_flag, print_flag)
507
-
508
- if print_flag: print("Anomaly styles got")
509
-
510
- with fig.batch_update():
511
- fig.data[0].marker.symbol = symbols
512
- fig.data[0].marker.line.color = line_colors
513
- if print_flag: print("Anomalies updated")
514
- if print_flag: print("Threshold: ", self.threshold_)
515
- if print_flag: print("Scores: ", anomaly_scores)
516
-
517
-
518
- def anomaly_action(b):
519
- with output_anomaly: # Cambia output_flag a output_anomaly
520
- output_anomaly.clear_output(wait=True)
521
- if print_fllag: print("Negate anomaly flag")
522
- self.anomaly_flag = not self.anomaly_flag
523
- if print_flag: print("Show anomalies:", self.anomaly_flag)
524
- update_anomalies()
525
-
526
- sca.on_selection(select_action)
527
- layout = widgets.Layout(width='auto', height='40px')
528
- button = Button(
529
- description="Update selected_indices",
530
- style = {'button_color': 'lightblue'},
531
- display = 'flex',
532
- flex_row = 'column',
533
- align_items = 'stretch',
534
- layout = layout
535
- )
536
- anomaly_button = Button(
537
- description = "Show anomalies",
538
- style = {'button_color': 'lightgray'},
539
- display = 'flex',
540
- flex_row = 'column',
541
- align_items = 'stretch',
542
- layout = layout
543
- )
544
-
545
- button.on_click(button_action)
546
- anomaly_button.on_click(anomaly_action)
547
-
548
- ##### Reactivity buttons
549
- pause_button = Button(
550
- description = "Pause interactiveness",
551
- style = {'button_color': 'pink'},
552
- display = 'flex',
553
- flex_row = 'column',
554
- align_items = 'stretch',
555
- layout = layout
556
- )
557
- resume_button = Button(
558
- description = "Resume interactiveness",
559
- style = {'button_color': 'lightgreen'},
560
- display = 'flex',
561
- flex_row = 'column',
562
- align_items = 'stretch',
563
- layout = layout
564
- )
565
-
566
-
567
- threshold_slider = FloatSlider(
568
- value=self.threshold_,
569
- min=0.0,
570
- max=float(np.ceil(self.threshold+5)),
571
- step=0.0001,
572
- description='Anomaly threshold:',
573
- continuous_update=False
574
- )
575
-
576
- def pause_interaction(b):
577
- self.interaction_enabled = False
578
- fig.update_layout(dragmode='pan')
579
-
580
- def resume_interaction(b):
581
- self.interaction_enabled = True
582
- fig.update_layout(dragmode='lasso')
583
-
584
-
585
- def update_threshold(change):
586
- with output_threshold:
587
- output_threshold.clear_output(wait = True)
588
- if print_flag: print("Update threshold")
589
- self.threshold_ = change.new
590
- if print_flag: print("Update anomalies threshold = ", self.threshold_)
591
- update_anomalies()
592
-
593
- #### Width
594
- width_slider = FloatSlider(
595
- value = 0.5,
596
- min = 0.0,
597
- max = 1.0,
598
- step = 0.0001,
599
- description = 'Line width:',
600
- continuous_update = False
601
- )
602
-
603
- def update_width(change):
604
- with output_width:
605
- try:
606
- output_width.clear_output(wait = True)
607
- if print_flag:
608
- print("Change line width")
609
- print("Trace to update:", fig.data[1])
610
- with fig.batch_update():
611
- fig.data[1].line.width = change.new # Actualiza la opacidad de la línea
612
- if print_flag: print("ChangeD line width")
613
- except Exception as e:
614
- print("Error updating line width:", e)
615
-
616
-
617
-
618
- pause_button.on_click(pause_interaction)
619
- resume_button.on_click(resume_interaction)
620
-
621
- threshold_slider.observe(update_threshold, 'value')
622
-
623
- ####
624
- width_slider.observe(update_width, names = 'value')
625
-
626
- #####
627
- space = HTML("&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;")
628
-
629
- vbox = VBox((output_tmp, output_button, output_anomaly, output_threshold, fig))
630
- hbox = HBox((space, button, space, pause_button, space, resume_button, anomaly_button))
631
-
632
- # Centrar las dos cajas horizontalmente en el VBox
633
-
634
- box_layout = widgets.Layout(display='flex',
635
- flex_flow='column',
636
- align_items='center',
637
- width='100%')
638
-
639
- if self.anomaly_flag:
640
- box = VBox((hbox,threshold_slider,width_slider, output_width, vbox), layout = box_layout)
641
- else:
642
- box = VBox((hbox, width_slider, output_width, vbox), layout = box_layout)
643
- box.add_class("layout")
644
- plot_save(fig, self.w)
645
-
646
- display(box)
647
-
648
-
649
- # %% ../nbs/xai.ipynb 25
650
- def plot_save(fig, w):
651
- image_bytes = pio.to_image(fig, format='png')
652
- with open(f"../imgs/w={w}.png", 'wb') as f:
653
- f.write(image_bytes)
654
-
655
-
656
- # %% ../nbs/xai.ipynb 26
657
- def plot_initial_config(prjs, cluster_labels, anomaly_scores):
658
- prjs_df = pd.DataFrame(prjs, columns = ['x1', 'x2'])
659
- prjs_df['cluster'] = cluster_labels
660
- prjs_df['anomaly_score'] = anomaly_scores
661
-
662
- cluster_colors_df = pd.DataFrame({'cluster': cluster_labels}).drop_duplicates()
663
- cluster_colors_df['color'] = px.colors.qualitative.Set1[:len(cluster_colors_df)]
664
- cluster_colors = dict(zip(cluster_colors_df['cluster'], cluster_colors_df['color']))
665
- return prjs_df, cluster_colors
666
-
667
- # %% ../nbs/xai.ipynb 27
668
- def merge_overlapping_windows(windows):
669
- if not windows:
670
- return []
671
-
672
- # Order
673
- sorted_windows = sorted(windows, key=lambda x: x[0])
674
-
675
- merged_windows = [sorted_windows[0]]
676
-
677
- for window in sorted_windows[1:]:
678
- if window[0] <= merged_windows[-1][1]:
679
- # Merge!
680
- merged_windows[-1] = (merged_windows[-1][0], max(window[1], merged_windows[-1][1]))
681
- else:
682
- merged_windows.append(window)
683
-
684
- return merged_windows
685
-
686
- # %% ../nbs/xai.ipynb 29
687
- class InteractiveTSPlot:
688
- def __init__(
689
- self,
690
- df,
691
- selected_indices,
692
- meaningful_features_subset_ids,
693
- w,
694
- stride=1,
695
- print_flag=False,
696
- num_points=10000,
697
- dateformat='%Y-%m-%d %H:%M:%S',
698
- delta_x = 10,
699
- delta_y = 0.1
700
- ):
701
- self.df = df
702
- self.selected_indices = selected_indices
703
- self.meaningful_features_subset_ids = meaningful_features_subset_ids
704
- self.w = w
705
- self.stride = stride
706
- self.print_flag = print_flag
707
- self.num_points = num_points
708
- self.dateformat = dateformat
709
- self.fig = go.FigureWidget()
710
- self.buttons = []
711
- self.print_flag = print_flag
712
-
713
- self.delta_x = delta_x
714
- self.delta_y = delta_y
715
-
716
- self.window_ranges, self.n_windows, self.df_selected = get_df_selected(
717
- self.df, self.selected_indices, self.w, self.stride
718
- )
719
- # Ensure the small possible number of windows to plot (like in R Shiny App)
720
- self.window_ranges = merge_overlapping_windows(self.window_ranges)
721
-
722
- #Num points no va bien...
723
- #num_points = min(df_selected.shape[0], num_points)
724
-
725
- if self.print_flag:
726
- print("windows: ", self.n_windows, self.window_ranges)
727
- print("selected id: ", self.df_selected.index)
728
- print("points: ", self.num_points)
729
-
730
- self.df.index = self.df.index.astype(str)
731
- self.fig = go.FigureWidget()
732
- self.colors = [
733
- f'rgb({np.random.randint(0, 256)}, {np.random.randint(0, 256)}, {np.random.randint(0, 256)})'
734
- for _ in range(self.n_windows)
735
- ]
736
-
737
- ##############################
738
- # Outputs for debug printing #
739
- ##############################
740
- self.output_windows = Output()
741
- self.output_move = Output()
742
- self.output_delta_x = Output()
743
- self.output_delta_y = Output()
744
-
745
-
746
-
747
-
748
-
749
- # %% ../nbs/xai.ipynb 30
750
- def add_selected_features(self: InteractiveTSPlot):
751
- # Add features time series
752
- for feature_id in self.df.columns:
753
- feature_pos = self.df.columns.get_loc(feature_id)
754
- trace = go.Scatter(
755
- #x=df.index[:num_points],
756
- #y=df[feature_id][:num_points],
757
- x = self.df.index,
758
- y = self.df[feature_id],
759
- mode='lines',
760
- name=feature_id,
761
- visible=feature_pos in self.meaningful_features_subset_ids,
762
- text=self.df.index
763
- #text=[f'{i}-{val}' for i, val in enumerate(df.index)]
764
- )
765
- self.fig.add_trace(trace)
766
-
767
- InteractiveTSPlot.add_selected_features = add_selected_features
768
-
769
- # %% ../nbs/xai.ipynb 31
770
- def add_windows(self: InteractiveTSPlot):
771
- for i, (start, end) in enumerate(self.window_ranges):
772
- self.fig.add_shape(
773
- type="rect",
774
- x0=self.df.index[start],
775
- x1=self.df.index[end],
776
- y0= 0,
777
- y1= 1,
778
- yref = "paper",
779
- fillcolor=self.colors[i], #"LightSalmon",
780
- opacity=0.25,
781
- layer="below",
782
- line=dict(color=self.colors[i], width=1),
783
- name = f"w_{i}"
784
- )
785
- with self.output_windows:
786
- print("w[" + str( self.selected_indices[i] )+ "]="+str(self.df.index[start])+", "+str(self.df.index[end])+")")
787
-
788
- InteractiveTSPlot.add_windows = add_windows
789
-
790
- # %% ../nbs/xai.ipynb 32
791
- def setup_style(self: InteractiveTSPlot):
792
- self.fig.update_layout(
793
- title='Time Series with time window plot',
794
- xaxis_title='Datetime',
795
- yaxis_title='Value',
796
- legend_title='Variables',
797
- margin=dict(l=10, r=10, t=30, b=10),
798
- xaxis=dict(
799
- tickformat = '%d-' + self.dateformat,
800
- #tickvals=list(range(len(df.index))),
801
- #ticktext = [f'{i}-{val}' for i, val in enumerate(df.index)]
802
- #grid_color = 'lightgray', zerolinecolor='black', title = 'x'
803
- ),
804
- #yaxis = dict(grid_color = 'lightgray', zerolinecolor='black', title = 'y'),
805
- #plot_color = 'white',
806
- paper_bgcolor='#f0f0f0'
807
- )
808
- self.fig.update_yaxes(fixedrange=True)
809
-
810
- InteractiveTSPlot.setup_style = setup_style
811
-
812
- # %% ../nbs/xai.ipynb 34
813
- def toggle_trace(self : InteractiveTSPlot, button : Button):
814
- idx = button.description
815
- trace = self.fig.data[self.df.columns.get_loc(idx)]
816
- trace.visible = not trace.visible
817
-
818
- InteractiveTSPlot.toggle_trace = toggle_trace
819
-
820
- # %% ../nbs/xai.ipynb 35
821
- def set_features_buttons(self):
822
- self.buttons = [
823
- Button(
824
- description=str(feature_id),
825
- button_style='success' if self.df.columns.get_loc(feature_id) in self.meaningful_features_subset_ids else ''
826
- )
827
- for feature_id in self.df.columns
828
- ]
829
- for button in self.buttons:
830
- button.on_click(self.toggle_trace)
831
- InteractiveTSPlot.set_features_buttons = set_features_buttons
832
-
833
- # %% ../nbs/xai.ipynb 36
834
- def move_left(self : InteractiveTSPlot, button : Button):
835
- with self.output_move:
836
- self.output_move.clear_output(wait=True)
837
- start_date, end_date = self.fig.layout.xaxis.range
838
- new_start_date = shift_datetime(start_date, self.delta_x, '-', self.dateformat, self.print_flag)
839
- new_end_date = shift_datetime(end_date, self.delta_x, '-', self.dateformat, self.print_flag)
840
- with self.fig.batch_update():
841
- self.fig.layout.xaxis.range = [new_start_date, new_end_date]
842
-
843
- def move_right(self : InteractiveTSPlot, button : Button):
844
- self.output_move.clear_output(wait=True)
845
- with self.output_move:
846
- start_date, end_date = self.fig.layout.xaxis.range
847
- new_start_date = shift_datetime(start_date, self.delta_x, '+', self.dateformat, self.print_flag)
848
- new_end_date = shift_datetime(end_date, self.delta_x, '+', self.dateformat, self.print_flag)
849
- with self.fig.batch_update():
850
- self.fig.layout.xaxis.range = [new_start_date, new_end_date]
851
-
852
- def move_down(self: InteractiveTSPlot, button : Button):
853
- with self.output_move:
854
- self.output_move.clear_output(wait=True)
855
- start_y, end_y = self.fig.layout.yaxis.range
856
- with self.fig.batch_update():
857
- self.ig.layout.yaxis.range = [start_y-self.delta_y, end_y-self.delta_y]
858
- def move_up(self: InteractiveTSPlot, button : Button):
859
- with self.output_move:
860
- self.output_move.clear_output(wait=True)
861
- start_y, end_y = self.fig.layout.yaxis.range
862
- with self.fig.batch_update():
863
- self.fig.layout.yaxis.range = [start_y+self.delta_y, end_y+self.delta_y]
864
-
865
- InteractiveTSPlot.move_left = move_left
866
- InteractiveTSPlot.move_right = move_right
867
- InteractiveTSPlot.move_down = move_down
868
- InteractiveTSPlot.move_up = move_up
869
-
870
- # %% ../nbs/xai.ipynb 37
871
- def delta_x_bigger(self: InteractiveTSPlot):
872
- with self.output_delta_x:
873
- self.output_delta_x.clear_output(wait = True)
874
- if self.print_flag: print("Delta before", self.delta_x)
875
- self.delta_x *= 10
876
- if self.print_flag: print("delta_x:", self.delta_x)
877
-
878
- def delta_y_bigger(self: InteractiveTSPlot):
879
- with self.output_delta_y:
880
- self.output_delta_y.clear_output(wait = True)
881
- if self.print_flag: print("Delta before", self.delta_y)
882
- self.delta_y *= 10
883
- if self.print_flag: print("delta_y:", self.delta_y)
884
-
885
- def delta_x_lower(self:InteractiveTSPlot):
886
- with self.output_delta_x:
887
- self.output_delta_x.clear_output(wait = True)
888
- if self.print_flag: print("Delta before", self.delta_x)
889
- self.delta_x /= 10
890
- if self.print_flag: print("delta_x:", self.delta_x)
891
-
892
- def delta_y_lower(self:InteractiveTSPlot):
893
- with self.output_delta_y:
894
- self.output_delta_y.clear_output(wait = True)
895
- print("Delta before", self.delta_y)
896
- self.delta_y = self.delta_y * 10
897
- print("delta_y:", self.delta_y)
898
- InteractiveTSPlot.delta_x_bigger = delta_x_bigger
899
- InteractiveTSPlot.delta_y_bigger = delta_y_bigger
900
- InteractiveTSPlot.delta_x_lower = delta_x_lower
901
- InteractiveTSPlot.delta_y_lower = delta_y_lower
902
-
903
- # %% ../nbs/xai.ipynb 38
904
- def add_movement_buttons(self: InteractiveTSPlot):
905
- self.button_left = Button(description="←")
906
- self.button_right = Button(description="→")
907
- self.button_up = Button(description="↑")
908
- self.button_down = Button(description="↓")
909
-
910
- self.button_step_x_up = Button(description="dx ↑")
911
- self.button_step_x_down = Button(description="dx ↓")
912
- self.button_step_y_up = Button(description="dy↑")
913
- self.button_step_y_down = Button(description="dy↓")
914
-
915
-
916
- # TODO: Arreglar que se pueda modificar el paso con el que se avanza. No se ve el output y no se modifica el valor
917
- self.button_step_x_up.on_click(self.delta_x_bigger)
918
- self.button_step_x_down.on_click(self.delta_x_lower)
919
- self.button_step_y_up.on_click(self.delta_y_bigger)
920
- self.button_step_y_down.on_click(self.delta_y_lower)
921
-
922
- self.button_left.on_click(self.move_left)
923
- self.button_right.on_click(self.move_right)
924
- self.button_up.on_click(self.move_up)
925
- self.button_down.on_click(self.move_down)
926
-
927
- InteractiveTSPlot.add_movement_buttons = add_movement_buttons
928
-
929
- # %% ../nbs/xai.ipynb 40
930
- def setup_boxes(self: InteractiveTSPlot):
931
- self.steps_x = VBox([self.button_step_x_up, self.button_step_x_down])
932
- self.steps_y = VBox([self.button_step_y_up, self.button_step_y_down])
933
- arrow_buttons = HBox([self.button_left, self.button_right, self.button_up, self.button_down, self.steps_x, self.steps_y])
934
- hbox_layout = widgets.Layout(display='flex', flex_flow='row wrap', align_items='flex-start')
935
- hbox = HBox(self.buttons, layout=hbox_layout)
936
- box_layout = widgets.Layout(
937
- display='flex',
938
- flex_flow='column',
939
- align_items='center',
940
- width='100%'
941
- )
942
- if self.print_flag:
943
- self.box = VBox([hbox, arrow_buttons, self.output_move, self.output_delta_x, self.output_delta_y, self.fig, self.output_windows], layout=box_layout)
944
- else:
945
- self.box = VBox([hbox, arrow_buttons, self.fig, self.output_windows], layout=box_layout)
946
-
947
- InteractiveTSPlot.setup_boxes = setup_boxes
948
-
949
-
950
- # %% ../nbs/xai.ipynb 41
951
- def initial_plot(self: InteractiveTSPlot):
952
- self.add_selected_features()
953
- self.add_windows()
954
- self.setup_style()
955
- self.set_features_buttons()
956
- self.add_movement_buttons()
957
- self.setup_boxes()
958
- InteractiveTSPlot.initial_plot = initial_plot
959
-
960
- # %% ../nbs/xai.ipynb 42
961
- def show(self : InteractiveTSPlot):
962
- self.initial_plot()
963
- display(self.box)
964
- InteractiveTSPlot.show = show
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
local_build_docker.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Inicializa un array vacío
2
+ args=()
3
+
4
+ # Lee el archivo .env línea por línea
5
+ while IFS='=' read -r key value; do
6
+ if [[ $key != \#* && $key != '' ]]; then # Excluye comentarios y líneas vacías
7
+ args+=(--build-arg "$key=$value") # Agrega --build-arg y la variable como un elemento
8
+ fi
9
+ done < .env
10
+
11
+ echo "args: ${args[@]}"
12
+ read -p "Press enter to continue"
13
+ # Ejecuta docker build con los argumentos
14
+ docker build "${args[@]}" . -t dvatshf
local_exec_docker.sh CHANGED
@@ -4,10 +4,25 @@ args=()
4
  # Lee el archivo .env línea por línea
5
  while IFS='=' read -r key value; do
6
  if [[ $key != \#* && $key != '' ]]; then # Excluye comentarios y líneas vacías
7
- args+=(--build-arg "$key=$value") # Agrega --build-arg y la variable como un elemento
8
  fi
9
  done < .env
10
 
11
  echo "args: ${args[@]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Ejecuta docker build con los argumentos
13
- docker build "${args[@]}" . -t dvatshf
 
4
  # Lee el archivo .env línea por línea
5
  while IFS='=' read -r key value; do
6
  if [[ $key != \#* && $key != '' ]]; then # Excluye comentarios y líneas vacías
7
+ args+=(-e "$key=$value") # Agrega --build-arg y la variable como un elemento
8
  fi
9
  done < .env
10
 
11
  echo "args: ${args[@]}"
12
+ #read -p "Press enter to continue"
13
+
14
+ INTER=$1
15
+
16
+ flags=()
17
+ if((INTER == 1)); then
18
+ echo "INTERACTIVE"
19
+ flags+=(-it --entrypoint /bin/bash)
20
+ fi
21
+
22
+ flags+=("--gpus" "all")
23
+
24
+ echo "${flags[@]}"
25
+
26
+ #read -p "Press enter to continue"
27
  # Ejecuta docker build con los argumentos
28
+ docker run "${flags[@]}" "${args[@]}" -t dvatshf
r_shiny_app/global.R CHANGED
@@ -52,7 +52,7 @@ if(torch$cuda$is_available()){
52
 
53
  # Python dependencies
54
  print("--> py dependences | Tsai")
55
- Sys.setenv(MPLCONFIGDIR = "/tmp/").
56
  tsai_data = reticulate::import("tsai.data.all")
57
  print("--> py dependences | Wandb")
58
  wandb = reticulate::import("wandb")
@@ -85,9 +85,11 @@ DEFAULT_VALUES = list(metric_hdbscan = "euclidean",
85
  path_alpha = 5/10,
86
  point_alpha = 1/10,
87
  point_size = 1)
 
88
  WANDB_ENTITY = Sys.getenv("WANDB_ENTITY")
89
  WANDB_PROJECT = Sys.getenv("WANDB_PROJECT")
90
 
 
91
 
92
  ####################
93
  # HELPER FUNCTIONS #
 
52
 
53
  # Python dependencies
54
  print("--> py dependences | Tsai")
55
+ Sys.setenv(MPLCONFIGDIR = "/tmp/")
56
  tsai_data = reticulate::import("tsai.data.all")
57
  print("--> py dependences | Wandb")
58
  wandb = reticulate::import("wandb")
 
85
  path_alpha = 5/10,
86
  point_alpha = 1/10,
87
  point_size = 1)
88
+
89
  WANDB_ENTITY = Sys.getenv("WANDB_ENTITY")
90
  WANDB_PROJECT = Sys.getenv("WANDB_PROJECT")
91
 
92
+ print("Wandb API Key -->", WANDB_API_KEY)
93
 
94
  ####################
95
  # HELPER FUNCTIONS #