additional features
Browse files
@@ -173,53 +173,65 @@ def explore_dataset():
173 |
174 |
df = pd.read_parquet(parquet_path)
175 |
176 |
# Generate dataset summary
177 |
summary = df.describe(include='all').T
178 |
summary["missing_values"] = df.isnull().sum()
179 |
summary["unique_values"] = df.nunique()
180 |
summary_text = summary.to_markdown()
181 |
182 |
# Log dataset summary as text in Weights & Biases
183 |
wandb.log({"Dataset Summary": wandb.Html(summary_text)})
184 |
185 |
186 |
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
187 |
fig.suptitle("Dataset Overview", fontsize=16)
188 |
189 |
190 |
data_types = df.dtypes.value_counts()
191 |
sns.barplot(x=data_types.index.astype(str), y=data_types.values, ax=axes[0])
192 |
axes[0].set_title("Column Count by Data Type")
193 |
194 |
195 |
196 |
num_cols = df.select_dtypes(include=['number']).columns
197 |
if len(num_cols) > 0:
198 |
mean_values = df[num_cols].mean()
199 |
sns.barplot(x=mean_values.index, y=mean_values.values, ax=axes[1])
200 |
axes[1].set_title("Mean Values of Numeric Columns")
201 |
202 |
203 |
204 |
for col, mean_val in mean_values.items():
205 |
wandb.log({f"Mean Values/{col}": mean_val})
206 |
207 |
208 |
buf = io.BytesIO()
209 |
210 |
plt.savefig(buf, format='png', bbox_inches='tight')
211 |
212 |
213 |
214 |
# Convert figure to NumPy array
215 |
image =
216 |
image_array = np.array(image)
217 |
218 |
# Log image to Weights & Biases
219 |
wandb.log({"Dataset Overview": wandb.Image(image)})
220 |
221 |
return summary_text, image_array
222 |
223 |
except Exception as e:
224 |
return f"Error loading data: {str(e)}", None
225 |
173 |
174 |
df = pd.read_parquet(parquet_path)
175 |
176 |
summary = df.describe(include='all').T
177 |
summary["missing_values"] = df.isnull().sum()
178 |
summary["unique_values"] = df.nunique()
179 |
summary_text = summary.to_markdown()
180 |
181 |
wandb.log({"Dataset Summary": wandb.Html(summary_text)})
182 |
183 |
fig, axes = plt.subplots(3, 2, figsize=(14, 15))
184 |
fig.suptitle("Dataset Overview", fontsize=16)
185 |
186 |
# Column Count by Data Type
187 |
data_types = df.dtypes.value_counts()
188 |
sns.barplot(x=data_types.index.astype(str), y=data_types.values, ax=axes[0, 0])
189 |
axes[0, 0].set_title("Column Count by Data Type")
190 |
axes[0, 0].set_ylabel("Count")
191 |
axes[0, 0].set_xlabel("Column Type")
192 |
193 |
# Mean Values of Numeric Columns
194 |
num_cols = df.select_dtypes(include=['number']).columns
195 |
if len(num_cols) > 0:
196 |
mean_values = df[num_cols].mean()
197 |
sns.barplot(x=mean_values.index, y=mean_values.values, ax=axes[0, 1])
198 |
axes[0, 1].set_title("Mean Values of Numeric Columns")
199 |
axes[0, 1].set_xlabel("Column Name")
200 |
axes[0, 1].tick_params(axis='x', rotation=45)
201 |
202 |
for col, mean_val in mean_values.items():
203 |
wandb.log({f"Mean Values/{col}": mean_val})
204 |
205 |
# Step 1: Correlation Heatmap
206 |
if len(num_cols) > 0:
207 |
corr_matrix = df[num_cols].corr()
208 |
sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", ax=axes[1, 0])
209 |
axes[1, 0].set_title("Correlation Heatmap")
210 |
211 |
# Step 6: Missing Value Heatmap
212 |
sns.heatmap(df.isnull(), cmap="viridis", cbar=False, ax=axes[1, 1])
213 |
axes[1, 1].set_title("Missing Value Heatmap")
214 |
215 |
# Step 3: Pairplots for Feature Relationships
216 |
sns.pairplot(df[num_cols].sample(500), diag_kind='kde') # Sampling for performance
217 |
218 |
219 |
# Step 4: Outlier Detection
220 |
df[num_cols].plot(kind='box', subplots=True, layout=(2, 3), figsize=(14, 10), ax=axes[2, :])
221 |
axes[2, 0].set_title("Outlier Detection - Boxplots")
222 |
223 |
buf = io.BytesIO()
224 |
225 |
plt.savefig(buf, format='png', bbox_inches='tight')
226 |
227 |
228 |
229 |
image =
230 |
image_array = np.array(image)
231 |
232 |
wandb.log({"Dataset Overview": wandb.Image(image)})
233 |
234 |
return summary_text, image_array
235 |
except Exception as e:
236 |
return f"Error loading data: {str(e)}", None
237 |