Spaces:
Sleeping
Sleeping
Commit
·
a6b1d00
1
Parent(s):
a49379f
Add subgroup-harm-assessor
Browse files- .gitignore +3 -0
- .gitmodules +3 -0
- Dockerfile +12 -0
- LICENSE +21 -0
- README.md +75 -5
- app.py +1175 -0
- constants.py +20 -0
- dash_app/assets/base.css +414 -0
- dash_app/assets/style.css +85 -0
- dash_app/config.py +1 -0
- dash_app/main.py +11 -0
- dash_app/views/confusion_matrix.py +37 -0
- dash_app/views/menu.py +61 -0
- fairsd/.gitignore +138 -0
- fairsd/LICENSE +201 -0
- fairsd/README.md +20 -0
- fairsd/fairsd/__init__.py +12 -0
- fairsd/fairsd/algorithms.py +769 -0
- fairsd/fairsd/discretization.py +355 -0
- fairsd/fairsd/qualitymeasures.py +36 -0
- fairsd/fairsd/searchspace.py +208 -0
- fairsd/fairsd/sgdescription.py +252 -0
- fairsd/main.py +32 -0
- fairsd/notebooks/fairsd_settings.ipynb +494 -0
- fairsd/notebooks/fairsd_usage.ipynb +358 -0
- fairsd/requirements.txt +5 -0
- fairsd/tests/__init__.py +0 -0
- fairsd/tests/age.csv +0 -0
- fairsd/tests/discretizationtest.py +68 -0
- fairsd/tests/sgdiscoverytest.py +397 -0
- images/screenshot.png +0 -0
- load.py +352 -0
- metrics.py +211 -0
- plot.py +911 -0
- requirements.txt +19 -0
- utils.py +118 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.zed
|
3 |
+
venv
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "fairsd"]
|
2 |
+
path = fairsd
|
3 |
+
url = https://github.com/adubowski/fairsd.git
|
Dockerfile
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY requirements.txt ./
|
6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
7 |
+
|
8 |
+
COPY . .
|
9 |
+
|
10 |
+
EXPOSE 7860
|
11 |
+
|
12 |
+
CMD ["python", "app.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Adam Dubowski
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,11 +1,81 @@
|
|
1 |
---
|
2 |
-
title: Subgroup Harm Assessor
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
license: mit
|
9 |
---
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: Subgroup Harm Assessor tool
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: yellow
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
license: mit
|
9 |
---
|
10 |
|
11 |
+
# Subgroup Harm Assessor tool
|
12 |
+
|
13 |
+
## Description
|
14 |
+
This project aims to assess the fairness of machine learning models in terms of subgroup harms and predictive bias. The developed tool provides guidance and metrics to evaluate the performance of models across different impacted subgroups and identify potential biases based on model loss explanations. Notably, due to SHAP limitations, we only support tree-based models (e.g. XGBoost, LightGBM, CatBoost, etc.).
|
15 |
+
|
16 |
+
## Requirements
|
17 |
+
To run the project, you need to have Python 3.7 or higher installed (3.8 recommended). You can install the required packages using the following command:
|
18 |
+
```bash
|
19 |
+
pip install -r requirements.txt
|
20 |
+
```
|
21 |
+
|
22 |
+
## Run
|
23 |
+
To run the project, you can use the following command:
|
24 |
+
```bash
|
25 |
+
python app.py
|
26 |
+
```
|
27 |
+
|
28 |
+
To see all available configuration options, you can use the CLI:
|
29 |
+
```bash
|
30 |
+
python app.py --help
|
31 |
+
```
|
32 |
+
|
33 |
+
## Usage
|
34 |
+
The project provides a web interface to interact with the model. You can access it by opening the browser and navigating to `http://127.0.0.1:8050/` (or the address specified in the console). Step 0 is optional as the thesis argues for the average log loss difference to be used as a default quality metric for the subgroup harm assessment of risk scoring systems.
|
35 |
+
|
36 |
+

|
37 |
+
|
38 |
+
With the 6 step process, we implement the Subgroup Harm and Predictive Bias assessment framework discussed in the thesis project. The steps are as follows:
|
39 |
+
1. **Select Subgroup**: Choose the subgroup to analyze. The subgroup is defined by a (set of) discriminators based on specific feature values or ranges and ordered in terms of magnificance of the quality metric selected.
|
40 |
+
2. **Analyzing profile of underperformance**: We analyze the discriminatory and calibration abilities of the model. The subgroup is compared to the overall population and we can analyse the distribution of predicted probabilities.
|
41 |
+
3. **Analysing performance and calibration**: The next step is to analyze the performance and calibration of the model in detail, to see if we can identify any potential predictive bias or we should consider calibration per group (for the feature of the subgroup selected).
|
42 |
+
4. **Model Loss Feature importance**: We can analyze the feature importance of the model to see if certain features are less informative for the model.
|
43 |
+
5. **Feature value model loss contributions**: We can analyze the feature value contributions to the model loss to see if certain feature values are less informative for the model.
|
44 |
+
6. **Class imbalances**: We can analyze the class imbalances in the subgroup to see if the model is biased towards the majority class (indicating representation or aggregation bias).
|
45 |
+
|
46 |
+
### Experiments
|
47 |
+
To show the utility of the tool for the goal of identifying features that are less informative, we can run experiments on a random subgroup which follows the distribution of the total population. We can load an overview with the random subgroup with the following command:
|
48 |
+
```bash
|
49 |
+
python app.py --random_subgroup
|
50 |
+
```
|
51 |
+
or simply:
|
52 |
+
```bash
|
53 |
+
python app.py -r
|
54 |
+
```
|
55 |
+
|
56 |
+
Additionally, if we want to add artificial bias, we can use the following command:
|
57 |
+
```bash
|
58 |
+
python app.py --bias [bias_type]
|
59 |
+
```
|
60 |
+
where `[bias_type]` can be one of the following:
|
61 |
+
```
|
62 |
+
swap: Swap the feature values of a certain categorical feature of the subgroup
|
63 |
+
random: Add random noise to the feature values of a certain continuous feature of the subgroup
|
64 |
+
mean: Swap the continuous feature values of the subgroup with the mean of the total population
|
65 |
+
```
|
66 |
+
And others, which can be found in the `load.py` file.
|
67 |
+
|
68 |
+
## Python modules
|
69 |
+
|
70 |
+
### Loading dataset and model
|
71 |
+
In the current version of the tool we load one of the selected datasets from openml and train a model on it.
|
72 |
+
To load your own dataset, please edit the `import_dataset` or its parent `load_data` function in load.py, retaining the same method signatures.
|
73 |
+
To load your own model, please edit the `get_classifier` function in load.py, retaining the same function signature (returning the trained model and predicted probabilities of the positive class).
|
74 |
+
|
75 |
+
### Metrics
|
76 |
+
We specify the metrics used in the thesis in the `metrics.py` file. In case further metrics are needed, they can be added to the methods in that file.
|
77 |
+
|
78 |
+
## Known issues:
|
79 |
+
Currently, there are some non-critical issues with the project:
|
80 |
+
- When switching tabs, the height of the plots can switch to the default 500px. It can be fixed by reselecting the subgroup (which regenerates the plots).
|
81 |
+
- XGBoost seems to fail to train on the German Credit dataset, which is likely because the dataset has some specific characters in feature names
|
app.py
ADDED
@@ -0,0 +1,1175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import time
|
4 |
+
from typing import Tuple, Union
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
# Ignore UserWarning from shap
|
8 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
9 |
+
log = logging.getLogger("werkzeug")
|
10 |
+
log.setLevel(logging.ERROR)
|
11 |
+
|
12 |
+
import dash_bootstrap_components as dbc
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
from dash import dcc, html
|
16 |
+
from dash.dependencies import Input, Output
|
17 |
+
from dash.exceptions import PreventUpdate
|
18 |
+
|
19 |
+
from load import (
|
20 |
+
add_bias,
|
21 |
+
get_classifier,
|
22 |
+
get_fairsd_result_set,
|
23 |
+
load_data,
|
24 |
+
get_shap_logloss,
|
25 |
+
)
|
26 |
+
from dash_app.config import APP_NAME
|
27 |
+
from dash_app.main import app
|
28 |
+
from dash_app.views.confusion_matrix import CMchart
|
29 |
+
from dash_app.views.menu import get_subgroup_dropdown_options
|
30 |
+
from plot import (
|
31 |
+
get_data_distr_charts,
|
32 |
+
get_data_table,
|
33 |
+
get_feat_bar,
|
34 |
+
get_feat_box,
|
35 |
+
get_feat_val_violin_plot,
|
36 |
+
get_feat_val_violin_plots,
|
37 |
+
get_feat_val_box,
|
38 |
+
get_feat_val_bar,
|
39 |
+
get_feat_table,
|
40 |
+
get_sg_hist,
|
41 |
+
plot_calibration_curve,
|
42 |
+
plot_roc_curves,
|
43 |
+
)
|
44 |
+
from metrics import (
|
45 |
+
Y_PRED_METRICS,
|
46 |
+
get_qf_from_str,
|
47 |
+
get_quality_metric_from_str,
|
48 |
+
sort_quality_metrics_df,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def prepare_app(
|
53 |
+
n_samples=0, dataset="adult", bias=False, test_split=0.3, model="rf", sensitive_features=None
|
54 |
+
) -> Tuple[pd.DataFrame, np.ndarray, np.ndarray, np.ndarray, np.ndarray, pd.Index]:
|
55 |
+
"""Loads the data and trains the classifier
|
56 |
+
|
57 |
+
Args:
|
58 |
+
n_samples (int, optional): Number of samples to load. Defaults to 0.
|
59 |
+
dataset (str, optional): Name of the dataset to load. Defaults to "adult".
|
60 |
+
bias (Union[str, bool], optional): Type of bias to add to the dataset. Defaults to False.
|
61 |
+
If set to "random", adds random noise to a random feature for a random subset of the data.
|
62 |
+
If set to "swap", swaps the values of a random feature for a random subset of the data.
|
63 |
+
Returns:
|
64 |
+
Tuple: X_test, y_true_test, y_pred, y_pred_prob, shap_logloss_df, y_df
|
65 |
+
"""
|
66 |
+
|
67 |
+
# Loading and training
|
68 |
+
(
|
69 |
+
X_test,
|
70 |
+
y_true_train,
|
71 |
+
y_true_test,
|
72 |
+
onehot_X_train,
|
73 |
+
onehot_X_test,
|
74 |
+
cat_features,
|
75 |
+
) = load_data(n_samples=n_samples, dataset=dataset, test_split=test_split, sensitive_features=sensitive_features)
|
76 |
+
|
77 |
+
random_subgroup = pd.Series(
|
78 |
+
np.random.choice([True, False], size=len(X_test), p=[0.5, 0.5])
|
79 |
+
)
|
80 |
+
|
81 |
+
if dataset == "adult" and bias:
|
82 |
+
add_bias(bias, X_test, onehot_X_test, random_subgroup)
|
83 |
+
|
84 |
+
classifier, y_pred_prob = get_classifier(
|
85 |
+
onehot_X_train, y_true_train, onehot_X_test, model=model
|
86 |
+
)
|
87 |
+
|
88 |
+
shap_logloss_df = get_shap_logloss(
|
89 |
+
classifier, onehot_X_test, y_true_test, X_test, cat_features
|
90 |
+
)
|
91 |
+
|
92 |
+
return X_test, y_true_test, y_pred_prob, shap_logloss_df, random_subgroup
|
93 |
+
|
94 |
+
|
95 |
+
def run_app(
|
96 |
+
n_samples: int,
|
97 |
+
dataset: str,
|
98 |
+
bias: Union[str, bool] = False,
|
99 |
+
random_subgroup=False,
|
100 |
+
test_split=True,
|
101 |
+
model="rf",
|
102 |
+
depth=1,
|
103 |
+
min_support=100,
|
104 |
+
min_support_ratio=0.1,
|
105 |
+
min_quality=0.01,
|
106 |
+
sensitive_features=None,
|
107 |
+
):
|
108 |
+
"""Runs the app with the given qf_metric"""
|
109 |
+
use_random_subgroup = (
|
110 |
+
random_subgroup or bias
|
111 |
+
) # When evaluating bias, we want to evaluate against a random subgroup
|
112 |
+
start = time.time()
|
113 |
+
(
|
114 |
+
X_test_global,
|
115 |
+
y_true_global_test,
|
116 |
+
y_pred_prob_global,
|
117 |
+
shap_logloss_df_global,
|
118 |
+
random_subgroup_global,
|
119 |
+
) = prepare_app(
|
120 |
+
n_samples=n_samples,
|
121 |
+
dataset=dataset,
|
122 |
+
bias=bias,
|
123 |
+
test_split=test_split,
|
124 |
+
model=model,
|
125 |
+
sensitive_features=sensitive_features,
|
126 |
+
)
|
127 |
+
|
128 |
+
app.layout = html.Div(
|
129 |
+
id="app-container",
|
130 |
+
children=[
|
131 |
+
dcc.Store(id="result-set-dict"),
|
132 |
+
# Header
|
133 |
+
dbc.Row(
|
134 |
+
[
|
135 |
+
dbc.Col(
|
136 |
+
id="left-column",
|
137 |
+
children=[
|
138 |
+
dbc.Row(
|
139 |
+
[
|
140 |
+
dbc.Col(
|
141 |
+
html.H5(APP_NAME),
|
142 |
+
style={
|
143 |
+
"align-items": "center",
|
144 |
+
"height": "fit-content",
|
145 |
+
"white-space": "nowrap",
|
146 |
+
"width": 4,
|
147 |
+
},
|
148 |
+
),
|
149 |
+
dbc.Col(
|
150 |
+
html.H6(
|
151 |
+
"0. Subgroup Discovery Metric: ",
|
152 |
+
),
|
153 |
+
style={
|
154 |
+
"display": "flex",
|
155 |
+
"align-items": "center",
|
156 |
+
"height": "fit-content",
|
157 |
+
"justify-content": "right",
|
158 |
+
"width": 4,
|
159 |
+
},
|
160 |
+
),
|
161 |
+
dbc.Col(
|
162 |
+
dcc.Dropdown(
|
163 |
+
id="fairness-metric-dropdown",
|
164 |
+
options=[
|
165 |
+
{
|
166 |
+
"label": "Equalized Odds Difference",
|
167 |
+
"value": "equalized_odds_diff",
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"label": "Avg Log Loss Difference",
|
171 |
+
"value": "average_log_loss_diff",
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"label": "AUROC (ROC AUC) Difference",
|
175 |
+
"value": "auroc_diff",
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"label": "Miscalibration Difference",
|
179 |
+
"value": "miscalibration_diff",
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"label": "Brier Score Difference",
|
183 |
+
"value": "brier_score_diff",
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"label": "Accuracy Difference",
|
187 |
+
"value": "acc_diff",
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"label": "F1-score Difference",
|
191 |
+
"value": "f1_diff",
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"label": "False Positive Rate Difference",
|
195 |
+
"value": "fpr_diff",
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"label": "True Positive Rate Difference",
|
199 |
+
"value": "tpr_diff",
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"label": "False Negative Rate Difference",
|
203 |
+
"value": "fnr_diff",
|
204 |
+
},
|
205 |
+
],
|
206 |
+
value="average_log_loss_diff",
|
207 |
+
# Disable clearing
|
208 |
+
clearable=False,
|
209 |
+
)
|
210 |
+
),
|
211 |
+
]
|
212 |
+
),
|
213 |
+
],
|
214 |
+
),
|
215 |
+
dbc.Col(
|
216 |
+
id="right-column",
|
217 |
+
children=[
|
218 |
+
dbc.Row(
|
219 |
+
[
|
220 |
+
dbc.Col(
|
221 |
+
html.Center(
|
222 |
+
html.H6("1. Select Subgroup: "),
|
223 |
+
),
|
224 |
+
width=3,
|
225 |
+
style={
|
226 |
+
"align-items": "right",
|
227 |
+
"height": "fit-content",
|
228 |
+
},
|
229 |
+
),
|
230 |
+
dbc.Col(
|
231 |
+
dcc.Dropdown(
|
232 |
+
id="subgroup-dropdown",
|
233 |
+
options=[],
|
234 |
+
style={
|
235 |
+
"align-items": "left",
|
236 |
+
"height": "fit-content",
|
237 |
+
},
|
238 |
+
# Disable clearing
|
239 |
+
clearable=False,
|
240 |
+
),
|
241 |
+
width=9,
|
242 |
+
style={
|
243 |
+
"align-items": "left",
|
244 |
+
"height": "fit-content",
|
245 |
+
},
|
246 |
+
),
|
247 |
+
],
|
248 |
+
)
|
249 |
+
],
|
250 |
+
),
|
251 |
+
],
|
252 |
+
style={"height": "5vh"},
|
253 |
+
),
|
254 |
+
dcc.Tabs(
|
255 |
+
id="tabs",
|
256 |
+
value="impact",
|
257 |
+
children=[
|
258 |
+
dcc.Tab(
|
259 |
+
id="impact",
|
260 |
+
label="2. Underperformance Overview",
|
261 |
+
value="impact",
|
262 |
+
children=[
|
263 |
+
# Split the tab into two columns
|
264 |
+
html.Div(
|
265 |
+
className="row",
|
266 |
+
children=[
|
267 |
+
html.Div(
|
268 |
+
className="six columns",
|
269 |
+
children=[
|
270 |
+
html.H6(
|
271 |
+
"Full Dataset Baseline",
|
272 |
+
style={
|
273 |
+
"border-bottom": "3px solid #d3d3d3",
|
274 |
+
"font-weight": "normal",
|
275 |
+
},
|
276 |
+
),
|
277 |
+
dbc.Row(
|
278 |
+
[
|
279 |
+
dbc.Col(
|
280 |
+
dbc.Row(
|
281 |
+
[
|
282 |
+
dbc.Col(
|
283 |
+
html.Div(
|
284 |
+
id="simple-baseline-table",
|
285 |
+
className="six-columns",
|
286 |
+
children="Wait for the baseline to load...",
|
287 |
+
style={
|
288 |
+
"align-items": "center",
|
289 |
+
"height": "fit-content",
|
290 |
+
},
|
291 |
+
),
|
292 |
+
),
|
293 |
+
dbc.Col(
|
294 |
+
dcc.Graph(
|
295 |
+
id="simple-baseline-conf",
|
296 |
+
style={
|
297 |
+
"align-items": "center",
|
298 |
+
"height": "fit-content",
|
299 |
+
"height": "20vh",
|
300 |
+
"font-size": "0.8rem",
|
301 |
+
},
|
302 |
+
)
|
303 |
+
),
|
304 |
+
]
|
305 |
+
),
|
306 |
+
),
|
307 |
+
]
|
308 |
+
),
|
309 |
+
html.Br(),
|
310 |
+
dcc.Graph(id="simple-baseline-hist"),
|
311 |
+
# html.H6(
|
312 |
+
# "Select decision threshold for the model:"
|
313 |
+
# ),
|
314 |
+
# dcc.Slider(
|
315 |
+
# 0.1,
|
316 |
+
# 0.9,
|
317 |
+
# 0.1,
|
318 |
+
# value=0.5,
|
319 |
+
# id="simple-baseline-threshold-slider",
|
320 |
+
# ),
|
321 |
+
],
|
322 |
+
style={
|
323 |
+
"textAlign": "center",
|
324 |
+
"margin-right": "0.5",
|
325 |
+
},
|
326 |
+
),
|
327 |
+
html.Div(
|
328 |
+
className="six columns",
|
329 |
+
children=[
|
330 |
+
html.H6(
|
331 |
+
"Selected Subgroup",
|
332 |
+
style={
|
333 |
+
"border-bottom": "3px solid #d3d3d3",
|
334 |
+
"font-weight": "normal",
|
335 |
+
},
|
336 |
+
),
|
337 |
+
dbc.Row(
|
338 |
+
[
|
339 |
+
dbc.Col(
|
340 |
+
[
|
341 |
+
html.Div(
|
342 |
+
id="simple-subgroup-col",
|
343 |
+
className="six-columns",
|
344 |
+
children="Select subgroup and wait for the visualizations to load. ",
|
345 |
+
style={
|
346 |
+
"align-items": "center",
|
347 |
+
"height": "fit-content",
|
348 |
+
},
|
349 |
+
),
|
350 |
+
]
|
351 |
+
),
|
352 |
+
dbc.Col(
|
353 |
+
dcc.Graph(
|
354 |
+
id="simple-subgroup-conf",
|
355 |
+
style={
|
356 |
+
"height": "20vh",
|
357 |
+
"font-size": "0.4rem",
|
358 |
+
},
|
359 |
+
)
|
360 |
+
),
|
361 |
+
]
|
362 |
+
),
|
363 |
+
html.Br(),
|
364 |
+
dcc.Graph(id="simple-subgroup-hist"),
|
365 |
+
],
|
366 |
+
style={
|
367 |
+
"textAlign": "center",
|
368 |
+
# "border-left": "3px solid #d3d3d3",
|
369 |
+
"margin-left": "1",
|
370 |
+
},
|
371 |
+
),
|
372 |
+
],
|
373 |
+
),
|
374 |
+
],
|
375 |
+
),
|
376 |
+
dcc.Tab(
|
377 |
+
label="3. Performance and Calibration",
|
378 |
+
value="performance_tab",
|
379 |
+
children=[
|
380 |
+
html.Div(
|
381 |
+
className="row",
|
382 |
+
children=[
|
383 |
+
dbc.Row(
|
384 |
+
[
|
385 |
+
dbc.Col(
|
386 |
+
[
|
387 |
+
dcc.Graph(
|
388 |
+
id="perf-roc",
|
389 |
+
),
|
390 |
+
]
|
391 |
+
),
|
392 |
+
dbc.Col(
|
393 |
+
[
|
394 |
+
dcc.Graph(
|
395 |
+
id="calibration_curve",
|
396 |
+
),
|
397 |
+
html.H6(
|
398 |
+
"Select number of bins for the calibration plot:"
|
399 |
+
),
|
400 |
+
dcc.Slider(
|
401 |
+
4,
|
402 |
+
20,
|
403 |
+
1,
|
404 |
+
value=10,
|
405 |
+
id="calibration-slider",
|
406 |
+
),
|
407 |
+
]
|
408 |
+
),
|
409 |
+
]
|
410 |
+
),
|
411 |
+
],
|
412 |
+
style={"align-items": "center"},
|
413 |
+
),
|
414 |
+
],
|
415 |
+
),
|
416 |
+
dcc.Tab(
|
417 |
+
label="4. Loss contributions per feature",
|
418 |
+
value="feature_contributions_tab",
|
419 |
+
children=[
|
420 |
+
dbc.Row(
|
421 |
+
[
|
422 |
+
dcc.Graph(
|
423 |
+
id="feat-bar",
|
424 |
+
style={
|
425 |
+
"align-items": "center",
|
426 |
+
"height": "fit-content",
|
427 |
+
"font-size": "0.8rem",
|
428 |
+
},
|
429 |
+
),
|
430 |
+
]
|
431 |
+
),
|
432 |
+
dbc.Row(
|
433 |
+
[
|
434 |
+
dbc.Col(
|
435 |
+
[
|
436 |
+
html.H6(
|
437 |
+
"Select significance level alpha for the Kolgomorov Smirnoff (KS) test - Data table rows in bold are considered significant: "
|
438 |
+
),
|
439 |
+
dcc.Slider(
|
440 |
+
0.01,
|
441 |
+
0.1,
|
442 |
+
0.01,
|
443 |
+
value=0.05,
|
444 |
+
id="feat-alpha-slider",
|
445 |
+
),
|
446 |
+
html.H6(
|
447 |
+
"Select rounding level (number of decimal points) for SHAP values in the KS test: "
|
448 |
+
),
|
449 |
+
dcc.Slider(
|
450 |
+
0,
|
451 |
+
7,
|
452 |
+
1,
|
453 |
+
value=2,
|
454 |
+
id="feat-sensitivity-slider",
|
455 |
+
),
|
456 |
+
html.H6(
|
457 |
+
"Rounding level is the decimal point level at which SHAP values should be rounded because they are considered 'distinct'," +
|
458 |
+
" in order to avoid detecting statistically significant but minor differences between distributions in the KS test."
|
459 |
+
),
|
460 |
+
dbc.Row([
|
461 |
+
dbc.Col([
|
462 |
+
html.H6("Select aggregation method for feature contributions:"),
|
463 |
+
]),
|
464 |
+
dbc.Col([
|
465 |
+
dcc.Dropdown(
|
466 |
+
id="feat-agg-dropdown",
|
467 |
+
options=[
|
468 |
+
{
|
469 |
+
"label": "Statistical summary (box)",
|
470 |
+
"value": "box",
|
471 |
+
},
|
472 |
+
{
|
473 |
+
"label": "Distribution details (violin)",
|
474 |
+
"value": "violin",
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"label": "Mean (bar)",
|
478 |
+
"value": "mean",
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"label": "Median (bar)",
|
482 |
+
"value": "median",
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"label": "Mean difference (bar)",
|
486 |
+
"value": "mean_diff",
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"label": "Sum", # Weighted sum in this context is literally the mean
|
490 |
+
"value": "sum",
|
491 |
+
},
|
492 |
+
],
|
493 |
+
value="mean",
|
494 |
+
style={
|
495 |
+
"align-items": "center",
|
496 |
+
"text-align": "center",
|
497 |
+
},
|
498 |
+
# Disable clearing
|
499 |
+
clearable=False,
|
500 |
+
),
|
501 |
+
]),
|
502 |
+
])
|
503 |
+
|
504 |
+
]
|
505 |
+
),
|
506 |
+
html.Br(),
|
507 |
+
dbc.Col(
|
508 |
+
[
|
509 |
+
html.Div(
|
510 |
+
id="feat-table-col",
|
511 |
+
className="six-columns",
|
512 |
+
children="Data table for feature contributions. Select a subgroup to update the table.",
|
513 |
+
style={
|
514 |
+
"align-items": "center",
|
515 |
+
"height": "fit-content",
|
516 |
+
},
|
517 |
+
),
|
518 |
+
]
|
519 |
+
),
|
520 |
+
]
|
521 |
+
),
|
522 |
+
],
|
523 |
+
),
|
524 |
+
dcc.Tab(
|
525 |
+
label="5. Loss contributions per feature value",
|
526 |
+
value="feature_value_contributions_tab",
|
527 |
+
children=[
|
528 |
+
html.Div(
|
529 |
+
className="row",
|
530 |
+
children=[
|
531 |
+
dbc.Row(
|
532 |
+
[
|
533 |
+
dbc.Col(
|
534 |
+
[
|
535 |
+
html.H6(
|
536 |
+
"Select feature for value contributions:"
|
537 |
+
),
|
538 |
+
dcc.Dropdown(
|
539 |
+
id="feat-val-feature-dropdown",
|
540 |
+
options=[
|
541 |
+
{"label": col, "value": col}
|
542 |
+
for col in X_test_global.columns
|
543 |
+
],
|
544 |
+
value=X_test_global.columns[0],
|
545 |
+
style={
|
546 |
+
"align-items": "center",
|
547 |
+
"width": "50%",
|
548 |
+
"text-align": "center",
|
549 |
+
},
|
550 |
+
# Disable clearing
|
551 |
+
clearable=False,
|
552 |
+
),
|
553 |
+
]
|
554 |
+
),
|
555 |
+
dbc.Col([
|
556 |
+
html.H6("Select aggregation method for feature value contributions:"),
|
557 |
+
dcc.Dropdown(
|
558 |
+
id="feat-val-agg-dropdown",
|
559 |
+
options=[
|
560 |
+
{
|
561 |
+
"label": "Mean",
|
562 |
+
"value": "mean",
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"label": "Median",
|
566 |
+
"value": "median",
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"label": "Sum (Total)",
|
570 |
+
"value": "sum",
|
571 |
+
},
|
572 |
+
{
|
573 |
+
"label": "Sum (weighted)",
|
574 |
+
"value": "sum_weighted",
|
575 |
+
},
|
576 |
+
# {
|
577 |
+
# "label": "Statistical summary (box plot)",
|
578 |
+
# "value": "box",
|
579 |
+
# }, # FIXME: Box plots should be next to each other not on top
|
580 |
+
{
|
581 |
+
"label": "Distribution details (violin)",
|
582 |
+
"value": "violin",
|
583 |
+
},
|
584 |
+
],
|
585 |
+
value="sum_weighted",
|
586 |
+
style={
|
587 |
+
"align-items": "center",
|
588 |
+
"text-align": "center",
|
589 |
+
"width": "50%",
|
590 |
+
},
|
591 |
+
# Disable clearing
|
592 |
+
clearable=False,
|
593 |
+
),
|
594 |
+
]),
|
595 |
+
]
|
596 |
+
),
|
597 |
+
html.Br(),
|
598 |
+
dbc.Row([
|
599 |
+
dcc.Graph(id="feat-val-plot"),
|
600 |
+
]),
|
601 |
+
html.Br(),
|
602 |
+
dcc.Slider(
|
603 |
+
2, 20, 2, value=8, id="feat-val-hist-slider"
|
604 |
+
),
|
605 |
+
],
|
606 |
+
style={"align-items": "center"},
|
607 |
+
),
|
608 |
+
],
|
609 |
+
),
|
610 |
+
dcc.Tab(
|
611 |
+
label="6. Class Imbalances",
|
612 |
+
value="data_tab",
|
613 |
+
children=[
|
614 |
+
# Split into two equal width columns with headers
|
615 |
+
html.Div(
|
616 |
+
className="row",
|
617 |
+
children=[
|
618 |
+
dbc.Row(
|
619 |
+
[
|
620 |
+
dbc.Col(
|
621 |
+
[
|
622 |
+
html.H6(
|
623 |
+
"Select feature for distribution plots:"
|
624 |
+
),
|
625 |
+
dcc.Dropdown(
|
626 |
+
id="data-feature-dropdown",
|
627 |
+
options=[
|
628 |
+
{"label": col, "value": col}
|
629 |
+
for col in X_test_global.columns
|
630 |
+
],
|
631 |
+
value=X_test_global.columns[0],
|
632 |
+
style={
|
633 |
+
"align-items": "center",
|
634 |
+
"text-align": "center",
|
635 |
+
"border-bottom": "3px solid #d3d3d3"
|
636 |
+
},
|
637 |
+
# Disable clearing
|
638 |
+
clearable=False,
|
639 |
+
),
|
640 |
+
]
|
641 |
+
),
|
642 |
+
dbc.Col(
|
643 |
+
[
|
644 |
+
html.H6("Select class of samples to plot:"),
|
645 |
+
dcc.Dropdown(
|
646 |
+
id="data-label-dropdown",
|
647 |
+
options=[
|
648 |
+
{"label": "All", "value": "all"},
|
649 |
+
{"label": "Positive samples only", "value": 1},
|
650 |
+
{"label": "Negative samples only", "value": 0},
|
651 |
+
],
|
652 |
+
value="all",
|
653 |
+
style={
|
654 |
+
"align-items": "center",
|
655 |
+
"text-align": "center",
|
656 |
+
"border-bottom": "3px solid #d3d3d3"
|
657 |
+
},
|
658 |
+
# Disable clearing
|
659 |
+
clearable=False,
|
660 |
+
),
|
661 |
+
]
|
662 |
+
),
|
663 |
+
dbc.Col(
|
664 |
+
[
|
665 |
+
# Dropdown for percentage/absolute values
|
666 |
+
html.H6(
|
667 |
+
"Select aggregation method for distribution plots:"
|
668 |
+
),
|
669 |
+
dcc.Dropdown(
|
670 |
+
id="data-agg-dropdown",
|
671 |
+
options=[
|
672 |
+
{
|
673 |
+
"label": "Percentage",
|
674 |
+
"value": "percentage",
|
675 |
+
},
|
676 |
+
{
|
677 |
+
"label": "Count",
|
678 |
+
"value": "count",
|
679 |
+
},
|
680 |
+
],
|
681 |
+
value="percentage",
|
682 |
+
style={
|
683 |
+
"align-items": "center",
|
684 |
+
"text-align": "center",
|
685 |
+
"border-bottom": "3px solid #d3d3d3"
|
686 |
+
},
|
687 |
+
# Disable clearing
|
688 |
+
clearable=False,
|
689 |
+
),
|
690 |
+
]
|
691 |
+
),
|
692 |
+
]
|
693 |
+
),
|
694 |
+
html.Br(),
|
695 |
+
dbc.Row(
|
696 |
+
[
|
697 |
+
dbc.Col([
|
698 |
+
dcc.Graph(id="data-class-dist-plot"),
|
699 |
+
]),
|
700 |
+
dbc.Col([
|
701 |
+
dcc.Graph(id="data-pred-dist-plot"),
|
702 |
+
]),
|
703 |
+
]
|
704 |
+
),
|
705 |
+
|
706 |
+
html.H6(
|
707 |
+
"Select (max) number of bins (numerical features only):"
|
708 |
+
),
|
709 |
+
dcc.Slider(
|
710 |
+
5, 30, 5, value=10, id="data-hist-slider"
|
711 |
+
),
|
712 |
+
],
|
713 |
+
style={"align-items": "center"},
|
714 |
+
),
|
715 |
+
],
|
716 |
+
),
|
717 |
+
],
|
718 |
+
# Set headers to bold font
|
719 |
+
style={"font-weight": "bold"},
|
720 |
+
),
|
721 |
+
],
|
722 |
+
)
|
723 |
+
|
724 |
+
@app.callback(
|
725 |
+
Output("simple-baseline-table", "children"),
|
726 |
+
Output("simple-baseline-conf", "figure"),
|
727 |
+
Output("simple-baseline-hist", "figure"),
|
728 |
+
Output("subgroup-dropdown", "options"),
|
729 |
+
Output("result-set-dict", "data"),
|
730 |
+
Input("fairness-metric-dropdown", "value"),
|
731 |
+
# Input("simple-baseline-threshold-slider", "value"),
|
732 |
+
)
|
733 |
+
def get_baseline_stats_and_subgroups(metric, threshold=0.5):
|
734 |
+
if not metric:
|
735 |
+
raise PreventUpdate
|
736 |
+
|
737 |
+
|
738 |
+
y_true = y_true_global_test.copy()
|
739 |
+
y_pred_prob = y_pred_prob_global.copy()
|
740 |
+
if metric in Y_PRED_METRICS:
|
741 |
+
y_pred = (y_pred_prob >= threshold).astype(int)
|
742 |
+
else:
|
743 |
+
y_pred = y_pred_prob.copy()
|
744 |
+
|
745 |
+
y_df = pd.DataFrame({"y_true": y_true, "probability": y_pred_prob})
|
746 |
+
y_df["category"] = y_df.apply(
|
747 |
+
lambda row: "TP"
|
748 |
+
if row["y_true"] == 1 and row["probability"] >= threshold
|
749 |
+
else "FP"
|
750 |
+
if row["y_true"] == 0 and row["probability"] >= threshold
|
751 |
+
else "FN"
|
752 |
+
if row["y_true"] == 1 and row["probability"] < threshold
|
753 |
+
else "TN",
|
754 |
+
axis=1,
|
755 |
+
)
|
756 |
+
|
757 |
+
baseline_descr = "Full dataset baseline"
|
758 |
+
|
759 |
+
baseline_data_table = get_data_table(
|
760 |
+
baseline_descr,
|
761 |
+
y_true,
|
762 |
+
y_pred,
|
763 |
+
y_pred_prob,
|
764 |
+
qf_metric=metric,
|
765 |
+
sg_feature=pd.Series([True] * y_true.shape[0]),
|
766 |
+
# We do not update number of bins here to prevent rerunning DSSD. Default number of bins should be enough for the baseline generally
|
767 |
+
)
|
768 |
+
baseline_conf_mat = CMchart(
|
769 |
+
"Confusion Matrix", y_true, (y_pred_prob >= threshold).astype(int)
|
770 |
+
).fig
|
771 |
+
baseline_hist = get_sg_hist(
|
772 |
+
y_df, title="Histogram of prediction probabilities on the full dataset"
|
773 |
+
)
|
774 |
+
|
775 |
+
if use_random_subgroup:
|
776 |
+
sg_feature = random_subgroup_global.copy()
|
777 |
+
# Replace the result_set_df with a synthetic random subgroup
|
778 |
+
sg_y_pred = (
|
779 |
+
y_pred[sg_feature]
|
780 |
+
if metric in Y_PRED_METRICS
|
781 |
+
else y_pred_prob[sg_feature]
|
782 |
+
)
|
783 |
+
sg_y_true = y_true[sg_feature]
|
784 |
+
name = "Random subgroup"
|
785 |
+
if bias:
|
786 |
+
name += f" with bias"
|
787 |
+
result_set_df = pd.DataFrame(
|
788 |
+
{
|
789 |
+
"quality": [None],
|
790 |
+
"description": [name],
|
791 |
+
"size": [sum(sg_feature)],
|
792 |
+
"proportion": [sum(sg_feature) / len(sg_feature)],
|
793 |
+
"metric_score": [
|
794 |
+
get_quality_metric_from_str(metric)(sg_y_true, sg_y_pred)
|
795 |
+
],
|
796 |
+
}
|
797 |
+
)
|
798 |
+
result_set_json = {
|
799 |
+
"descriptions": ["Random subgroup"],
|
800 |
+
"sg_features": [sg_feature.to_json()],
|
801 |
+
"metric": metric,
|
802 |
+
}
|
803 |
+
return (
|
804 |
+
baseline_data_table,
|
805 |
+
baseline_conf_mat,
|
806 |
+
baseline_hist,
|
807 |
+
get_subgroup_dropdown_options(result_set_df, metric),
|
808 |
+
result_set_json,
|
809 |
+
)
|
810 |
+
else:
|
811 |
+
result_set = get_fairsd_result_set(
|
812 |
+
X_test_global,
|
813 |
+
y_true_global_test,
|
814 |
+
y_pred,
|
815 |
+
qf=get_qf_from_str(metric),
|
816 |
+
# method="between_groups",
|
817 |
+
method="to_overall",
|
818 |
+
depth=depth,
|
819 |
+
min_support=min_support,
|
820 |
+
min_support_ratio=min_support_ratio,
|
821 |
+
max_support_ratio=0.5, # To prevent finding majority subgroups
|
822 |
+
logging_level=logging.INFO,
|
823 |
+
sensitive_features=sensitive_features,
|
824 |
+
min_quality=min_quality,
|
825 |
+
)
|
826 |
+
result_set_df = result_set.to_dataframe()
|
827 |
+
metrics = []
|
828 |
+
for idx in range(len(result_set_df)):
|
829 |
+
description = result_set.get_description(idx)
|
830 |
+
sg_feature = description.to_boolean_array(X_test_global)
|
831 |
+
sg_y_pred = (
|
832 |
+
y_pred[sg_feature]
|
833 |
+
if metric in Y_PRED_METRICS
|
834 |
+
else y_pred_prob[sg_feature]
|
835 |
+
)
|
836 |
+
sg_y_true = y_true[sg_feature]
|
837 |
+
metrics.append(
|
838 |
+
get_quality_metric_from_str(metric)(sg_y_true, sg_y_pred)
|
839 |
+
)
|
840 |
+
|
841 |
+
result_set_df["metric_score"] = metrics
|
842 |
+
result_set_df = sort_quality_metrics_df(result_set_df, metric)
|
843 |
+
return (
|
844 |
+
baseline_data_table,
|
845 |
+
baseline_conf_mat,
|
846 |
+
baseline_hist,
|
847 |
+
get_subgroup_dropdown_options(result_set_df, metric),
|
848 |
+
result_set.to_json(
|
849 |
+
X_test_global, metric, result_set_df
|
850 |
+
), # Store the result set representation in the data store
|
851 |
+
)
|
852 |
+
|
853 |
+
# Get feature value plot based on subgroup and feature selection
|
854 |
+
@app.callback(
|
855 |
+
Output("feat-val-plot", "figure"),
|
856 |
+
Input("feat-val-feature-dropdown", "value"),
|
857 |
+
Input("subgroup-dropdown", "value"),
|
858 |
+
Input("result-set-dict", "data"),
|
859 |
+
Input("feat-val-hist-slider", "value"),
|
860 |
+
Input("feat-val-agg-dropdown", "value"),
|
861 |
+
)
|
862 |
+
def get_feat_val_plot(feature, subgroup, data, nbins, agg):
|
863 |
+
"""Produces a violin chart or line plot with the feature value contributions for the selected subgroup"""
|
864 |
+
if not feature:
|
865 |
+
raise PreventUpdate
|
866 |
+
if subgroup is None:
|
867 |
+
raise PreventUpdate
|
868 |
+
if not nbins:
|
869 |
+
print("Error: No bins selected. This should not happen.")
|
870 |
+
raise PreventUpdate
|
871 |
+
|
872 |
+
if len(data["descriptions"]) == 0:
|
873 |
+
print("Error: No subgroups found. This should not happen.")
|
874 |
+
raise PreventUpdate
|
875 |
+
|
876 |
+
description = data["descriptions"][subgroup]
|
877 |
+
sg_feature = pd.read_json(data["sg_features"][subgroup], typ="series")
|
878 |
+
if agg == "violin":
|
879 |
+
return get_feat_val_violin_plot(
|
880 |
+
X_test_global.copy(),
|
881 |
+
shap_logloss_df_global.copy(),
|
882 |
+
sg_feature,
|
883 |
+
feature,
|
884 |
+
description,
|
885 |
+
nbins=nbins,
|
886 |
+
)
|
887 |
+
elif agg == "box":
|
888 |
+
return get_feat_val_box(
|
889 |
+
X_test_global.copy(),
|
890 |
+
shap_logloss_df_global.copy(),
|
891 |
+
sg_feature=sg_feature,
|
892 |
+
feature=feature,
|
893 |
+
description=description,
|
894 |
+
nbins=nbins,
|
895 |
+
)
|
896 |
+
else:
|
897 |
+
return get_feat_val_bar(
|
898 |
+
X_test_global.copy(),
|
899 |
+
shap_logloss_df_global.copy(),
|
900 |
+
sg_feature=sg_feature,
|
901 |
+
feature=feature,
|
902 |
+
description=description,
|
903 |
+
nbins=nbins,
|
904 |
+
agg=agg,
|
905 |
+
)
|
906 |
+
|
907 |
+
# Get calibration plot based on subgroup and slider selection
|
908 |
+
@app.callback(
|
909 |
+
Output("calibration_curve", "figure"),
|
910 |
+
Input("calibration-slider", "value"),
|
911 |
+
Input("subgroup-dropdown", "value"),
|
912 |
+
Input("result-set-dict", "data"),
|
913 |
+
)
|
914 |
+
def get_calibration_plot(slider_value, subgroup, data):
|
915 |
+
"""Produces a calibration plot for the selected subgroup"""
|
916 |
+
if not slider_value:
|
917 |
+
raise PreventUpdate
|
918 |
+
if subgroup is None:
|
919 |
+
raise PreventUpdate
|
920 |
+
if len(data["sg_features"]) == 0:
|
921 |
+
print("Error: No subgroups found. This should not happen.")
|
922 |
+
raise PreventUpdate
|
923 |
+
sg_feature = pd.read_json(data["sg_features"][subgroup], typ="series")
|
924 |
+
y_true = y_true_global_test.copy()
|
925 |
+
y_pred_prob = y_pred_prob_global.copy()
|
926 |
+
|
927 |
+
return plot_calibration_curve(
|
928 |
+
y_true, y_pred_prob, sg_feature, n_bins=slider_value
|
929 |
+
)
|
930 |
+
|
931 |
+
# Get data distributions based on subgroup selection
|
932 |
+
@app.callback(
|
933 |
+
Output("data-class-dist-plot", "figure"),
|
934 |
+
Output("data-pred-dist-plot", "figure"),
|
935 |
+
Input("data-feature-dropdown", "value"),
|
936 |
+
Input("data-agg-dropdown", "value"),
|
937 |
+
Input("subgroup-dropdown", "value"),
|
938 |
+
Input("result-set-dict", "data"),
|
939 |
+
Input("data-hist-slider", "value"),
|
940 |
+
Input("data-label-dropdown", "value"),
|
941 |
+
# Input("simple-baseline-threshold-slider", "value"),
|
942 |
+
)
|
943 |
+
def get_data_feat_distr(feature, agg, subgroup, data, bins, class_label, threshold=0.5):
|
944 |
+
"""Produces a bar chart or line plot with the data feature values counts for the selected subgroup"""
|
945 |
+
if not feature:
|
946 |
+
raise PreventUpdate
|
947 |
+
if not agg:
|
948 |
+
raise PreventUpdate
|
949 |
+
if subgroup is None:
|
950 |
+
raise PreventUpdate
|
951 |
+
if not bins:
|
952 |
+
logging.error("Error: No bins selected. This should not happen.")
|
953 |
+
raise PreventUpdate
|
954 |
+
try:
|
955 |
+
description = data["descriptions"][subgroup]
|
956 |
+
sg_feature = pd.read_json(data["sg_features"][subgroup], typ="series")
|
957 |
+
except IndexError:
|
958 |
+
print("Subgroup not found. This should not happen.")
|
959 |
+
raise PreventUpdate
|
960 |
+
|
961 |
+
y_pred_prob = y_pred_prob_global.copy()
|
962 |
+
y_pred = (y_pred_prob >= threshold).astype(int)
|
963 |
+
y_true = y_true_global_test.copy()
|
964 |
+
X_test = X_test_global.copy()
|
965 |
+
if class_label in (0, 1):
|
966 |
+
y_pred = y_pred[y_true == class_label]
|
967 |
+
X_test = X_test[y_true == class_label]
|
968 |
+
|
969 |
+
class_plot, pred_plot = get_data_distr_charts(
|
970 |
+
X_test, y_pred, sg_feature, feature, description, bins, agg
|
971 |
+
)
|
972 |
+
return class_plot, pred_plot
|
973 |
+
|
974 |
+
|
975 |
+
# Get feat-table-col
|
976 |
+
@app.callback(
|
977 |
+
Output("feat-table-col", "children"),
|
978 |
+
Input("subgroup-dropdown", "value"),
|
979 |
+
Input("result-set-dict", "data"),
|
980 |
+
Input("feat-alpha-slider", "value"),
|
981 |
+
Input("feat-sensitivity-slider", "value"),
|
982 |
+
)
|
983 |
+
def get_feat_table_col(subgroup, data, alpha, sensitivity):
|
984 |
+
"""Returns the feature contributions table for the selected subgroup"""
|
985 |
+
if subgroup is None:
|
986 |
+
raise PreventUpdate
|
987 |
+
if not alpha:
|
988 |
+
alpha = 0.05
|
989 |
+
if len(data["descriptions"]) == 0:
|
990 |
+
print("Error: No subgroups found. This should not happen.")
|
991 |
+
raise PreventUpdate
|
992 |
+
|
993 |
+
sg_feature = pd.read_json(data["sg_features"][subgroup], typ="series")
|
994 |
+
shap_df = shap_logloss_df_global.copy()
|
995 |
+
return get_feat_table(
|
996 |
+
shap_values_df=shap_df,
|
997 |
+
sg_feature=sg_feature,
|
998 |
+
sensitivity=sensitivity,
|
999 |
+
alpha=alpha,
|
1000 |
+
)
|
1001 |
+
|
1002 |
+
# Get plots based on the subgroup selection
|
1003 |
+
@app.callback(
|
1004 |
+
Output("simple-subgroup-col", "children"),
|
1005 |
+
Output("simple-subgroup-conf", "figure"),
|
1006 |
+
Output("simple-subgroup-hist", "figure"),
|
1007 |
+
Output("perf-roc", "figure"),
|
1008 |
+
Output("feat-bar", "figure"),
|
1009 |
+
Input("result-set-dict", "data"),
|
1010 |
+
Input("subgroup-dropdown", "value"),
|
1011 |
+
Input("calibration-slider", "value"),
|
1012 |
+
Input("feat-agg-dropdown", "value"),
|
1013 |
+
# Input("simple-baseline-threshold-slider", "value"),
|
1014 |
+
)
|
1015 |
+
def get_subgroup_stats(data, subgroup, nbins, agg, threshold=0.5):
|
1016 |
+
"""Returns the group description and updates the charts of the selected subgroup"""
|
1017 |
+
if subgroup is None:
|
1018 |
+
# TODO: Return baseline-only plots when no subgroup is selected
|
1019 |
+
raise PreventUpdate
|
1020 |
+
if len(data["descriptions"]) == 0:
|
1021 |
+
print("Error: No subgroups found. This should not happen.")
|
1022 |
+
raise PreventUpdate
|
1023 |
+
if not nbins:
|
1024 |
+
print("Error: No bins selected. This should not happen.")
|
1025 |
+
raise PreventUpdate
|
1026 |
+
|
1027 |
+
sg_feature = pd.read_json(data["sg_features"][subgroup], typ="series")
|
1028 |
+
description = data["descriptions"][subgroup]
|
1029 |
+
subgroup_description = str(description).replace(" ", "")
|
1030 |
+
subgroup_description = subgroup_description.replace("AND", " AND ")
|
1031 |
+
metric = data["metric"]
|
1032 |
+
|
1033 |
+
y_true = y_true_global_test.copy()
|
1034 |
+
|
1035 |
+
y_pred_prob = y_pred_prob_global.copy()
|
1036 |
+
y_pred = (y_pred_prob >= threshold).astype(int)
|
1037 |
+
|
1038 |
+
y_df = pd.DataFrame({"y_true": y_true, "probability": y_pred_prob})
|
1039 |
+
y_df["category"] = y_df.apply(
|
1040 |
+
lambda row: "TP"
|
1041 |
+
if row["y_true"] == 1 and row["probability"] >= threshold
|
1042 |
+
else "FP"
|
1043 |
+
if row["y_true"] == 0 and row["probability"] >= threshold
|
1044 |
+
else "FN"
|
1045 |
+
if row["y_true"] == 1 and row["probability"] < threshold
|
1046 |
+
else "TN",
|
1047 |
+
axis=1,
|
1048 |
+
)
|
1049 |
+
shap_values_df = shap_logloss_df_global.copy()
|
1050 |
+
|
1051 |
+
sg_hist = get_sg_hist(y_df[sg_feature])
|
1052 |
+
|
1053 |
+
sg_data_table = get_data_table(
|
1054 |
+
subgroup_description,
|
1055 |
+
y_true,
|
1056 |
+
y_pred,
|
1057 |
+
y_pred_prob,
|
1058 |
+
qf_metric=metric,
|
1059 |
+
sg_feature=sg_feature,
|
1060 |
+
n_bins=nbins,
|
1061 |
+
)
|
1062 |
+
|
1063 |
+
roc_fig = plot_roc_curves(
|
1064 |
+
y_true, y_pred_prob, sg_feature, title="ROC for subgroup and baseline"
|
1065 |
+
)
|
1066 |
+
|
1067 |
+
sg_conf_mat = CMchart(
|
1068 |
+
"Confusion Matrix", y_true_global_test[sg_feature], y_pred[sg_feature]
|
1069 |
+
).fig
|
1070 |
+
|
1071 |
+
if agg == "box":
|
1072 |
+
# Get a box plot with the feature importances for sg and baseline
|
1073 |
+
feat_plot = get_feat_box(shap_values_df, sg_feature=sg_feature)
|
1074 |
+
elif agg == "violin":
|
1075 |
+
# Get a violin chart with the feature importances for sg and baseline
|
1076 |
+
feat_plot = get_feat_val_violin_plots(
|
1077 |
+
shap_values_df, sg_feature=sg_feature
|
1078 |
+
)
|
1079 |
+
else:
|
1080 |
+
feat_plot = get_feat_bar(shap_values_df, sg_feature=sg_feature, agg=agg)
|
1081 |
+
|
1082 |
+
return (sg_data_table, sg_conf_mat, sg_hist, roc_fig, feat_plot)
|
1083 |
+
|
1084 |
+
print("App startup time (s): ", time.time() - start)
|
1085 |
+
app.run(host='0.0.0.0', port=7860, debug=False)
|
1086 |
+
|
1087 |
+
|
1088 |
+
if __name__ == "__main__":
|
1089 |
+
parser = argparse.ArgumentParser()
|
1090 |
+
parser.add_argument(
|
1091 |
+
"-n",
|
1092 |
+
"--n_samples",
|
1093 |
+
type=int,
|
1094 |
+
default=0,
|
1095 |
+
help="Number of samples to use for the app. Use 0 to load the entire dataset.",
|
1096 |
+
)
|
1097 |
+
parser.add_argument(
|
1098 |
+
"--dataset",
|
1099 |
+
type=str,
|
1100 |
+
default="adult",
|
1101 |
+
help="Dataset to be used in the evaluation. Available options are: 'adult', 'credit_g', 'heloc'",
|
1102 |
+
)
|
1103 |
+
parser.add_argument(
|
1104 |
+
"-r",
|
1105 |
+
"--random_subgroup",
|
1106 |
+
action="store_true",
|
1107 |
+
default=False,
|
1108 |
+
help="Flag whether to use a random subgroup for evaluation",
|
1109 |
+
)
|
1110 |
+
parser.add_argument(
|
1111 |
+
"-b",
|
1112 |
+
"--bias",
|
1113 |
+
type=str,
|
1114 |
+
default=False,
|
1115 |
+
help="Type of bias to add to the dataset",
|
1116 |
+
)
|
1117 |
+
parser.add_argument(
|
1118 |
+
"-s",
|
1119 |
+
"--test_split",
|
1120 |
+
default=0.3,
|
1121 |
+
help="Ratio of samples selected for the test set used for visualizations. 0 and 1 will use the full dataset for the training and test set (useful with smaller models).",
|
1122 |
+
)
|
1123 |
+
parser.add_argument(
|
1124 |
+
"-m",
|
1125 |
+
"--model",
|
1126 |
+
type=str,
|
1127 |
+
default="rf",
|
1128 |
+
help="Model to use for the evaluation. Available options are: 'rf', 'dt', 'xgb'",
|
1129 |
+
)
|
1130 |
+
parser.add_argument(
|
1131 |
+
"-d",
|
1132 |
+
"--depth",
|
1133 |
+
type=int,
|
1134 |
+
default=1,
|
1135 |
+
help="Depth of the subgroup discovery algorithm search",
|
1136 |
+
)
|
1137 |
+
parser.add_argument(
|
1138 |
+
"--min_support",
|
1139 |
+
type=int,
|
1140 |
+
default=100,
|
1141 |
+
help="Minimum support (subgroup size) for the subgroup discovery algorithm",
|
1142 |
+
)
|
1143 |
+
parser.add_argument(
|
1144 |
+
"--min_support_ratio",
|
1145 |
+
type=float,
|
1146 |
+
default=0.1,
|
1147 |
+
help="Min support ratio for the subgroup discovery algorithm",
|
1148 |
+
)
|
1149 |
+
parser.add_argument(
|
1150 |
+
"--min_quality",
|
1151 |
+
type=float,
|
1152 |
+
default=0.01,
|
1153 |
+
help="Minimum quality for the subgroup discovery algorithm",
|
1154 |
+
)
|
1155 |
+
parser.add_argument(
|
1156 |
+
"--sensitive_features",
|
1157 |
+
type=str,
|
1158 |
+
nargs="+",
|
1159 |
+
help="List of sensitive features to use for the subgroup discovery algorithm",
|
1160 |
+
)
|
1161 |
+
|
1162 |
+
args = parser.parse_args()
|
1163 |
+
run_app(
|
1164 |
+
args.n_samples,
|
1165 |
+
args.dataset,
|
1166 |
+
args.bias,
|
1167 |
+
args.random_subgroup,
|
1168 |
+
args.test_split,
|
1169 |
+
args.model,
|
1170 |
+
args.depth,
|
1171 |
+
args.min_support,
|
1172 |
+
args.min_support_ratio,
|
1173 |
+
args.min_quality,
|
1174 |
+
args.sensitive_features,
|
1175 |
+
)
|
constants.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Selected colors which have a dark version (so that they can be used for both light and dark accents of the same feature)
|
2 |
+
colors = [
|
3 |
+
"blue",
|
4 |
+
"cyan",
|
5 |
+
"goldenrod",
|
6 |
+
"grey",
|
7 |
+
"green",
|
8 |
+
"khaki",
|
9 |
+
"magenta",
|
10 |
+
"orange",
|
11 |
+
"orchid",
|
12 |
+
"red",
|
13 |
+
"salmon",
|
14 |
+
"seagreen",
|
15 |
+
"slateblue",
|
16 |
+
"slategray",
|
17 |
+
"slategrey",
|
18 |
+
"turquoise",
|
19 |
+
"violet",
|
20 |
+
]
|
dash_app/assets/base.css
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Table of contents
|
2 |
+
––––––––––––––––––––––––––––––––––––––––––––––––––
|
3 |
+
- Plotly.js
|
4 |
+
- Grid
|
5 |
+
- Base Styles
|
6 |
+
- Typography
|
7 |
+
- Links
|
8 |
+
- Buttons
|
9 |
+
- Forms
|
10 |
+
- Lists
|
11 |
+
- Code
|
12 |
+
- Tables
|
13 |
+
- Spacing
|
14 |
+
- Utilities
|
15 |
+
- Clearing
|
16 |
+
- Media Queries
|
17 |
+
*/
|
18 |
+
|
19 |
+
/* PLotly.js
|
20 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
21 |
+
/* plotly.js's modebar's z-index is 1001 by default
|
22 |
+
* https://github.com/plotly/plotly.js/blob/7e4d8ab164258f6bd48be56589dacd9bdd7fded2/src/css/_modebar.scss#L5
|
23 |
+
* In case a dropdown is above the graph, the dropdown's options
|
24 |
+
* will be rendered below the modebar
|
25 |
+
* Increase the select option's z-index
|
26 |
+
*/
|
27 |
+
|
28 |
+
/* This was actually not quite right -
|
29 |
+
dropdowns were overlapping each other (edited October 26)
|
30 |
+
|
31 |
+
.Select {
|
32 |
+
z-index: 1002;
|
33 |
+
}*/
|
34 |
+
|
35 |
+
/* Grid
|
36 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
37 |
+
.container {
|
38 |
+
position: relative;
|
39 |
+
width: 100%;
|
40 |
+
max-width: 960px;
|
41 |
+
margin: 0 auto;
|
42 |
+
padding: 0 20px;
|
43 |
+
box-sizing: border-box; }
|
44 |
+
.column,
|
45 |
+
.columns {
|
46 |
+
width: 100%;
|
47 |
+
float: left;
|
48 |
+
box-sizing: border-box; }
|
49 |
+
|
50 |
+
/* For devices larger than 400px */
|
51 |
+
@media (min-width: 400px) {
|
52 |
+
.container {
|
53 |
+
width: 85%;
|
54 |
+
padding: 0; }
|
55 |
+
}
|
56 |
+
|
57 |
+
/* For devices larger than 550px */
|
58 |
+
@media (min-width: 550px) {
|
59 |
+
.container {
|
60 |
+
width: 80%; }
|
61 |
+
.column,
|
62 |
+
.columns {
|
63 |
+
margin-left: 2%; }
|
64 |
+
.column:first-child,
|
65 |
+
.columns:first-child {
|
66 |
+
margin-left: 1%; }
|
67 |
+
|
68 |
+
.one.column,
|
69 |
+
.one.columns { width: 4.66666666667%; }
|
70 |
+
.two.columns { width: 13.3333333333%; }
|
71 |
+
.three.columns { width: 22%; }
|
72 |
+
.four.columns { width: 30.6666666667%; }
|
73 |
+
.five.columns { width: 39.3333333333%; }
|
74 |
+
.six.columns { width: 48%; }
|
75 |
+
.seven.columns { width: 56.6666666667%; }
|
76 |
+
.eight.columns { width: 65.3333333333%; }
|
77 |
+
.nine.columns { width: 74.0%; }
|
78 |
+
.ten.columns { width: 82.6666666667%; }
|
79 |
+
.eleven.columns { width: 91.3333333333%; }
|
80 |
+
.twelve.columns { width: 100%; margin-left: 0; }
|
81 |
+
|
82 |
+
.one-third.column { width: 30.6666666667%; }
|
83 |
+
.two-thirds.column { width: 65.3333333333%; }
|
84 |
+
|
85 |
+
.one-half.column { width: 48%; }
|
86 |
+
|
87 |
+
/* Offsets */
|
88 |
+
.offset-by-one.column,
|
89 |
+
.offset-by-one.columns { margin-left: 8.66666666667%; }
|
90 |
+
.offset-by-two.column,
|
91 |
+
.offset-by-two.columns { margin-left: 17.3333333333%; }
|
92 |
+
.offset-by-three.column,
|
93 |
+
.offset-by-three.columns { margin-left: 26%; }
|
94 |
+
.offset-by-four.column,
|
95 |
+
.offset-by-four.columns { margin-left: 34.6666666667%; }
|
96 |
+
.offset-by-five.column,
|
97 |
+
.offset-by-five.columns { margin-left: 43.3333333333%; }
|
98 |
+
.offset-by-six.column,
|
99 |
+
.offset-by-six.columns { margin-left: 52%; }
|
100 |
+
.offset-by-seven.column,
|
101 |
+
.offset-by-seven.columns { margin-left: 60.6666666667%; }
|
102 |
+
.offset-by-eight.column,
|
103 |
+
.offset-by-eight.columns { margin-left: 69.3333333333%; }
|
104 |
+
.offset-by-nine.column,
|
105 |
+
.offset-by-nine.columns { margin-left: 78.0%; }
|
106 |
+
.offset-by-ten.column,
|
107 |
+
.offset-by-ten.columns { margin-left: 86.6666666667%; }
|
108 |
+
.offset-by-eleven.column,
|
109 |
+
.offset-by-eleven.columns { margin-left: 95.3333333333%; }
|
110 |
+
|
111 |
+
.offset-by-one-third.column,
|
112 |
+
.offset-by-one-third.columns { margin-left: 34.6666666667%; }
|
113 |
+
.offset-by-two-thirds.column,
|
114 |
+
.offset-by-two-thirds.columns { margin-left: 69.3333333333%; }
|
115 |
+
|
116 |
+
.offset-by-one-half.column,
|
117 |
+
.offset-by-one-half.columns { margin-left: 52%; }
|
118 |
+
|
119 |
+
}
|
120 |
+
|
121 |
+
|
122 |
+
/* Base Styles
|
123 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
124 |
+
/* NOTE
|
125 |
+
html is set to 62.5% so that all the REM measurements throughout Skeleton
|
126 |
+
are based on 10px sizing. So basically 1.5rem = 15px :) */
|
127 |
+
html {
|
128 |
+
font-size: 62.5%; }
|
129 |
+
body {
|
130 |
+
font-size: 1.5em; /* currently ems cause chrome bug misinterpreting rems on body element */
|
131 |
+
line-height: 1.6;
|
132 |
+
font-weight: 400;
|
133 |
+
font-family: "Open Sans", "HelveticaNeue", "Helvetica Neue", Helvetica, Arial, sans-serif;
|
134 |
+
color: rgb(50, 50, 50); }
|
135 |
+
|
136 |
+
|
137 |
+
/* Typography
|
138 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
139 |
+
h1, h2, h3, h4, h5, h6 {
|
140 |
+
margin-top: 0;
|
141 |
+
margin-bottom: 0;
|
142 |
+
font-weight: 300; }
|
143 |
+
h1 { font-size: 4.5rem; line-height: 1.2; letter-spacing: -.1rem; margin-bottom: 2rem; }
|
144 |
+
h2 { font-size: 3.6rem; line-height: 1.25; letter-spacing: -.1rem; margin-bottom: 1.8rem; margin-top: 1.8rem;}
|
145 |
+
h3 { font-size: 3.0rem; line-height: 1.3; letter-spacing: -.1rem; margin-bottom: 1.5rem; margin-top: 1.5rem;}
|
146 |
+
h4 { font-size: 2.6rem; line-height: 1.35; letter-spacing: -.08rem; margin-bottom: 1.2rem; margin-top: 1.2rem;}
|
147 |
+
h5 { font-size: 2.2rem; line-height: 1.5; letter-spacing: -.05rem; margin-bottom: 0.6rem; margin-top: 0.6rem;}
|
148 |
+
h6 { font-size: 1.7rem; line-height: 1.6; letter-spacing: 0; margin-bottom: 0.75rem; margin-top: 0.75rem;}
|
149 |
+
|
150 |
+
p {
|
151 |
+
margin-top: 0; }
|
152 |
+
|
153 |
+
|
154 |
+
/* Blockquotes
|
155 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
156 |
+
blockquote {
|
157 |
+
border-left: 4px lightgrey solid;
|
158 |
+
padding-left: 1rem;
|
159 |
+
margin-top: 2rem;
|
160 |
+
margin-bottom: 2rem;
|
161 |
+
margin-left: 0rem;
|
162 |
+
}
|
163 |
+
|
164 |
+
|
165 |
+
/* Links
|
166 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
167 |
+
a {
|
168 |
+
color: #1EAEDB;
|
169 |
+
text-decoration: underline;
|
170 |
+
cursor: pointer;}
|
171 |
+
a:hover {
|
172 |
+
color: #0FA0CE; }
|
173 |
+
|
174 |
+
|
175 |
+
/* Buttons
|
176 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
177 |
+
.button,
|
178 |
+
button,
|
179 |
+
input[type="submit"],
|
180 |
+
input[type="reset"],
|
181 |
+
input[type="button"] {
|
182 |
+
display: inline-block;
|
183 |
+
height: 38px;
|
184 |
+
padding: 0 30px;
|
185 |
+
color: #555;
|
186 |
+
text-align: center;
|
187 |
+
font-size: 11px;
|
188 |
+
font-weight: 600;
|
189 |
+
line-height: 38px;
|
190 |
+
letter-spacing: .1rem;
|
191 |
+
text-transform: uppercase;
|
192 |
+
text-decoration: none;
|
193 |
+
white-space: nowrap;
|
194 |
+
background-color: transparent;
|
195 |
+
border-radius: 4px;
|
196 |
+
border: 1px solid #bbb;
|
197 |
+
cursor: pointer;
|
198 |
+
box-sizing: border-box; }
|
199 |
+
.button:hover,
|
200 |
+
button:hover,
|
201 |
+
input[type="submit"]:hover,
|
202 |
+
input[type="reset"]:hover,
|
203 |
+
input[type="button"]:hover,
|
204 |
+
.button:focus,
|
205 |
+
button:focus,
|
206 |
+
input[type="submit"]:focus,
|
207 |
+
input[type="reset"]:focus,
|
208 |
+
input[type="button"]:focus {
|
209 |
+
color: #333;
|
210 |
+
border-color: #888;
|
211 |
+
outline: 0; }
|
212 |
+
.button.button-primary,
|
213 |
+
button.button-primary,
|
214 |
+
input[type="submit"].button-primary,
|
215 |
+
input[type="reset"].button-primary,
|
216 |
+
input[type="button"].button-primary {
|
217 |
+
color: #FFF;
|
218 |
+
background-color: #33C3F0;
|
219 |
+
border-color: #33C3F0; }
|
220 |
+
.button.button-primary:hover,
|
221 |
+
button.button-primary:hover,
|
222 |
+
input[type="submit"].button-primary:hover,
|
223 |
+
input[type="reset"].button-primary:hover,
|
224 |
+
input[type="button"].button-primary:hover,
|
225 |
+
.button.button-primary:focus,
|
226 |
+
button.button-primary:focus,
|
227 |
+
input[type="submit"].button-primary:focus,
|
228 |
+
input[type="reset"].button-primary:focus,
|
229 |
+
input[type="button"].button-primary:focus {
|
230 |
+
color: #FFF;
|
231 |
+
background-color: #1EAEDB;
|
232 |
+
border-color: #1EAEDB; }
|
233 |
+
|
234 |
+
|
235 |
+
/* Forms
|
236 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
237 |
+
input[type="email"],
|
238 |
+
input[type="number"],
|
239 |
+
input[type="search"],
|
240 |
+
input[type="text"],
|
241 |
+
input[type="tel"],
|
242 |
+
input[type="url"],
|
243 |
+
input[type="password"],
|
244 |
+
textarea,
|
245 |
+
select {
|
246 |
+
height: 38px;
|
247 |
+
padding: 6px 10px; /* The 6px vertically centers text on FF, ignored by Webkit */
|
248 |
+
background-color: #fff;
|
249 |
+
border: 1px solid #D1D1D1;
|
250 |
+
border-radius: 4px;
|
251 |
+
box-shadow: none;
|
252 |
+
box-sizing: border-box;
|
253 |
+
font-family: inherit;
|
254 |
+
font-size: inherit; /*https://stackoverflow.com/questions/6080413/why-doesnt-input-inherit-the-font-from-body*/}
|
255 |
+
/* Removes awkward default styles on some inputs for iOS */
|
256 |
+
input[type="email"],
|
257 |
+
input[type="number"],
|
258 |
+
input[type="search"],
|
259 |
+
input[type="text"],
|
260 |
+
input[type="tel"],
|
261 |
+
input[type="url"],
|
262 |
+
input[type="password"],
|
263 |
+
textarea {
|
264 |
+
-webkit-appearance: none;
|
265 |
+
-moz-appearance: none;
|
266 |
+
appearance: none; }
|
267 |
+
textarea {
|
268 |
+
min-height: 65px;
|
269 |
+
padding-top: 6px;
|
270 |
+
padding-bottom: 6px; }
|
271 |
+
input[type="email"]:focus,
|
272 |
+
input[type="number"]:focus,
|
273 |
+
input[type="search"]:focus,
|
274 |
+
input[type="text"]:focus,
|
275 |
+
input[type="tel"]:focus,
|
276 |
+
input[type="url"]:focus,
|
277 |
+
input[type="password"]:focus,
|
278 |
+
textarea:focus,
|
279 |
+
/*select:focus {*/
|
280 |
+
/* border: 1px solid #33C3F0;*/
|
281 |
+
/* outline: 0; }*/
|
282 |
+
label,
|
283 |
+
legend {
|
284 |
+
display: block;
|
285 |
+
margin-bottom: 0px; }
|
286 |
+
fieldset {
|
287 |
+
padding: 0;
|
288 |
+
border-width: 0; }
|
289 |
+
input[type="checkbox"],
|
290 |
+
input[type="radio"] {
|
291 |
+
display: inline; }
|
292 |
+
label > .label-body {
|
293 |
+
display: inline-block;
|
294 |
+
margin-left: .5rem;
|
295 |
+
font-weight: normal; }
|
296 |
+
|
297 |
+
|
298 |
+
/* Lists
|
299 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
300 |
+
ul {
|
301 |
+
list-style: circle inside; }
|
302 |
+
ol {
|
303 |
+
list-style: decimal inside; }
|
304 |
+
ol, ul {
|
305 |
+
padding-left: 0;
|
306 |
+
margin-top: 0; }
|
307 |
+
ul ul,
|
308 |
+
ul ol,
|
309 |
+
ol ol,
|
310 |
+
ol ul {
|
311 |
+
margin: 1.5rem 0 1.5rem 3rem;
|
312 |
+
font-size: 90%; }
|
313 |
+
li {
|
314 |
+
margin-bottom: 1rem; }
|
315 |
+
|
316 |
+
|
317 |
+
/* Tables
|
318 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
319 |
+
table {
|
320 |
+
border-collapse: collapse;
|
321 |
+
}
|
322 |
+
th,
|
323 |
+
td {
|
324 |
+
padding: 12px 15px;
|
325 |
+
text-align: left;
|
326 |
+
border-bottom: 1px solid #E1E1E1; }
|
327 |
+
th:first-child,
|
328 |
+
td:first-child {
|
329 |
+
padding-left: 0; }
|
330 |
+
th:last-child,
|
331 |
+
td:last-child {
|
332 |
+
padding-right: 0; }
|
333 |
+
|
334 |
+
|
335 |
+
/* Spacing
|
336 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
337 |
+
button,
|
338 |
+
.button {
|
339 |
+
margin-bottom: 0rem; }
|
340 |
+
input,
|
341 |
+
textarea,
|
342 |
+
select,
|
343 |
+
fieldset {
|
344 |
+
margin-bottom: 0rem; }
|
345 |
+
pre,
|
346 |
+
dl,
|
347 |
+
figure,
|
348 |
+
table,
|
349 |
+
form {
|
350 |
+
margin-bottom: 0rem; }
|
351 |
+
p,
|
352 |
+
ul,
|
353 |
+
ol {
|
354 |
+
margin-bottom: 0.75rem; }
|
355 |
+
|
356 |
+
/* Utilities
|
357 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
358 |
+
.u-full-width {
|
359 |
+
width: 100%;
|
360 |
+
box-sizing: border-box; }
|
361 |
+
.u-max-full-width {
|
362 |
+
max-width: 100%;
|
363 |
+
box-sizing: border-box; }
|
364 |
+
.u-pull-right {
|
365 |
+
float: right; }
|
366 |
+
.u-pull-left {
|
367 |
+
float: left; }
|
368 |
+
|
369 |
+
|
370 |
+
/* Misc
|
371 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
372 |
+
hr {
|
373 |
+
margin-top: 3rem;
|
374 |
+
margin-bottom: 3.5rem;
|
375 |
+
border-width: 0;
|
376 |
+
border-top: 1px solid #E1E1E1; }
|
377 |
+
|
378 |
+
|
379 |
+
/* Clearing
|
380 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
381 |
+
|
382 |
+
/* Self Clearing Goodness */
|
383 |
+
.container:after,
|
384 |
+
.row:after,
|
385 |
+
.u-cf {
|
386 |
+
content: "";
|
387 |
+
display: table;
|
388 |
+
clear: both; }
|
389 |
+
|
390 |
+
|
391 |
+
/* Media Queries
|
392 |
+
–––––––––––––––––––––––––––––––––––––––––––––––––– */
|
393 |
+
/*
|
394 |
+
Note: The best way to structure the use of media queries is to create the queries
|
395 |
+
near the relevant code. For example, if you wanted to change the styles for buttons
|
396 |
+
on small devices, paste the mobile query code up in the buttons section and style it
|
397 |
+
there.
|
398 |
+
*/
|
399 |
+
|
400 |
+
|
401 |
+
/* Larger than mobile */
|
402 |
+
@media (min-width: 400px) {}
|
403 |
+
|
404 |
+
/* Larger than phablet (also point when grid becomes active) */
|
405 |
+
@media (min-width: 550px) {}
|
406 |
+
|
407 |
+
/* Larger than tablet */
|
408 |
+
@media (min-width: 750px) {}
|
409 |
+
|
410 |
+
/* Larger than desktop */
|
411 |
+
@media (min-width: 1000px) {}
|
412 |
+
|
413 |
+
/* Larger than Desktop HD */
|
414 |
+
@media (min-width: 1200px) {}
|
dash_app/assets/style.css
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
background-color: #f9f9f9;
|
3 |
+
color: #333333;
|
4 |
+
font-family: "Helvetica Neue", Arial, sans-serif;
|
5 |
+
font-size: 15px;
|
6 |
+
margin: 5px;
|
7 |
+
}
|
8 |
+
.zoom
|
9 |
+
{
|
10 |
+
zoom: 70%;
|
11 |
+
}
|
12 |
+
#left-column {
|
13 |
+
padding: 2px 0.6rem;
|
14 |
+
display: flex;
|
15 |
+
flex-direction: column;
|
16 |
+
min-height: 100vh;
|
17 |
+
}
|
18 |
+
|
19 |
+
#right-column {
|
20 |
+
padding: 0 0.2rem;
|
21 |
+
min-height: 100vh;
|
22 |
+
}
|
23 |
+
|
24 |
+
h5 {
|
25 |
+
color: #2c8cff;
|
26 |
+
font-weight: 300;
|
27 |
+
margin: 10px;
|
28 |
+
}
|
29 |
+
h6 {
|
30 |
+
color: #333333;
|
31 |
+
font-size: 300;
|
32 |
+
font-weight: bold;
|
33 |
+
margin: 10px;
|
34 |
+
}
|
35 |
+
|
36 |
+
#intro {
|
37 |
+
margin: 20px 0px;
|
38 |
+
text-align: justify;
|
39 |
+
}
|
40 |
+
|
41 |
+
#control-card label {
|
42 |
+
font-weight: bold;
|
43 |
+
}
|
44 |
+
.tabs {
|
45 |
+
display: flex;
|
46 |
+
flex-direction: row;
|
47 |
+
justify-content: space-between;
|
48 |
+
margin: 0.6rem;
|
49 |
+
height: 1vh;
|
50 |
+
}
|
51 |
+
.graph_card {
|
52 |
+
background-color: white;
|
53 |
+
border-radius: 0.675rem;
|
54 |
+
padding: 0.6rem;
|
55 |
+
margin: 0.6rem;
|
56 |
+
margin-top: 10px;
|
57 |
+
}
|
58 |
+
|
59 |
+
.graph_card h6 {
|
60 |
+
margin: 0 7px;
|
61 |
+
padding-bottom: 7px;
|
62 |
+
border-bottom: 1px solid #ccc;
|
63 |
+
}
|
64 |
+
|
65 |
+
.modal {
|
66 |
+
position: fixed;
|
67 |
+
z-index: 1002; /* Sit on top, including modebar which has z=1001 */
|
68 |
+
left: 0;
|
69 |
+
top: 0;
|
70 |
+
width: 30%; /* Full width */
|
71 |
+
height: 100%; /* Full height */
|
72 |
+
background-color: rgba(0, 0, 0, 0.6); /* Black w/ opacity */
|
73 |
+
|
74 |
+
}
|
75 |
+
|
76 |
+
.modal-content {
|
77 |
+
margin: 10px;
|
78 |
+
height: 600px;
|
79 |
+
padding: 10px;
|
80 |
+
background-color: #c58797;
|
81 |
+
}
|
82 |
+
|
83 |
+
.column_left {
|
84 |
+
border-right: 1px solid #333333;
|
85 |
+
}
|
dash_app/config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
APP_NAME = "Subgroup Harm Assessor"
|
dash_app/main.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dash import Dash
|
2 |
+
from .config import APP_NAME
|
3 |
+
import dash_bootstrap_components as dbc
|
4 |
+
|
5 |
+
|
6 |
+
app = Dash(
|
7 |
+
__name__,
|
8 |
+
external_stylesheets=[dbc.themes.BOOTSTRAP],
|
9 |
+
meta_tags=[{"name": "viewport", "content": "width=device-width"}],
|
10 |
+
)
|
11 |
+
app.title = APP_NAME
|
dash_app/views/confusion_matrix.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dash import dcc, html
|
2 |
+
import plotly.express as px
|
3 |
+
from explainerdashboard.explainer_plots import plotly_confusion_matrix
|
4 |
+
from sklearn.metrics import confusion_matrix
|
5 |
+
|
6 |
+
|
7 |
+
class CMchart(html.Div):
|
8 |
+
def __init__(self, name, y_true, y_pred):
|
9 |
+
"""
|
10 |
+
:param name: name of the plot
|
11 |
+
:param df: dataframe
|
12 |
+
"""
|
13 |
+
self.html_id = name.lower().replace(" ", "-")
|
14 |
+
self.name = name
|
15 |
+
self.cm = confusion_matrix(y_true, y_pred)
|
16 |
+
self.fig = plotly_confusion_matrix(self.cm)
|
17 |
+
self.title_id = self.html_id + "-t"
|
18 |
+
|
19 |
+
# Equivalent to `html.Div([...])`
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
def update(self):
|
23 |
+
self.fig = plotly_confusion_matrix(self.cm)
|
24 |
+
|
25 |
+
self.fig.update_layout(
|
26 |
+
yaxis_zeroline=False, xaxis_zeroline=False, dragmode="select"
|
27 |
+
)
|
28 |
+
self.fig.update_xaxes(fixedrange=True)
|
29 |
+
self.fig.update_yaxes(fixedrange=True)
|
30 |
+
|
31 |
+
# update titles
|
32 |
+
self.fig.update_layout(
|
33 |
+
xaxis_title=self.col1,
|
34 |
+
yaxis_title=self.col2,
|
35 |
+
)
|
36 |
+
|
37 |
+
return self.fig
|
dash_app/views/menu.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from metrics import get_name_from_metric_str
|
3 |
+
import plotly.express as px
|
4 |
+
|
5 |
+
|
6 |
+
def get_subgroup_dropdown_options(
|
7 |
+
result_set_df: pd.DataFrame, quality_metric: str = "Quality"
|
8 |
+
):
|
9 |
+
"""Get the options for the subgroup dropdown. They consist of the subgroup description and the subgroup index."""
|
10 |
+
|
11 |
+
if result_set_df.empty:
|
12 |
+
return [
|
13 |
+
{
|
14 |
+
"label": "No subgroups found. Check your subgroup search criteria.",
|
15 |
+
"value": -1,
|
16 |
+
}
|
17 |
+
]
|
18 |
+
|
19 |
+
# For each description, get the size and quality of the corresponding subgroup
|
20 |
+
return [
|
21 |
+
{
|
22 |
+
"label": str(result_set_df["description"][idx])
|
23 |
+
+ "; Size: "
|
24 |
+
+ str(result_set_df["size"][idx])
|
25 |
+
# + f"; {quality_metric}: "
|
26 |
+
# + str(result_set_df["quality"][idx].round(3))
|
27 |
+
+ f"; {get_name_from_metric_str(quality_metric)}: "
|
28 |
+
+ str(result_set_df["metric_score"][idx]),
|
29 |
+
"value": idx,
|
30 |
+
}
|
31 |
+
for idx in range(len(result_set_df))
|
32 |
+
]
|
33 |
+
|
34 |
+
|
35 |
+
def get_shap_barchart(
|
36 |
+
shap_values_df, group_feature="group", title="Feature contributions"
|
37 |
+
):
|
38 |
+
# Sort the values by mean
|
39 |
+
agg_df = shap_values_df.groupby(group_feature).mean()
|
40 |
+
# Create the fig
|
41 |
+
fig = px.bar(
|
42 |
+
agg_df,
|
43 |
+
y=agg_df.index,
|
44 |
+
x=agg_df.values,
|
45 |
+
color=agg_df.index,
|
46 |
+
barmode="group",
|
47 |
+
title=title,
|
48 |
+
orientation="h",
|
49 |
+
)
|
50 |
+
|
51 |
+
# Update the fig
|
52 |
+
fig.update_layout(
|
53 |
+
title=title,
|
54 |
+
xaxis_title="Contribution (logloss SHAP)",
|
55 |
+
yaxis_title="",
|
56 |
+
)
|
57 |
+
# Turn y labels
|
58 |
+
fig.update_layout(yaxis_tickangle=-35)
|
59 |
+
# Set x axis range
|
60 |
+
fig.update_xaxes(range=[-0.1, 0.1])
|
61 |
+
return fig
|
fairsd/.gitignore
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
98 |
+
__pypackages__/
|
99 |
+
|
100 |
+
# Celery stuff
|
101 |
+
celerybeat-schedule
|
102 |
+
celerybeat.pid
|
103 |
+
|
104 |
+
# SageMath parsed files
|
105 |
+
*.sage.py
|
106 |
+
|
107 |
+
# Environments
|
108 |
+
.env
|
109 |
+
.venv
|
110 |
+
env/
|
111 |
+
venv/
|
112 |
+
ENV/
|
113 |
+
env.bak/
|
114 |
+
venv.bak/
|
115 |
+
|
116 |
+
# Spyder project settings
|
117 |
+
.spyderproject
|
118 |
+
.spyproject
|
119 |
+
|
120 |
+
# Rope project settings
|
121 |
+
.ropeproject
|
122 |
+
|
123 |
+
# mkdocs documentation
|
124 |
+
/site
|
125 |
+
|
126 |
+
# mypy
|
127 |
+
.mypy_cache/
|
128 |
+
.dmypy.json
|
129 |
+
dmypy.json
|
130 |
+
|
131 |
+
# Pyre type checker
|
132 |
+
.pyre/
|
133 |
+
|
134 |
+
# pytype static type analyzer
|
135 |
+
.pytype/
|
136 |
+
|
137 |
+
# Cython debug symbols
|
138 |
+
cython_debug/
|
fairsd/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
fairsd/README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# FairSD
|
2 |
+
|
3 |
+
FairSD is a package that implements top-k subgroup discovery algorithms for identifying subgroups that may be treated unfairly by a machine learning model.<br/>
|
4 |
+
|
5 |
+
The package has been designed to offer the user the possibility to use different notions of fairness as quality measures. Integration with the [Fairlearn]( https://fairlearn.github.io/) package allows the user to use all the [fairlearn metrics](https://fairlearn.github.io/v0.6.0/api_reference/fairlearn.metrics.html) as quality measures. The user can also define custom quality measures, by extending the QualityFunction class present in the [fairsd.qualitymeasures](https://github.com/MaurizioPulizzi/fairsd/blob/main/fairsd/qualitymeasures.py) module.
|
6 |
+
|
7 |
+
|
8 |
+
## Usage
|
9 |
+
For common usage refer to the [Jupyter notebooks](https://github.com/MaurizioPulizzi/fairsd/tree/main/notebooks). In particular:
|
10 |
+
* [Quick start - FairSD usage](https://github.com/MaurizioPulizzi/fairsd/blob/main/notebooks/fairsd_usage.ipynb).
|
11 |
+
* [FairSD settings](https://github.com/MaurizioPulizzi/fairsd/blob/main/notebooks/fairsd_settings.ipynb), for a detailed explanation of how inizialize the SugbgroupDiscoveryTask object and use the implemented subgroup discovery algorithms.
|
12 |
+
|
13 |
+
|
14 |
+
## Contributors
|
15 |
+
* [Maurizio Pulizzi](https://github.com/MaurizioPulizzi)
|
16 |
+
* [Hilde Weerts](https://github.com/hildeweerts)
|
17 |
+
|
18 |
+
|
19 |
+
## Acknowledgements
|
20 |
+
Some parts of the code are an adaptation of the [pysubgroup package](https://github.com/flemmerich/pysubgroup). These parts are indicated in the code.
|
fairsd/fairsd/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .sgdescription import Description
|
2 |
+
from .sgdescription import Descriptor
|
3 |
+
|
4 |
+
# from .sgdescription import BinaryTarget
|
5 |
+
from .searchspace import SearchSpace
|
6 |
+
from .searchspace import Discretizer
|
7 |
+
from .qualitymeasures import QualityFunction
|
8 |
+
|
9 |
+
from .algorithms import SubgroupDiscoveryTask
|
10 |
+
from .algorithms import BeamSearch
|
11 |
+
from .algorithms import DSSD
|
12 |
+
from .algorithms import ResultSet
|
fairsd/fairsd/algorithms.py
ADDED
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import fairlearn.metrics as flm
|
5 |
+
import inspect
|
6 |
+
import logging
|
7 |
+
from .sgdescription import Description
|
8 |
+
from .searchspace import SearchSpace
|
9 |
+
from .searchspace import Discretizer
|
10 |
+
|
11 |
+
"""
|
12 |
+
The class SubgroupDiscoveryTask is an adaptation of the homonymous class of the pysubgroup library.
|
13 |
+
"""
|
14 |
+
|
15 |
+
quality_function_options = [
|
16 |
+
"equalized_odds_difference",
|
17 |
+
"equalized_odds_ratio",
|
18 |
+
"demographic_parity_difference",
|
19 |
+
"demographic_parity_ratio",
|
20 |
+
]
|
21 |
+
|
22 |
+
quality_function_parameters = ["y_true", "y_pred", "sensitive_features"]
|
23 |
+
|
24 |
+
|
25 |
+
class SubgroupDiscoveryTask:
|
26 |
+
"""This is an interface class and will contain all the parameters useful for the sg discovery algorithms."""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
X, # pandas dataframe or numpy array with features
|
31 |
+
y_true, # numpy array, pandas dataframe, or pandas Series with ground truth labels
|
32 |
+
y_pred=None, # numpy array, pandas dataframe, or pandas Series with classifier's predicted labels
|
33 |
+
feature_names=None, # optional, list with column names in case users supply a numpy array X
|
34 |
+
sensitive_features=None, # list of sensitive features names (str)
|
35 |
+
nominal_features=None, # optional, list of nominal features
|
36 |
+
numeric_features=None, # optional, list of nominal features
|
37 |
+
qf="equalized_odds_difference", # str or callable object
|
38 |
+
discretizer="equalfreq", # str
|
39 |
+
num_bins=6,
|
40 |
+
dynamic_discretization=True, # boolean
|
41 |
+
result_set_size=5, # int
|
42 |
+
depth=3, # int
|
43 |
+
min_quality=0, # float
|
44 |
+
min_support=200, # int
|
45 |
+
min_support_ratio=0.1, # float
|
46 |
+
max_support_ratio=0.5, # float
|
47 |
+
logging_level=logging.INFO,
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Parameters
|
51 |
+
----------
|
52 |
+
X : pandas dataframe or numpy array
|
53 |
+
y_true : numpy array, pandas dataframe, or pandas Series
|
54 |
+
represent the ground truth
|
55 |
+
y_pred : numpy array, pandas dataframe, or pandas Series
|
56 |
+
contain the predicted values
|
57 |
+
feature_names : list of string
|
58 |
+
this parameter is necessary if the user supply X in a numpy array
|
59 |
+
sensitive_features: list of string
|
60 |
+
this list contains the names of the sensitive features
|
61 |
+
nominal_features : optional, list of strings
|
62 |
+
list of nominal features
|
63 |
+
numeric_features : optional, list of strings
|
64 |
+
list of nominal features
|
65 |
+
qf : string or callable object
|
66 |
+
discretizer : string
|
67 |
+
can be "mdlp", "equalfrequency" or "equalwidth"
|
68 |
+
num_bins : int
|
69 |
+
maximum number of bins that a numerical feature discretization operation will produce
|
70 |
+
dynamic_discretization : boolean
|
71 |
+
result_set_size : int
|
72 |
+
depth : int
|
73 |
+
maximum number of descriptors in a description
|
74 |
+
min_quality : float
|
75 |
+
min_support : int
|
76 |
+
minimum size of a subgroup
|
77 |
+
min_support_ratio : float
|
78 |
+
minimum proportion of a subgroup compared to the whole dataset size
|
79 |
+
max_support_ratio : float
|
80 |
+
maximum proportion of a subgroup compared to the whole dataset size
|
81 |
+
logging_level : int
|
82 |
+
logging level
|
83 |
+
"""
|
84 |
+
logging.basicConfig(level=logging_level)
|
85 |
+
self.inputChecking(
|
86 |
+
X,
|
87 |
+
y_true,
|
88 |
+
y_pred,
|
89 |
+
feature_names,
|
90 |
+
sensitive_features,
|
91 |
+
nominal_features,
|
92 |
+
numeric_features,
|
93 |
+
discretizer,
|
94 |
+
dynamic_discretization,
|
95 |
+
result_set_size,
|
96 |
+
depth,
|
97 |
+
min_quality,
|
98 |
+
min_support,
|
99 |
+
min_support_ratio,
|
100 |
+
)
|
101 |
+
if isinstance(X, np.ndarray):
|
102 |
+
self.data = pd.DataFrame(X, columns=feature_names)
|
103 |
+
else:
|
104 |
+
self.data = X.copy()
|
105 |
+
|
106 |
+
self.data["y_true"] = y_true
|
107 |
+
if y_pred is not None:
|
108 |
+
self.data["y_pred"] = y_pred
|
109 |
+
self.there_is_y_pred = True
|
110 |
+
else:
|
111 |
+
self.there_is_y_pred = False
|
112 |
+
self.sensitive_features = sensitive_features
|
113 |
+
self.discretizer = Discretizer(
|
114 |
+
discretization_type=discretizer, target="y_true", num_bins=num_bins
|
115 |
+
)
|
116 |
+
self.search_space = SearchSpace(
|
117 |
+
self.data,
|
118 |
+
["y_true", "y_pred"],
|
119 |
+
nominal_features,
|
120 |
+
numeric_features,
|
121 |
+
dynamic_discretization,
|
122 |
+
self.discretizer,
|
123 |
+
sensitive_features,
|
124 |
+
)
|
125 |
+
|
126 |
+
self.qf = self.set_qualityfuntion(qf)
|
127 |
+
|
128 |
+
self.result_set_size = result_set_size
|
129 |
+
self.depth = depth
|
130 |
+
self.min_quality = min_quality
|
131 |
+
self.min_support = min_support
|
132 |
+
self.min_support_ratio = min_support_ratio
|
133 |
+
self.max_support_ratio = max_support_ratio
|
134 |
+
|
135 |
+
def set_qualityfuntion(self, qf):
|
136 |
+
if isinstance(qf, str):
|
137 |
+
if qf not in quality_function_options:
|
138 |
+
raise ValueError("Quality function not known")
|
139 |
+
else:
|
140 |
+
return getattr(flm, qf)
|
141 |
+
|
142 |
+
if not callable(qf):
|
143 |
+
RuntimeError("Supplied metric object must be callable or string")
|
144 |
+
sig = inspect.signature(qf).parameters
|
145 |
+
for par in quality_function_parameters:
|
146 |
+
if par not in sig:
|
147 |
+
raise ValueError(
|
148 |
+
"Please use the functions in the fairlearn.metrics package as quality functions or "
|
149 |
+
"other fuctions with the same interface"
|
150 |
+
)
|
151 |
+
return qf
|
152 |
+
|
153 |
+
def inputChecking(
|
154 |
+
self,
|
155 |
+
X, # pandas dataframe or numpy array
|
156 |
+
y_true, # numpy array, pandas dataframe, or pandas Series with ground truth labels
|
157 |
+
y_pred, # numpy array, pandas dataframe, or pandas Series with classifier's predicted labels
|
158 |
+
feature_names, # optional, list with column names in case users supply a numpy array X
|
159 |
+
sensitive_features,
|
160 |
+
nominal_features, # optional, list of nominal features
|
161 |
+
numeric_features, # optional, list of nominal features
|
162 |
+
discretizer, # str
|
163 |
+
dynamic_discretization, # boolean
|
164 |
+
result_set_size, # int
|
165 |
+
depth, # int
|
166 |
+
min_quality, # float
|
167 |
+
min_support, # int
|
168 |
+
min_support_ratio, # float
|
169 |
+
):
|
170 |
+
if not (isinstance(X, pd.DataFrame) or isinstance(X, np.ndarray)):
|
171 |
+
raise TypeError("X must be of type numpy.ndarray or pandas.DataFrame")
|
172 |
+
if not (
|
173 |
+
isinstance(y_true, pd.DataFrame)
|
174 |
+
or isinstance(y_true, np.ndarray)
|
175 |
+
or isinstance(y_true, pd.Series)
|
176 |
+
):
|
177 |
+
raise TypeError(
|
178 |
+
"y_true must be of type numpy.ndarray, pandas.Series or pandas.DataFrame"
|
179 |
+
)
|
180 |
+
if X.shape[0] != y_true.size:
|
181 |
+
raise RuntimeError("X and y_true have two different dimensions")
|
182 |
+
if y_pred is not None:
|
183 |
+
if not (
|
184 |
+
isinstance(y_pred, pd.DataFrame)
|
185 |
+
or isinstance(y_pred, np.ndarray)
|
186 |
+
or isinstance(y_pred, pd.Series)
|
187 |
+
):
|
188 |
+
raise TypeError(
|
189 |
+
"y_pred must be of type numpy.ndarray, pandas.Series or pandas.DataFrame"
|
190 |
+
)
|
191 |
+
if y_pred.size != y_true.size:
|
192 |
+
raise RuntimeError("y_pred and y_true have two different dimensions")
|
193 |
+
if isinstance(X, np.ndarray):
|
194 |
+
if (not isinstance(feature_names, list)) or len(feature_names) != X.shape[
|
195 |
+
1
|
196 |
+
]:
|
197 |
+
raise RuntimeError(
|
198 |
+
"If X is a numpy.ndarray, feature_names must contain the names of the colums"
|
199 |
+
)
|
200 |
+
if sensitive_features is not None and not isinstance(sensitive_features, list):
|
201 |
+
raise RuntimeError("sensitive_features input must be of list type or None")
|
202 |
+
if nominal_features is not None and not isinstance(nominal_features, list):
|
203 |
+
raise RuntimeError("nominal_features input must be of list type or None")
|
204 |
+
if numeric_features is not None and not isinstance(nominal_features, list):
|
205 |
+
raise RuntimeError("numeric_features input must be of list type or None")
|
206 |
+
if not isinstance(discretizer, str):
|
207 |
+
raise TypeError("discretizer input must be of string type")
|
208 |
+
if discretizer == "mdlp":
|
209 |
+
t = pd.DataFrame(y_true).iloc[:, 0].unique()
|
210 |
+
if not (t == [1, 0]).all() and not (t == [0, 1]).all():
|
211 |
+
raise RuntimeError("MDLP discretization supports only binary target")
|
212 |
+
if not isinstance(dynamic_discretization, bool):
|
213 |
+
raise TypeError("dynamic_discretization input must be of bool type")
|
214 |
+
if not isinstance(dynamic_discretization, bool):
|
215 |
+
raise TypeError("dynamic_discretization input must be of bool type")
|
216 |
+
if not isinstance(result_set_size, int) or result_set_size < 1:
|
217 |
+
raise RuntimeError("result_set_size input must be greater than 0")
|
218 |
+
if not isinstance(depth, int):
|
219 |
+
raise RuntimeError("depth input must be greater than 0")
|
220 |
+
if not isinstance(min_support, int):
|
221 |
+
raise RuntimeError("min_support input must be greater than 0")
|
222 |
+
if not isinstance(min_support_ratio, float):
|
223 |
+
raise RuntimeError("min_support_ratio input must be of float type")
|
224 |
+
if min_quality > 1 or min_quality < 0:
|
225 |
+
raise RuntimeError("min_quality input must be between 0 and 1")
|
226 |
+
|
227 |
+
|
228 |
+
class ResultSet:
|
229 |
+
"""
|
230 |
+
This class is used to represent the subgroup set found by one of the
|
231 |
+
subgroup discovery algorithms implemented in this package.
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self, descriptions_list, x_size):
|
235 |
+
"""
|
236 |
+
:param descriptions_list: list of Description objects
|
237 |
+
:param x_size: int, size of the dataset
|
238 |
+
"""
|
239 |
+
self.descriptions_list = descriptions_list
|
240 |
+
self.X_size = x_size
|
241 |
+
|
242 |
+
def to_dataframe(self):
|
243 |
+
"""Convert the result set into a dataframe
|
244 |
+
|
245 |
+
:return: pandas.Dataframe
|
246 |
+
"""
|
247 |
+
lod = list()
|
248 |
+
for d in self.descriptions_list:
|
249 |
+
row = [d.quality, d.__repr__(), d.support, d.support / self.X_size]
|
250 |
+
lod.append(row)
|
251 |
+
columns = ["quality", "description", "size", "proportion"]
|
252 |
+
index = [str(x) for x in range(len(self.descriptions_list))]
|
253 |
+
return pd.DataFrame(lod, index=index, columns=columns)
|
254 |
+
|
255 |
+
def get_description(self, sg_index) -> List[Description]:
|
256 |
+
if sg_index >= len(self.descriptions_list) or sg_index < 0:
|
257 |
+
raise RuntimeError("The requested subgroup doesn't exists")
|
258 |
+
return self.descriptions_list[sg_index]
|
259 |
+
|
260 |
+
def sg_feature(self, sg_index, X):
|
261 |
+
"""
|
262 |
+
This method generate and return the feature of the subgroup with index = sg_index in the current object.
|
263 |
+
The result is indeed a boolean array of the same length of the dataset X. Each i-th element of this
|
264 |
+
array is true iff the i-th tuple of X belong to the subgroup with index sg_index.
|
265 |
+
|
266 |
+
:param sg_index: int, number of the subgroup in the current object
|
267 |
+
:param X: pandas DataFrame or numpy array
|
268 |
+
:return: boolean list
|
269 |
+
"""
|
270 |
+
if sg_index >= len(self.descriptions_list) or sg_index < 0:
|
271 |
+
raise RuntimeError("The requested subgroup doesn't exists")
|
272 |
+
return pd.Series(
|
273 |
+
self.descriptions_list[sg_index].to_boolean_array(X),
|
274 |
+
name=str("sg" + str(sg_index)),
|
275 |
+
)
|
276 |
+
|
277 |
+
def __repr__(self):
|
278 |
+
res = ""
|
279 |
+
for desc in self.descriptions_list:
|
280 |
+
res += desc.__repr__() + "\n"
|
281 |
+
return res
|
282 |
+
|
283 |
+
def print(self):
|
284 |
+
print(self.__repr__())
|
285 |
+
|
286 |
+
def to_string(self):
|
287 |
+
return self.__repr__()
|
288 |
+
|
289 |
+
def to_json(self, X: pd.DataFrame, metric: str, result_set_df: pd.DataFrame):
|
290 |
+
"""Save string representation of descriptions and sg_feature boolean series to json
|
291 |
+
|
292 |
+
:param X: pandas DataFrame representing the dataset
|
293 |
+
:param metric: string, name of the metric used to evaluate the subgroups
|
294 |
+
:param result_set_df: pandas DataFrame used for ordering of the results
|
295 |
+
:return: dict
|
296 |
+
"""
|
297 |
+
descriptions = []
|
298 |
+
sg_features = []
|
299 |
+
result_set_df.reset_index(inplace=True)
|
300 |
+
for _, row in result_set_df.iterrows():
|
301 |
+
i = row[
|
302 |
+
"index"
|
303 |
+
] # The original subgroup index might not be the same as the index in the result set df
|
304 |
+
desc = self.descriptions_list[int(i)]
|
305 |
+
descriptions.append(desc.__repr__())
|
306 |
+
sg_feature = self.sg_feature(int(i), X).to_json()
|
307 |
+
sg_features.append(sg_feature)
|
308 |
+
|
309 |
+
return {
|
310 |
+
"descriptions": descriptions,
|
311 |
+
"sg_features": sg_features,
|
312 |
+
"metric": metric,
|
313 |
+
}
|
314 |
+
|
315 |
+
|
316 |
+
class BeamSearch:
|
317 |
+
"""This class is used to execute the Beam Search Algorithm."""
|
318 |
+
|
319 |
+
def __init__(self, beam_width=20):
|
320 |
+
"""
|
321 |
+
:param beam_width : int
|
322 |
+
"""
|
323 |
+
if beam_width < 1:
|
324 |
+
raise RuntimeError("beam_width must be greater than 0")
|
325 |
+
self.beam_width = beam_width
|
326 |
+
|
327 |
+
def execute(self, task, method="between_groups"):
|
328 |
+
"""
|
329 |
+
This method execute the Beam Search
|
330 |
+
|
331 |
+
:param task : SubgroupDiscoveryTask
|
332 |
+
:return: ResultSet object
|
333 |
+
|
334 |
+
Notes
|
335 |
+
-----
|
336 |
+
The list_of_beam variable is: a list of list of descriptions. The i-th element of list_of_beam, at the end,
|
337 |
+
will contain the most interesting descriptions formed by i descriptors.
|
338 |
+
"""
|
339 |
+
if self.beam_width < task.result_set_size:
|
340 |
+
raise RuntimeError("Beam width is smaller than the result set size!")
|
341 |
+
|
342 |
+
list_of_beam = list()
|
343 |
+
list_of_beam.append(list())
|
344 |
+
list_of_beam[0] = [Description()]
|
345 |
+
|
346 |
+
depth = 0
|
347 |
+
while depth < task.depth:
|
348 |
+
list_of_beam.append(list())
|
349 |
+
current_min_quality = 1
|
350 |
+
for last_sg in list_of_beam[depth]:
|
351 |
+
ss = task.search_space.extract_search_space(
|
352 |
+
task.data, task.discretizer, current_description=last_sg
|
353 |
+
)
|
354 |
+
for sel in ss:
|
355 |
+
new_Descriptors = list(last_sg.descriptors)
|
356 |
+
new_Descriptors.append(sel)
|
357 |
+
new_description = Description(new_Descriptors)
|
358 |
+
|
359 |
+
# check for duplicates
|
360 |
+
if new_description.is_present_in(list_of_beam[depth + 1]):
|
361 |
+
continue
|
362 |
+
|
363 |
+
sg_belonging_feature = new_description.to_boolean_array(
|
364 |
+
task.data, set_attributes=True
|
365 |
+
)
|
366 |
+
# check min support
|
367 |
+
if (
|
368 |
+
new_description.size(task.data) < task.min_support
|
369 |
+
or new_description.size(task.data)
|
370 |
+
< task.min_support_ratio * task.data.shape[0]
|
371 |
+
or new_description.size(task.data)
|
372 |
+
> task.max_support_ratio * task.data.shape[0]
|
373 |
+
):
|
374 |
+
continue
|
375 |
+
# evaluate subgroup
|
376 |
+
if task.there_is_y_pred:
|
377 |
+
quality = task.qf(
|
378 |
+
y_true=task.data["y_true"],
|
379 |
+
method=method,
|
380 |
+
y_pred=task.data["y_pred"],
|
381 |
+
sensitive_features=sg_belonging_feature,
|
382 |
+
)
|
383 |
+
else:
|
384 |
+
quality = task.qf(
|
385 |
+
y_true=task.data["y_true"],
|
386 |
+
method=method,
|
387 |
+
sensitive_features=sg_belonging_feature,
|
388 |
+
)
|
389 |
+
if quality < task.min_quality:
|
390 |
+
continue
|
391 |
+
new_description.set_quality(quality)
|
392 |
+
|
393 |
+
if len(list_of_beam[depth + 1]) < self.beam_width:
|
394 |
+
list_of_beam[depth + 1].append(new_description)
|
395 |
+
if current_min_quality > quality:
|
396 |
+
current_min_quality = quality
|
397 |
+
elif quality > current_min_quality:
|
398 |
+
i = 0
|
399 |
+
while list_of_beam[depth + 1][i].quality != current_min_quality:
|
400 |
+
i = i + 1
|
401 |
+
list_of_beam[depth + 1][i] = new_description
|
402 |
+
current_min_quality = 1
|
403 |
+
for d in list_of_beam[depth + 1]:
|
404 |
+
if d.quality < current_min_quality:
|
405 |
+
current_min_quality = d.quality
|
406 |
+
depth += 1
|
407 |
+
|
408 |
+
subgroups = list()
|
409 |
+
for l in list_of_beam[1:]:
|
410 |
+
subgroups.extend(l)
|
411 |
+
subgroups.sort(reverse=True)
|
412 |
+
return ResultSet(subgroups[: task.result_set_size], task.data.shape[0])
|
413 |
+
|
414 |
+
|
415 |
+
class DSSD:
|
416 |
+
"""
|
417 |
+
This class implements the Diverse Subgroup Set Discovery algorithm (DSSD).
|
418 |
+
This algorithm is a variant of the Beam Search Algorithm that also take into account
|
419 |
+
the redundancy of the generated subgroups.
|
420 |
+
In this implementation a cover-based redundancy definition is used: roughly, the more tuples two subgroups
|
421 |
+
have in common, the more they are considered redundant.
|
422 |
+
This algorithm is described in details in the Van Leeuwen and Knobbe's paper "Diverse Subgroup Set Discovery".
|
423 |
+
"""
|
424 |
+
|
425 |
+
def __init__(self, beam_width=20, a=0.9):
|
426 |
+
"""
|
427 |
+
:param beam_width: int
|
428 |
+
:param a: float
|
429 |
+
this parameter correspond to the alpha parameter.
|
430 |
+
the more a is high, the less the subgroups redundancy is taken into account.
|
431 |
+
"""
|
432 |
+
if beam_width < 1:
|
433 |
+
raise RuntimeError("beam_width must be greater than 0")
|
434 |
+
if a < 0 or a > 1:
|
435 |
+
raise RuntimeError("a-parameter must be between 0 and 1")
|
436 |
+
self.beam_width = beam_width
|
437 |
+
self.a = 1 - a # for future calculations it is more practical to memorize 1-a
|
438 |
+
|
439 |
+
def execute(self, task, method="between_groups"):
|
440 |
+
"""
|
441 |
+
:param task: SubgroupDiscoveryTask object
|
442 |
+
:param method: string, method of evaluation of the subgroups, can be "between_groups" or "to_overall", as defined in fairlearn.metrics
|
443 |
+
:return: ResultSet object
|
444 |
+
|
445 |
+
Notes
|
446 |
+
-----
|
447 |
+
The algorithm is divided in three phases:
|
448 |
+
Phase 1: a modified beam search algorithm is performed to find a first non-redundant subset
|
449 |
+
Phase 2: Dominance Pruning - the algorithm try to generalize the subgroups finded in the previous phase
|
450 |
+
Phase 3: The subgroups to return are chosen among the sg-set resulting from the previous phase. Again, the
|
451 |
+
subgroups to put in the result set are chosen by taking into account both quality and diversity.
|
452 |
+
"""
|
453 |
+
if self.beam_width < task.result_set_size:
|
454 |
+
raise RuntimeError("Beam width is smaller than the result set size!")
|
455 |
+
|
456 |
+
# PHASE 1 - MODIFIED BEAM SEARCH
|
457 |
+
list_of_beam = list()
|
458 |
+
self.redundancy_aware_beam_search(list_of_beam, task, method)
|
459 |
+
|
460 |
+
# PHASE 2 - DOMINANCE PRUNING
|
461 |
+
subgroups = list()
|
462 |
+
subgroups.extend(list_of_beam[1])
|
463 |
+
if len(list_of_beam) > 2:
|
464 |
+
for l in list_of_beam[2:]:
|
465 |
+
self.dominance_pruning(l, subgroups, task)
|
466 |
+
if len(subgroups) < task.result_set_size:
|
467 |
+
subgroups.sort(reverse=True)
|
468 |
+
return ResultSet(subgroups, task.data.shape[0])
|
469 |
+
|
470 |
+
# PHASE 3 - SUBGROUP SELECTION
|
471 |
+
tuples_sg_matrix = []
|
472 |
+
quality_array = []
|
473 |
+
support_array = list()
|
474 |
+
for descr in subgroups:
|
475 |
+
tuples_sg_matrix.append(descr.to_boolean_array(task.data))
|
476 |
+
quality_array.append(descr.get_quality())
|
477 |
+
support_array.append(descr.support)
|
478 |
+
support_array = np.array(support_array)
|
479 |
+
quality_array = np.array(quality_array)
|
480 |
+
tuples_sg_matrix = np.array(tuples_sg_matrix)
|
481 |
+
final_sgs = []
|
482 |
+
self.beam_creation(
|
483 |
+
tuples_sg_matrix,
|
484 |
+
support_array,
|
485 |
+
quality_array,
|
486 |
+
subgroups,
|
487 |
+
final_sgs,
|
488 |
+
task.result_set_size,
|
489 |
+
)
|
490 |
+
|
491 |
+
final_sgs.sort(reverse=True)
|
492 |
+
return ResultSet(final_sgs, task.data.shape[0])
|
493 |
+
|
494 |
+
def redundancy_aware_beam_search(self, list_of_beam, task, method="between_groups"):
|
495 |
+
"""
|
496 |
+
Parameters
|
497 |
+
----------
|
498 |
+
list_of_beam : list
|
499 |
+
This list is empty at the beginning and this method will fill it with the beams of each level.
|
500 |
+
The beam of level i will be a list of Description objects where each description is composed of
|
501 |
+
i Descriptors.
|
502 |
+
task : SubgroupDiscoveryTask
|
503 |
+
method : string, method of evaluation of the subgroups, can be "between_groups" or "to_overall", as defined in fairlearn.metrics
|
504 |
+
|
505 |
+
Notes
|
506 |
+
-----
|
507 |
+
Starting from the beam of the previous level, all the candidates subgroups (descriptions) are generated.
|
508 |
+
After this, the beam of the current level is generated by calling the beam_creation method.
|
509 |
+
"""
|
510 |
+
list_of_beam.append(list())
|
511 |
+
list_of_beam[0] = [Description()]
|
512 |
+
|
513 |
+
depth = 0
|
514 |
+
while depth < task.depth:
|
515 |
+
# Generation of the beam with number of descriptors = depth+1
|
516 |
+
|
517 |
+
list_of_beam.append(list())
|
518 |
+
logging.debug("DEPTH: " + str(depth + 1))
|
519 |
+
|
520 |
+
tuples_sg_matrix = (
|
521 |
+
[]
|
522 |
+
) # boolean matrix where rows are candidates subgroups and columns are tuples of the dataset
|
523 |
+
# tuples_sg_matrix[i][j] == true iff subgroup i contain tuple j
|
524 |
+
quality_array = [] # will contain the quality of each candidate subgroup
|
525 |
+
support_array = [] # will contain the support of each candidate subgroup
|
526 |
+
decriptions_list = (
|
527 |
+
list()
|
528 |
+
) # will contain the description object of each candidate subgroup
|
529 |
+
|
530 |
+
# generation of candidates subgroups
|
531 |
+
for last_sg in list_of_beam[
|
532 |
+
depth
|
533 |
+
]: # for each subgroup in the previous beam
|
534 |
+
ss = task.search_space.extract_search_space(
|
535 |
+
task.data, task.discretizer, current_description=last_sg
|
536 |
+
)
|
537 |
+
|
538 |
+
# generation of all the possible extensions of the description last_sg
|
539 |
+
for sel in ss:
|
540 |
+
new_Descriptors = list(last_sg.descriptors)
|
541 |
+
new_Descriptors.append(sel)
|
542 |
+
new_description = Description(new_Descriptors)
|
543 |
+
|
544 |
+
# check for duplicates
|
545 |
+
if new_description.is_present_in(decriptions_list):
|
546 |
+
continue
|
547 |
+
|
548 |
+
sg_belonging_feature = new_description.to_boolean_array(
|
549 |
+
task.data, set_attributes=True
|
550 |
+
)
|
551 |
+
support = new_description.support
|
552 |
+
# check min support
|
553 |
+
if (
|
554 |
+
support < task.min_support
|
555 |
+
or support < task.min_support_ratio * task.data.shape[0]
|
556 |
+
or support > task.max_support_ratio * task.data.shape[0]
|
557 |
+
):
|
558 |
+
continue
|
559 |
+
# comparison with new descriptor alone
|
560 |
+
sel_feature = Description([sel]).to_boolean_array(task.data)
|
561 |
+
# evaluate subgroup
|
562 |
+
if task.there_is_y_pred:
|
563 |
+
quality = task.qf(
|
564 |
+
y_true=task.data["y_true"],
|
565 |
+
method=method,
|
566 |
+
y_pred=task.data["y_pred"],
|
567 |
+
sensitive_features=sg_belonging_feature,
|
568 |
+
)
|
569 |
+
### to evaluate
|
570 |
+
sel_quality = task.qf(
|
571 |
+
y_true=task.data["y_true"],
|
572 |
+
method=method,
|
573 |
+
y_pred=task.data["y_pred"],
|
574 |
+
sensitive_features=sel_feature,
|
575 |
+
)
|
576 |
+
else:
|
577 |
+
quality = task.qf(
|
578 |
+
y_true=task.data["y_true"],
|
579 |
+
method=method,
|
580 |
+
sensitive_features=sg_belonging_feature,
|
581 |
+
)
|
582 |
+
### to evaluate
|
583 |
+
sel_quality = task.qf(
|
584 |
+
y_true=task.data["y_true"],
|
585 |
+
method=method,
|
586 |
+
sensitive_features=sel_feature,
|
587 |
+
)
|
588 |
+
|
589 |
+
# if the quality of the new descriptor has deteriorated by merging the new descriptor
|
590 |
+
# with the current description, we do not add the new descriptor
|
591 |
+
### to evaluate
|
592 |
+
if quality < task.min_quality or quality < sel_quality:
|
593 |
+
continue
|
594 |
+
"""
|
595 |
+
# This is for allowing descriptions with negative quality in the first beam
|
596 |
+
if depth>0 and (quality < task.min_quality or quality < sel_quality):
|
597 |
+
continue
|
598 |
+
"""
|
599 |
+
|
600 |
+
# This code is for apriori discard those descriptions dominated by another descriptions not containing sensitive features
|
601 |
+
pruned_des = []
|
602 |
+
new_description.set_quality(quality)
|
603 |
+
self.dominance_pruning([new_description], pruned_des, task)
|
604 |
+
pruned_des.sort(reverse=True)
|
605 |
+
pruned_attr = pruned_des[0].get_attributes()
|
606 |
+
if task.sensitive_features is not None:
|
607 |
+
any_in = any(i in task.sensitive_features for i in pruned_attr)
|
608 |
+
if not any_in:
|
609 |
+
continue
|
610 |
+
|
611 |
+
# insert current subgroup in the candidates
|
612 |
+
tuples_sg_matrix.append(sg_belonging_feature)
|
613 |
+
support_array.append(support)
|
614 |
+
quality_array.append(quality)
|
615 |
+
decriptions_list.append(new_description)
|
616 |
+
|
617 |
+
if len(decriptions_list) == 0:
|
618 |
+
break
|
619 |
+
# CREATION OF THE BEAM
|
620 |
+
support_array = np.array(support_array)
|
621 |
+
quality_array = np.array(quality_array)
|
622 |
+
tuples_sg_matrix = np.array(tuples_sg_matrix)
|
623 |
+
self.beam_creation(
|
624 |
+
tuples_sg_matrix,
|
625 |
+
support_array,
|
626 |
+
quality_array,
|
627 |
+
decriptions_list,
|
628 |
+
list_of_beam[depth + 1],
|
629 |
+
self.beam_width,
|
630 |
+
)
|
631 |
+
|
632 |
+
for d in list_of_beam[depth + 1]:
|
633 |
+
logging.debug(d)
|
634 |
+
logging.debug(" ")
|
635 |
+
depth += 1
|
636 |
+
|
637 |
+
def dominance_pruning(self, subgroups, pruned_sgs, task, method="between_groups"):
|
638 |
+
"""
|
639 |
+
Parameters
|
640 |
+
----------
|
641 |
+
subgroups : list
|
642 |
+
list of Description objects to try to generalize
|
643 |
+
pruned_sgs : list
|
644 |
+
the generalized subgroups (description objects) are inserted in this list
|
645 |
+
task : SubgroupDiscoveryTask
|
646 |
+
method : string, method of evaluation of the subgroups, can be "between_groups" or "to_overall", as defined in fairlearn.metrics
|
647 |
+
"""
|
648 |
+
|
649 |
+
for desc in subgroups:
|
650 |
+
Descriptors = desc.get_Descriptors()
|
651 |
+
|
652 |
+
generalizable = False
|
653 |
+
for i in range(len(Descriptors)):
|
654 |
+
# creation of a generalized description by excluding the i-th descriptor
|
655 |
+
new_sel_list = []
|
656 |
+
for j in range(len(Descriptors)):
|
657 |
+
if i != j:
|
658 |
+
new_sel_list.append(Descriptors[j])
|
659 |
+
new_des = Description(new_sel_list)
|
660 |
+
|
661 |
+
sg_belonging_feature = new_des.to_boolean_array(
|
662 |
+
task.data, set_attributes=True
|
663 |
+
)
|
664 |
+
if task.there_is_y_pred:
|
665 |
+
quality = task.qf(
|
666 |
+
y_true=task.data["y_true"],
|
667 |
+
method=method,
|
668 |
+
y_pred=task.data["y_pred"],
|
669 |
+
sensitive_features=sg_belonging_feature,
|
670 |
+
)
|
671 |
+
else:
|
672 |
+
quality = task.qf(
|
673 |
+
y_true=task.data["y_true"],
|
674 |
+
method=method,
|
675 |
+
sensitive_features=sg_belonging_feature,
|
676 |
+
)
|
677 |
+
|
678 |
+
if quality >= desc.get_quality():
|
679 |
+
generalizable = True
|
680 |
+
if new_des.is_present_in(pruned_sgs):
|
681 |
+
continue
|
682 |
+
new_des.set_quality(quality)
|
683 |
+
pruned_sgs.append(new_des)
|
684 |
+
|
685 |
+
if generalizable == False:
|
686 |
+
pruned_sgs.append(desc)
|
687 |
+
|
688 |
+
def beam_creation(
|
689 |
+
self,
|
690 |
+
tuples_sg_matrix,
|
691 |
+
support_array,
|
692 |
+
quality_array,
|
693 |
+
decriptions_list,
|
694 |
+
beam,
|
695 |
+
beam_width,
|
696 |
+
):
|
697 |
+
"""
|
698 |
+
Parameters
|
699 |
+
----------
|
700 |
+
tuples_sg_matrix : numpy array
|
701 |
+
this is a boolean matrix. The rows are candidate subgroups and the columns are the istances of the dataset
|
702 |
+
support_array : numpy array
|
703 |
+
contains the support of each candidate subgroup
|
704 |
+
quality_array : numpy array
|
705 |
+
contains the quality of each candidate subgroup
|
706 |
+
decriptions_list : list
|
707 |
+
list of Description objects of the candidates subgroups
|
708 |
+
beam : list
|
709 |
+
list of Description objects. This parameter is empty at the beginning and this method will feel it with the
|
710 |
+
selected subgroups
|
711 |
+
beam_width : int
|
712 |
+
|
713 |
+
Notes
|
714 |
+
-----
|
715 |
+
In the code there is a variable called a_tothe_c_array. This variable represents a vector of weights,
|
716 |
+
each weight refers to a tuple of the dataset. Every time that a subgroup sg is selected (inserted to the beam),
|
717 |
+
all the weights relatives to the tuples belonging to sg are updated. In particular, they are decreased by
|
718 |
+
multiplying them by a.
|
719 |
+
At each round of the for loop, the subgroup with the highest product between its quality and its weight is
|
720 |
+
selected. The weight of a subgroup is obtained by averaging the weights of all the tuples it contains.
|
721 |
+
"""
|
722 |
+
|
723 |
+
if len(decriptions_list) <= beam_width:
|
724 |
+
for i in range(len(decriptions_list)):
|
725 |
+
descr = decriptions_list[i]
|
726 |
+
descr.set_quality(quality_array[i])
|
727 |
+
# isertion of the description in the beam
|
728 |
+
beam.append(descr)
|
729 |
+
return
|
730 |
+
|
731 |
+
# sort in a way that, in case of equal quality, groups with highter support are preferred
|
732 |
+
sorted_index = np.argsort(support_array)[::-1]
|
733 |
+
support_array = support_array[sorted_index]
|
734 |
+
quality_array = quality_array[sorted_index]
|
735 |
+
tuples_sg_matrix = tuples_sg_matrix[sorted_index]
|
736 |
+
decriptions_list.sort(key=lambda x: x.support, reverse=True)
|
737 |
+
|
738 |
+
# selection of the sg with highest quality
|
739 |
+
index_of_max = np.argmax(quality_array)
|
740 |
+
descr = decriptions_list[index_of_max]
|
741 |
+
descr.set_quality(quality_array[index_of_max])
|
742 |
+
beam.append(descr)
|
743 |
+
quality_array[
|
744 |
+
index_of_max
|
745 |
+
] = 0 # the quality of the selected sg is set to 0 in the quality_array,
|
746 |
+
# in this way this subgroup will never be choosen again
|
747 |
+
|
748 |
+
a_tothe_c_array = np.ones(tuples_sg_matrix.shape[1])
|
749 |
+
|
750 |
+
num_iterations = min(beam_width, support_array.size)
|
751 |
+
for i in range(1, num_iterations):
|
752 |
+
# a_tothe_c updating
|
753 |
+
best_sg_arr = tuples_sg_matrix[index_of_max]
|
754 |
+
best_sg_arr = 1 - self.a * best_sg_arr
|
755 |
+
# updating
|
756 |
+
a_tothe_c_array = np.multiply(a_tothe_c_array, best_sg_arr)
|
757 |
+
|
758 |
+
# weight creation
|
759 |
+
alpha_matrix = np.multiply(a_tothe_c_array, tuples_sg_matrix)
|
760 |
+
weights = np.divide(np.sum(alpha_matrix, axis=1), support_array)
|
761 |
+
|
762 |
+
# selection of the sg with highest quality
|
763 |
+
weighted_quality_array = np.multiply(quality_array, weights)
|
764 |
+
index_of_max = np.argmax(weighted_quality_array)
|
765 |
+
descr = decriptions_list[index_of_max]
|
766 |
+
descr.set_quality(quality_array[index_of_max])
|
767 |
+
# isertion of the description in the beam
|
768 |
+
beam.append(descr)
|
769 |
+
quality_array[index_of_max] = 0
|
fairsd/fairsd/discretization.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class MDLP:
|
7 |
+
"""
|
8 |
+
This class find the cut points to discretize a numerical feature using the Fayyad and Irani approach (MDL principle).
|
9 |
+
|
10 |
+
In thi class the constructor allows to specify the minimum possible size for a group (min_groupSize).
|
11 |
+
This parameter is interpret as an hard constraint: only cut points that produce buckets larger than
|
12 |
+
min_groupSize will be returned.
|
13 |
+
Another parameter that is possible to set is "force". If this parameter is setted to True the findCutPoints
|
14 |
+
method will return at least one cut point. This unless even a cut point can be found which produces two subgroups
|
15 |
+
with a size greater than min_groupSize and with class partition entropy less than the entropy of the entire set.
|
16 |
+
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, min_groupsize=1, force=False):
|
20 |
+
"""
|
21 |
+
:param min_groupsize: int
|
22 |
+
:param force: boolean
|
23 |
+
"""
|
24 |
+
self.min_groupsize = min_groupsize
|
25 |
+
self.force = force
|
26 |
+
|
27 |
+
def findCutPoints(self, x, y):
|
28 |
+
"""
|
29 |
+
:param x: list, numpy array or pandas Series that contains the values of the numeric feature to discretize
|
30 |
+
:param y: list, numpy array or pandas Series that contains binary values (0 or 1) representing the class label
|
31 |
+
|
32 |
+
:return: list of (ascending) ordered cut points
|
33 |
+
if the function returns the cut points [c1, c2, ..., cn], with c1<c2<...<cn, the feature can be discretized
|
34 |
+
by creating n+1 buckets: (-infinite, c1], (c1, c2], ..., (cn-1, cn], (cn, +infinite)
|
35 |
+
"""
|
36 |
+
df = pd.DataFrame({"x": x, "y": y})
|
37 |
+
total_size = df.shape[0]
|
38 |
+
|
39 |
+
df = df.groupby("x")["y"].agg(["sum", "count"])
|
40 |
+
df.reset_index(inplace=True)
|
41 |
+
total_sum = df["sum"].sum()
|
42 |
+
df["prop"] = df["sum"] / df["count"]
|
43 |
+
|
44 |
+
cut_points = self.find_partitions(df, total_size, total_sum, self.force)
|
45 |
+
cut_points.sort()
|
46 |
+
return cut_points
|
47 |
+
|
48 |
+
def find_partitions(self, df, total_size, total_sum, force=False):
|
49 |
+
"""
|
50 |
+
This is a private class function. It works in a recursive manner.
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
df: pandas.DataFrame
|
55 |
+
this dataframe contains, for each distinct values of the feature to discretize, the number of positive
|
56 |
+
instances (called sum), the total number of instances (count) and the proportion
|
57 |
+
positive_instances/total_num_of_instances (prop). See the findCutPoints method for details.
|
58 |
+
total_size: int
|
59 |
+
represent the total number of instances of the x feature in the current partition.
|
60 |
+
total_sum: int
|
61 |
+
represent the total number of positive instances of the x feature in the current partition.
|
62 |
+
force: boolean
|
63 |
+
force the method to return at list one cut point, exept for some exceptional cases described in the class description.
|
64 |
+
|
65 |
+
Returns
|
66 |
+
-------
|
67 |
+
list of int:
|
68 |
+
list of cut points
|
69 |
+
"""
|
70 |
+
|
71 |
+
sum = 0
|
72 |
+
count = 0
|
73 |
+
min_cpe = total_size # Class Partition Entropy (CPE)
|
74 |
+
partition_index = 0
|
75 |
+
partition_x = 0
|
76 |
+
partition_sum = 0
|
77 |
+
partition_count = 0
|
78 |
+
|
79 |
+
# find best candidate cut point
|
80 |
+
for i in range(0, df.shape[0] - 1):
|
81 |
+
loc = df.iloc[i]
|
82 |
+
sum += loc["sum"]
|
83 |
+
count += loc["count"]
|
84 |
+
if (
|
85 |
+
loc["prop"] == df.iloc[i + 1]["prop"]
|
86 |
+
or count < self.min_groupsize
|
87 |
+
or (total_size - count) < self.min_groupsize
|
88 |
+
):
|
89 |
+
continue
|
90 |
+
|
91 |
+
# cakculate CPE cut point
|
92 |
+
pc1s0 = sum / count # probability of class 1 in subgroup 0
|
93 |
+
pc0s0 = 1 - pc1s0
|
94 |
+
pc1s0_ = pc1s0 if pc1s0 != 0 else 1
|
95 |
+
pc0s0_ = pc0s0 if pc0s0 != 0 else 1
|
96 |
+
entS0 = -(pc1s0 * math.log2(pc1s0_) + pc0s0 * math.log2(pc0s0_))
|
97 |
+
|
98 |
+
pc1s1 = (total_sum - sum) / (total_size - count)
|
99 |
+
pc0s1 = 1 - pc1s1
|
100 |
+
pc1s1_ = pc1s1 if pc1s1 != 0 else 1
|
101 |
+
pc0s1_ = pc0s1 if pc0s1 != 0 else 1
|
102 |
+
entS1 = -(pc1s1 * math.log2(pc1s1_) + pc0s1 * math.log2(pc0s1_))
|
103 |
+
|
104 |
+
cpe = (count / total_size) * entS0 + (
|
105 |
+
(total_size - count) / total_size
|
106 |
+
) * entS1
|
107 |
+
|
108 |
+
if cpe < min_cpe:
|
109 |
+
min_cpe = cpe
|
110 |
+
partition_index = i
|
111 |
+
partition_x = loc["x"]
|
112 |
+
partition_sum = sum
|
113 |
+
partition_count = count
|
114 |
+
partition_entS0 = entS0
|
115 |
+
partition_entS1 = entS1
|
116 |
+
if min_cpe == total_size:
|
117 |
+
return []
|
118 |
+
|
119 |
+
# test MDLP condition
|
120 |
+
pc1 = total_sum / total_size
|
121 |
+
pc0 = 1 - pc1
|
122 |
+
pc1_ = pc1 if pc1 != 0 else 1
|
123 |
+
pc0_ = pc0 if pc0 != 0 else 1
|
124 |
+
entS = -(pc1 * math.log2(pc1_) + pc0 * math.log2(pc0_))
|
125 |
+
gain = entS - min_cpe
|
126 |
+
|
127 |
+
remained_pos = total_sum - partition_sum
|
128 |
+
total_rem = total_size - partition_count
|
129 |
+
c0 = 2 if (partition_sum == 0 or partition_sum == partition_count) else 1
|
130 |
+
c1 = 2 if (remained_pos == 0 or remained_pos == total_rem) else 1
|
131 |
+
delta = math.log2(9) - c0 * partition_entS0 - c1 * partition_entS1
|
132 |
+
|
133 |
+
delta = math.log2(7) - (2 * entS - c0 * partition_entS0 - c1 * partition_entS1)
|
134 |
+
|
135 |
+
if gain <= ((math.log2(total_size - 1) + delta) / total_size):
|
136 |
+
if force:
|
137 |
+
return [partition_x]
|
138 |
+
return []
|
139 |
+
|
140 |
+
# recoursive splitting
|
141 |
+
left_partitions = self.find_partitions(
|
142 |
+
df.iloc[: (partition_index + 1)], partition_count, partition_sum
|
143 |
+
)
|
144 |
+
right_partitions = self.find_partitions(
|
145 |
+
df.iloc[(partition_index + 1) :],
|
146 |
+
(total_size - partition_count),
|
147 |
+
(total_sum - partition_sum),
|
148 |
+
)
|
149 |
+
a = [partition_x] + left_partitions + right_partitions
|
150 |
+
return a
|
151 |
+
|
152 |
+
|
153 |
+
class EqualFrequency:
|
154 |
+
"""
|
155 |
+
This class find the cut points to discretize a numerical feature using an approximate equal frequency discretization.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self, min_bin_size=1, num_bins=0):
|
159 |
+
"""
|
160 |
+
Parameters
|
161 |
+
__________
|
162 |
+
num_bins : int
|
163 |
+
this number is to interpret as the maximum number of bins that will be generated.
|
164 |
+
If this parameter is not specified (0 by deafault), it will be automatically determined
|
165 |
+
based on min_bin_size parameter.
|
166 |
+
min_bin_size : int
|
167 |
+
Represent the minimum size that a bin can have.
|
168 |
+
|
169 |
+
Notes
|
170 |
+
-----
|
171 |
+
If the number of bins has to be automatically determined, this will be chosen so that each bin has an average
|
172 |
+
size of 1.2 * min_bin_size. Let's call this number automatic_nbin
|
173 |
+
( automatic_nbin = int(x.size/(self.min_group_size*1.2)) )
|
174 |
+
If instead the bin number is specified in the constructor (num_bins parameter), the number of bins that will
|
175 |
+
actually be generated will be, at most, equal to the minimum between num_bins and automatic_nbin. This is
|
176 |
+
to ensure that the constraint given by parameter min_bin_size is respected.
|
177 |
+
"""
|
178 |
+
self.min_group_size = min_bin_size
|
179 |
+
self.num_bins = num_bins
|
180 |
+
|
181 |
+
def findCutPoints(self, x):
|
182 |
+
"""
|
183 |
+
:param x: numpy array or pandas series
|
184 |
+
:return: list of (ascending) ordered cut points
|
185 |
+
"""
|
186 |
+
# determination of the number of bins
|
187 |
+
if self.num_bins > 1:
|
188 |
+
num_bins = min(self.num_bins, int(x.size / (self.min_group_size * 1.2)))
|
189 |
+
else:
|
190 |
+
num_bins = int(x.size / (self.min_group_size * 1.2))
|
191 |
+
if num_bins < 2:
|
192 |
+
return []
|
193 |
+
|
194 |
+
avg_group_size = x.size / num_bins
|
195 |
+
|
196 |
+
if isinstance(x, pd.Series):
|
197 |
+
x = x.to_numpy()
|
198 |
+
val, counts = np.unique(x, return_counts=True)
|
199 |
+
|
200 |
+
quantiles = [] # actually this array will contains quantiles * x.size
|
201 |
+
sum = 0
|
202 |
+
for c in counts:
|
203 |
+
sum = sum + c
|
204 |
+
quantiles.append(sum)
|
205 |
+
|
206 |
+
up_index = 0
|
207 |
+
low_index = 0
|
208 |
+
current_quantile = avg_group_size # again, is quantile * x.size
|
209 |
+
|
210 |
+
cut_indexes = []
|
211 |
+
# for each expected quantile, find his approximation
|
212 |
+
while current_quantile < (sum - avg_group_size / 2): # sum is equal to x.size
|
213 |
+
up_index, low_index = self.findApproximationIndexex(
|
214 |
+
up_index, low_index, current_quantile, quantiles
|
215 |
+
)
|
216 |
+
"""
|
217 |
+
# this commented code use the cut point that best approximates the expected quantile
|
218 |
+
num_up = quantiles[up_index] - current_quantile
|
219 |
+
num_low = current_quantile - quantiles[low_index]
|
220 |
+
|
221 |
+
if num_up < num_low:
|
222 |
+
if up_index not in cut_indexes:
|
223 |
+
cut_indexes.append(up_index)
|
224 |
+
else:
|
225 |
+
if low_index not in cut_indexes:
|
226 |
+
cut_indexes.append(low_index)
|
227 |
+
"""
|
228 |
+
|
229 |
+
### here we always choose an approximation by excess of the expected quantile.
|
230 |
+
# This consistency can help create more similarly sized bins
|
231 |
+
if up_index not in cut_indexes:
|
232 |
+
cut_indexes.append(up_index)
|
233 |
+
current_quantile += avg_group_size
|
234 |
+
###
|
235 |
+
|
236 |
+
cut_points = []
|
237 |
+
bins_size = []
|
238 |
+
last_quantile = 0
|
239 |
+
for i in cut_indexes:
|
240 |
+
cut_points.append(val[i])
|
241 |
+
bins_size.append(quantiles[i] - last_quantile)
|
242 |
+
last_quantile = quantiles[i]
|
243 |
+
bins_size.append(sum - last_quantile)
|
244 |
+
|
245 |
+
# The bins with size <= min_group_size will be merged with one of the other adiacent bins
|
246 |
+
self.mergeSmallBins(cut_points, bins_size)
|
247 |
+
return cut_points
|
248 |
+
|
249 |
+
def findApproximationIndexex(
|
250 |
+
self, up_index, low_index, current_quantile, quantiles
|
251 |
+
):
|
252 |
+
"""
|
253 |
+
:param up_index: int
|
254 |
+
:param low_index: int
|
255 |
+
:param current_quantile: float
|
256 |
+
:param quantiles: list of float
|
257 |
+
|
258 |
+
:return: (int, int)
|
259 |
+
"""
|
260 |
+
|
261 |
+
new_up_i = up_index
|
262 |
+
while new_up_i < len(quantiles):
|
263 |
+
if quantiles[new_up_i] < current_quantile:
|
264 |
+
new_up_i += 1
|
265 |
+
else:
|
266 |
+
break
|
267 |
+
new_low_i = low_index
|
268 |
+
if new_up_i > up_index:
|
269 |
+
new_low_i = new_up_i - 1
|
270 |
+
return new_up_i, new_low_i
|
271 |
+
|
272 |
+
def mergeSmallBins(self, cut_points, bins_size):
|
273 |
+
"""
|
274 |
+
:param cut_points: list
|
275 |
+
:param bins_size: list
|
276 |
+
:return: void
|
277 |
+
"""
|
278 |
+
# find the smallest bin
|
279 |
+
min_smallbin = self.min_group_size
|
280 |
+
min_index = 0
|
281 |
+
for i in range(len(bins_size)):
|
282 |
+
if bins_size[i] < min_smallbin:
|
283 |
+
min_smallbin = bins_size[i]
|
284 |
+
min_index = i
|
285 |
+
|
286 |
+
# check if the smallest bin finded has size lower than the min_group_size
|
287 |
+
if min_smallbin == self.min_group_size:
|
288 |
+
return
|
289 |
+
|
290 |
+
previous_size = 0
|
291 |
+
next_size = 0
|
292 |
+
if min_index > 0:
|
293 |
+
previous_size = bins_size[min_index - 1]
|
294 |
+
if min_index < (len(bins_size) - 1):
|
295 |
+
next_size = bins_size[min_index + 1]
|
296 |
+
|
297 |
+
if previous_size == 0 and next_size == 0:
|
298 |
+
return
|
299 |
+
|
300 |
+
if (previous_size == 0) or (next_size > 0 and next_size < previous_size):
|
301 |
+
bins_size[min_index + 1] = bins_size[min_index + 1] + bins_size[min_index]
|
302 |
+
cut_points.pop(min_index)
|
303 |
+
bins_size.pop(min_index)
|
304 |
+
else:
|
305 |
+
bins_size[min_index - 1] = bins_size[min_index - 1] + bins_size[min_index]
|
306 |
+
cut_points.pop(min_index - 1)
|
307 |
+
bins_size.pop(min_index)
|
308 |
+
|
309 |
+
self.mergeSmallBins(cut_points, bins_size)
|
310 |
+
|
311 |
+
|
312 |
+
class EqualWidth:
|
313 |
+
"""
|
314 |
+
This class find the cut points to discretize a numerical feature using the equal width discretization.
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, min_bins_size=1, num_bins=0):
|
318 |
+
"""
|
319 |
+
Parameters
|
320 |
+
__________
|
321 |
+
num_bins : int
|
322 |
+
this number is to interpret as the maximum number of bins that will be generated.
|
323 |
+
If this parameter is not specified (0 by deafault), it will be automatically determined.
|
324 |
+
min_bins_size : int
|
325 |
+
|
326 |
+
Notes
|
327 |
+
-----
|
328 |
+
min_bins_size and num_bins parameters are to be interpreted as for the EqualFreq class
|
329 |
+
"""
|
330 |
+
self.min_group_size = min_bins_size
|
331 |
+
self.num_bins = num_bins
|
332 |
+
|
333 |
+
def findCutPoints(self, x):
|
334 |
+
"""
|
335 |
+
|
336 |
+
:param x: numpy array or pandas series
|
337 |
+
:return: list of (ascending) ordered cut points
|
338 |
+
"""
|
339 |
+
if self.num_bins > 1:
|
340 |
+
num_bins = min(self.num_bins, int(x.size / (self.min_group_size * 1.2)))
|
341 |
+
else:
|
342 |
+
num_bins = int(x.size / (self.min_group_size * 1.2))
|
343 |
+
if num_bins < 2:
|
344 |
+
return []
|
345 |
+
|
346 |
+
if isinstance(x, pd.Series):
|
347 |
+
x = x.to_numpy()
|
348 |
+
minim = x.min()
|
349 |
+
bin_width = (x.max() - minim) / num_bins
|
350 |
+
cut_points = []
|
351 |
+
current_cut = minim + bin_width
|
352 |
+
for i in range(1, num_bins):
|
353 |
+
cut_points.append(current_cut)
|
354 |
+
current_cut = current_cut + bin_width
|
355 |
+
return cut_points
|
fairsd/fairsd/qualitymeasures.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class QualityFunction(ABC):
|
5 |
+
"""Abstract class.
|
6 |
+
|
7 |
+
If the user wants to create a customized quality function, it is recommended to extend this class
|
8 |
+
"""
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def evaluate(self, y_true, y_pred, sensitive_features):
|
12 |
+
"""Evaluate the quality of a description.
|
13 |
+
|
14 |
+
Parameters
|
15 |
+
----------
|
16 |
+
y_true : list, pandas.series or numpy array
|
17 |
+
y_pred : list, pandas.series or numpy array
|
18 |
+
sensitive_features : list, pandas.series or numpy array
|
19 |
+
|
20 |
+
Returns
|
21 |
+
-------
|
22 |
+
double
|
23 |
+
Real number indicating rhe calculated quality.
|
24 |
+
"""
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
class EqualOpportunityDiff(QualityFunction):
|
29 |
+
def evaluate(self, y_true, y_pred, sensitive_features):
|
30 |
+
s0y_true = (y_true & ~sensitive_features).sum()
|
31 |
+
s0y_true = 1 if s0y_true == 0 else s0y_true
|
32 |
+
p_s0 = (y_true & y_pred & ~sensitive_features).sum() / s0y_true
|
33 |
+
s1y_true = (y_true & sensitive_features).sum()
|
34 |
+
s1y_true = 1 if s1y_true == 0 else s1y_true
|
35 |
+
p_s1 = (y_true & y_pred & sensitive_features).sum() / s1y_true
|
36 |
+
return p_s0 - p_s1
|
fairsd/fairsd/searchspace.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .discretization import MDLP, EqualFrequency, EqualWidth
|
2 |
+
from .sgdescription import Description
|
3 |
+
from .sgdescription import Descriptor
|
4 |
+
|
5 |
+
|
6 |
+
class SearchSpace:
|
7 |
+
"""Will contain the set of all the descriptors to take into consideration for creating a Description.
|
8 |
+
|
9 |
+
Attributes
|
10 |
+
----------
|
11 |
+
nominal_Descriptors : list of Descriptor(s)
|
12 |
+
will contain all possible non numeric Descriptors.
|
13 |
+
numeric_Descriptors : list of Descriptor(s)
|
14 |
+
will contain all possible numeric Descriptors.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dataset,
|
20 |
+
ignore=None,
|
21 |
+
nominal_features=None,
|
22 |
+
numeric_features=None,
|
23 |
+
dynamic_discretization=False,
|
24 |
+
discretizer=None,
|
25 |
+
sensitiive_features=None,
|
26 |
+
):
|
27 |
+
"""
|
28 |
+
:param dataset : pandas.DataFrame
|
29 |
+
:param ignore : list of String(s)
|
30 |
+
list the attributes to not take into consideration for creating the search space.
|
31 |
+
:param nominal_features : list of Strings that contain a subgroup of nominal features
|
32 |
+
:param numeric_features : list of Strings that contain a subgroup of numeric features
|
33 |
+
:param dynamic_discretization: boolean
|
34 |
+
if dinamic_discretization is true the numerical features will be discretizized in the extract_search_space()
|
35 |
+
method, otherwise the numerical features will be discretizized here, during the inizialization.
|
36 |
+
:param discretizer: Discretizer object
|
37 |
+
:param sensitiive_features: List of str
|
38 |
+
|
39 |
+
Notes
|
40 |
+
-----
|
41 |
+
The type of features not present in numeric_features and in nominal_features will be deduced based on the attributes
|
42 |
+
|
43 |
+
"""
|
44 |
+
self.sensitiive_features = sensitiive_features
|
45 |
+
if ignore is None:
|
46 |
+
ignore = []
|
47 |
+
if nominal_features is None:
|
48 |
+
nominal_features = []
|
49 |
+
if numeric_features is None:
|
50 |
+
numeric_features = []
|
51 |
+
self.nominal_Descriptors = []
|
52 |
+
self.numeric_Descriptors = []
|
53 |
+
# create nominal Descriptors
|
54 |
+
# nominal features with explicit type phassed
|
55 |
+
for col in nominal_features:
|
56 |
+
if col not in ignore:
|
57 |
+
values = dataset[col].unique()
|
58 |
+
for x in values:
|
59 |
+
self.nominal_Descriptors.append(Descriptor(col, attribute_value=x))
|
60 |
+
# nominal features without explicit type phassed
|
61 |
+
dtypes_subs = dataset.select_dtypes(exclude=["number"])
|
62 |
+
for col in dtypes_subs.columns:
|
63 |
+
if col not in ignore + nominal_features + numeric_features:
|
64 |
+
values = dataset[col].unique()
|
65 |
+
for x in values:
|
66 |
+
self.nominal_Descriptors.append(Descriptor(col, attribute_value=x))
|
67 |
+
|
68 |
+
# numerical Descriptors
|
69 |
+
if dynamic_discretization:
|
70 |
+
for col in numeric_features:
|
71 |
+
if col not in ignore:
|
72 |
+
self.numeric_Descriptors.append(
|
73 |
+
Descriptor(col, to_discretize=True, is_numeric=True)
|
74 |
+
)
|
75 |
+
dtypes_subs = dataset.select_dtypes(include=["number"])
|
76 |
+
for col in dtypes_subs.columns:
|
77 |
+
if col not in ignore + nominal_features + numeric_features:
|
78 |
+
self.numeric_Descriptors.append(
|
79 |
+
Descriptor(col, to_discretize=True, is_numeric=True)
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
for col in numeric_features:
|
83 |
+
if col not in ignore:
|
84 |
+
self.numeric_Descriptors.extend(
|
85 |
+
discretizer.discretize(dataset, Description(), col)
|
86 |
+
)
|
87 |
+
dtypes_subs = dataset.select_dtypes(include=["number"])
|
88 |
+
for col in dtypes_subs.columns:
|
89 |
+
if col not in ignore + nominal_features + numeric_features:
|
90 |
+
self.numeric_Descriptors.extend(
|
91 |
+
discretizer.discretize(dataset, Description(), col)
|
92 |
+
)
|
93 |
+
|
94 |
+
def extract_search_space(self, dataset, discretizer, current_description=None):
|
95 |
+
"""This method return the subset of the search space to explore
|
96 |
+
for expanding the description "current_description".
|
97 |
+
|
98 |
+
All descriptors containing attributes present in the current_description will be removed
|
99 |
+
from the returned search space subset.
|
100 |
+
|
101 |
+
Parameters
|
102 |
+
----------
|
103 |
+
dataset : pandas.Dataframe
|
104 |
+
discretizer : Discretizer
|
105 |
+
current_description : Description
|
106 |
+
|
107 |
+
Rreturns
|
108 |
+
--------
|
109 |
+
list of Descriptors
|
110 |
+
the subset of the search space to explore
|
111 |
+
|
112 |
+
Notes
|
113 |
+
-----
|
114 |
+
"Descriptors" list will contain the subset of the search space to return. All the Descriptors not yet
|
115 |
+
discretized in the original search space, will be discretized before being inserted in the "Descriptors" list.
|
116 |
+
"""
|
117 |
+
if current_description is None:
|
118 |
+
current_description = Description()
|
119 |
+
to_exclude = current_description.get_attributes()
|
120 |
+
if len(to_exclude) == 0 and self.sensitiive_features is not None:
|
121 |
+
to_exclude = [
|
122 |
+
i for i in list(dataset.columns) if i not in self.sensitiive_features
|
123 |
+
]
|
124 |
+
Descriptors = []
|
125 |
+
for Descriptor in self.nominal_Descriptors:
|
126 |
+
if Descriptor.attribute_name not in to_exclude:
|
127 |
+
Descriptors.append(Descriptor)
|
128 |
+
for Descriptor in self.numeric_Descriptors:
|
129 |
+
if Descriptor.attribute_name not in to_exclude:
|
130 |
+
if Descriptor.is_to_discretize():
|
131 |
+
Descriptors.extend(
|
132 |
+
discretizer.discretize(
|
133 |
+
dataset,
|
134 |
+
current_description,
|
135 |
+
Descriptor.get_attribute_name(),
|
136 |
+
)
|
137 |
+
)
|
138 |
+
else:
|
139 |
+
Descriptors.append(Descriptor)
|
140 |
+
return Descriptors
|
141 |
+
|
142 |
+
|
143 |
+
class Discretizer:
|
144 |
+
"""Class for the discretization of the numeric attributes."""
|
145 |
+
|
146 |
+
def __init__(self, discretization_type, target=None, min_groupsize=1, num_bins=6):
|
147 |
+
"""
|
148 |
+
|
149 |
+
:param discretization_type : enumerated
|
150 |
+
can be "mdlp" or "equalfreq" or "equalwidth"
|
151 |
+
:param target: String, optional
|
152 |
+
this parameter is needed only for supervised discretizations (mdlp)
|
153 |
+
:param min_groupsize: int, optional
|
154 |
+
discretize() method will create only subgroups with size >= min_groupsize.
|
155 |
+
"""
|
156 |
+
self.discretization_type = discretization_type
|
157 |
+
if discretization_type == "mdlp":
|
158 |
+
self.supervised = True
|
159 |
+
self.discretizer = MDLP(min_groupsize, force=True)
|
160 |
+
self.target = target
|
161 |
+
elif discretization_type == "equalfreq":
|
162 |
+
self.supervised = False
|
163 |
+
self.discretizer = EqualFrequency(min_groupsize, num_bins)
|
164 |
+
elif discretization_type == "equalwidth":
|
165 |
+
self.supervised = False
|
166 |
+
self.discretizer = EqualWidth(min_groupsize, num_bins)
|
167 |
+
else:
|
168 |
+
raise RuntimeError(
|
169 |
+
'discretization_type must be "mdlp" OR "equalfreq" OR "equalwidth"'
|
170 |
+
)
|
171 |
+
|
172 |
+
def discretize(self, data, description, feature): #### to test
|
173 |
+
"""
|
174 |
+
Parameters
|
175 |
+
----------
|
176 |
+
data: pandas.DataFrame
|
177 |
+
The dataset.
|
178 |
+
description: Description
|
179 |
+
The discretization will be based only on those tuples of the dataset that match the description.
|
180 |
+
feature: String
|
181 |
+
Is the name of the numeric attribute of the dataset to discretize.
|
182 |
+
|
183 |
+
Returns
|
184 |
+
-------
|
185 |
+
list of Descriptor(s)
|
186 |
+
Will be created and returned (in a list) one Descriptor for each bin created in the discretization phase.
|
187 |
+
"""
|
188 |
+
subset = data[description.to_boolean_array(data)]
|
189 |
+
x = subset[feature]
|
190 |
+
if self.supervised:
|
191 |
+
y = subset[self.target]
|
192 |
+
cut_points = self.discretizer.findCutPoints(x, y)
|
193 |
+
else:
|
194 |
+
cut_points = self.discretizer.findCutPoints(x)
|
195 |
+
|
196 |
+
Descriptors = []
|
197 |
+
if len(cut_points) < 1:
|
198 |
+
return Descriptors
|
199 |
+
|
200 |
+
Descriptors.append(
|
201 |
+
Descriptor(feature, low_bound=None, up_bound=None, is_numeric=True)
|
202 |
+
)
|
203 |
+
for cp in cut_points:
|
204 |
+
Descriptors[-1].up_bound = cp
|
205 |
+
Descriptors.append(
|
206 |
+
Descriptor(feature, low_bound=cp, up_bound=None, is_numeric=True)
|
207 |
+
)
|
208 |
+
return Descriptors
|
fairsd/fairsd/sgdescription.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
# The following class is no more used in the current version
|
5 |
+
# class BinaryTarget:
|
6 |
+
# """Contains the target for the subgroup discovery task.
|
7 |
+
#
|
8 |
+
# The target can be boolean or Nominal. Only the value contained in the parameter "target_value" will be considered as
|
9 |
+
# the true value. In this way also if the target is nominal, it will still be treated as a boolean.
|
10 |
+
# """
|
11 |
+
#
|
12 |
+
# def __init__(self, y_true, y_pred=None, target_value=False):
|
13 |
+
# """
|
14 |
+
# Parameters
|
15 |
+
# ----------
|
16 |
+
# y_true : string
|
17 |
+
# Contains the label of the target.
|
18 |
+
# dataset: pandas.DataFrame
|
19 |
+
# The dataset is required because it will be checked if the others parameters are coherent
|
20 |
+
# (present inside the dataset).
|
21 |
+
# y_pred: String, optional
|
22 |
+
# Contains the label of the predicted attribute.
|
23 |
+
# target_value: bool or String, optional
|
24 |
+
# """
|
25 |
+
# self.y_true = y_true
|
26 |
+
# self.y_pred = y_pred
|
27 |
+
# self.target_value = target_value
|
28 |
+
|
29 |
+
|
30 |
+
class Descriptor:
|
31 |
+
"""
|
32 |
+
Thi object is formed by an attribute name and an attribute value (or a lower bound plus an upper bound
|
33 |
+
if the Descriptor is numeric).
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
attribute_name,
|
39 |
+
attribute_value=None,
|
40 |
+
up_bound=None,
|
41 |
+
low_bound=None,
|
42 |
+
to_discretize=False,
|
43 |
+
is_numeric=False,
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Parameters
|
47 |
+
----------
|
48 |
+
attribute_name : string
|
49 |
+
attribute_value : string or bool, default None
|
50 |
+
To set only if the Descriptor is not numeric.
|
51 |
+
up_bound : double or int, default None
|
52 |
+
To set iff the Descriptor is numeric and already discretized.
|
53 |
+
low_bound : double or int, default None
|
54 |
+
To set iff the Descriptor is numeric and already discretized.
|
55 |
+
to_discretize : bool, default False
|
56 |
+
To set at True iff the Descriptor is numeric and not still discretized. In this case
|
57 |
+
the up_bound and low_bound attributes will be meaningless
|
58 |
+
is_numeric : bool, default False
|
59 |
+
To set at true iff the descriptor is numeric
|
60 |
+
"""
|
61 |
+
|
62 |
+
self.attribute_name = attribute_name
|
63 |
+
self.is_numeric = is_numeric
|
64 |
+
if is_numeric:
|
65 |
+
self.up_bound = up_bound
|
66 |
+
self.low_bound = low_bound
|
67 |
+
else:
|
68 |
+
self.attribute_value = attribute_value
|
69 |
+
self.to_discretize = to_discretize # The current implementation could work also without this parameter
|
70 |
+
|
71 |
+
def get_attribute_name(self):
|
72 |
+
return self.attribute_name
|
73 |
+
|
74 |
+
def is_to_discretize(self):
|
75 |
+
return self.to_discretize
|
76 |
+
|
77 |
+
def is_present_in(self, other_descriptors):
|
78 |
+
"""
|
79 |
+
:param: other_descriptors: list of Descriptors
|
80 |
+
:return: bool
|
81 |
+
"""
|
82 |
+
for other in other_descriptors:
|
83 |
+
if (
|
84 |
+
self.attribute_name == other.attribute_name
|
85 |
+
and self.is_numeric == other.is_numeric
|
86 |
+
):
|
87 |
+
if (
|
88 |
+
self.is_numeric
|
89 |
+
and self.up_bound == other.up_bound
|
90 |
+
and self.low_bound == other.low_bound
|
91 |
+
):
|
92 |
+
return True
|
93 |
+
elif (
|
94 |
+
self.is_numeric == False
|
95 |
+
and self.attribute_value == other.attribute_value
|
96 |
+
):
|
97 |
+
return True
|
98 |
+
return False
|
99 |
+
|
100 |
+
|
101 |
+
class Description:
|
102 |
+
"""List of Descriptors plus other description attributes.
|
103 |
+
|
104 |
+
Semantically it is to be interpreted as the conjunction of all the Descriptors contained in the list:
|
105 |
+
a dataset record will match the description if each single Descriptor of the description will match with this record.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, descriptors: Descriptor = None):
|
109 |
+
"""
|
110 |
+
:param Descriptors : list of Descriptor
|
111 |
+
"""
|
112 |
+
if descriptors == None:
|
113 |
+
self.descriptors = []
|
114 |
+
else:
|
115 |
+
self.descriptors = descriptors
|
116 |
+
self.support = None
|
117 |
+
|
118 |
+
def __repr__(self, opposite=False):
|
119 |
+
"""Represent the description as a string.
|
120 |
+
|
121 |
+
:return : String
|
122 |
+
"""
|
123 |
+
descr = ""
|
124 |
+
if opposite:
|
125 |
+
descr = descr + "NOT ("
|
126 |
+
for s in self.descriptors:
|
127 |
+
# If the descriptor is numeric, the string will be formed by the attribute name and the lower and upper bound
|
128 |
+
if s.is_numeric:
|
129 |
+
low = str(s.low_bound) if s.low_bound is not None else "-infinite"
|
130 |
+
up = str(s.up_bound) if s.up_bound is not None else "+infinite"
|
131 |
+
descr = descr + s.attribute_name + " = (" + low + ", " + up + "] AND "
|
132 |
+
# If the descriptor is not numeric, the string will be formed by the attribute name and the attribute value
|
133 |
+
else:
|
134 |
+
descr = (
|
135 |
+
descr
|
136 |
+
+ s.attribute_name
|
137 |
+
+ ' = "'
|
138 |
+
+ str(s.attribute_value)
|
139 |
+
+ '" AND '
|
140 |
+
)
|
141 |
+
if descr != "":
|
142 |
+
descr = descr[:-4]
|
143 |
+
if opposite:
|
144 |
+
descr = descr + ")"
|
145 |
+
return descr
|
146 |
+
|
147 |
+
def to_string(self, opposite=False):
|
148 |
+
return self.__repr__(opposite=opposite)
|
149 |
+
|
150 |
+
def __lt__(self, other):
|
151 |
+
"""Compare the current description (self) with another description (other).
|
152 |
+
|
153 |
+
:param other: Description
|
154 |
+
:return: bool
|
155 |
+
"""
|
156 |
+
if self.quality != other.quality:
|
157 |
+
return self.quality < other.quality
|
158 |
+
elif self.support != other.support:
|
159 |
+
return self.support < other.support
|
160 |
+
else:
|
161 |
+
return len(self.descriptors) > len(other.descriptors)
|
162 |
+
|
163 |
+
def to_boolean_array(self, dataset, set_attributes=False):
|
164 |
+
"""
|
165 |
+
Parameters
|
166 |
+
----------
|
167 |
+
dataset : pandas.DataFrame
|
168 |
+
set_attributes : bol, default False
|
169 |
+
if this input is True, this method will set also the support attribute
|
170 |
+
|
171 |
+
Returns
|
172 |
+
-------
|
173 |
+
pandas Series of boolean type:
|
174 |
+
The array will have the length of the passed dataset (number of rows).
|
175 |
+
Each element of the array will be true iff the description (self) match the corresponding row of the dataset.
|
176 |
+
If a description is empty, the returned array will have all elements equal to True.
|
177 |
+
"""
|
178 |
+
s = np.full(dataset.shape[0], True)
|
179 |
+
s = pd.Series(s)
|
180 |
+
for i in range(0, len(self.descriptors)):
|
181 |
+
if self.descriptors[i].is_numeric:
|
182 |
+
if self.descriptors[i].low_bound is not None:
|
183 |
+
s = s & (
|
184 |
+
dataset[self.descriptors[i].attribute_name]
|
185 |
+
> self.descriptors[i].low_bound
|
186 |
+
)
|
187 |
+
if self.descriptors[i].up_bound is not None:
|
188 |
+
s = s & (
|
189 |
+
dataset[self.descriptors[i].attribute_name]
|
190 |
+
<= self.descriptors[i].up_bound
|
191 |
+
)
|
192 |
+
else:
|
193 |
+
s = (s) & (
|
194 |
+
dataset[self.descriptors[i].attribute_name]
|
195 |
+
== self.descriptors[i].attribute_value
|
196 |
+
)
|
197 |
+
|
198 |
+
if set_attributes:
|
199 |
+
# set size, relative size and target share
|
200 |
+
self.support = sum(s)
|
201 |
+
return s
|
202 |
+
|
203 |
+
def size(self, dataset=None):
|
204 |
+
"""Return the support of the description and set the support in case this parameters was not set.
|
205 |
+
|
206 |
+
:param dataset: pandas.DataFrame
|
207 |
+
:return: int
|
208 |
+
"""
|
209 |
+
if self.support is None:
|
210 |
+
if dataset is None:
|
211 |
+
raise RuntimeError("dataset argument required in Description.size()")
|
212 |
+
self.to_boolean_array(dataset, set_attributes=True)
|
213 |
+
return self.support
|
214 |
+
|
215 |
+
def get_attributes(self):
|
216 |
+
"""Return the list of the attribute names in the description.
|
217 |
+
|
218 |
+
:return: list of String
|
219 |
+
"""
|
220 |
+
attributes = []
|
221 |
+
for sel in self.descriptors:
|
222 |
+
attributes.append(sel.get_attribute_name())
|
223 |
+
return attributes
|
224 |
+
|
225 |
+
def get_Descriptors(self):
|
226 |
+
return self.descriptors
|
227 |
+
|
228 |
+
def is_present_in(self, beam):
|
229 |
+
"""
|
230 |
+
:param beam : list of Description objects
|
231 |
+
|
232 |
+
:return: bool
|
233 |
+
True if the current description (self) is present in the list (beam).
|
234 |
+
Return true iff the current object (Description) have the same parameters of at list another object
|
235 |
+
present in the beam.
|
236 |
+
|
237 |
+
"""
|
238 |
+
for descr in beam:
|
239 |
+
equals = True
|
240 |
+
for sel in self.descriptors:
|
241 |
+
if sel.is_present_in(descr.descriptors) == False:
|
242 |
+
equals = False
|
243 |
+
break
|
244 |
+
if equals:
|
245 |
+
return True
|
246 |
+
return False
|
247 |
+
|
248 |
+
def set_quality(self, q):
|
249 |
+
self.quality = q
|
250 |
+
|
251 |
+
def get_quality(self):
|
252 |
+
return self.quality
|
fairsd/main.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.datasets import fetch_openml
|
2 |
+
import pandas as pd
|
3 |
+
from sklearn.tree import DecisionTreeClassifier
|
4 |
+
from sklearn.metrics import accuracy_score
|
5 |
+
import fairlearn.metrics as fm
|
6 |
+
import fairsd as dsd
|
7 |
+
|
8 |
+
# Import dataset, training the classifier, producing y_pred
|
9 |
+
d = fetch_openml(data_id=1590, as_frame=True)
|
10 |
+
dataset = d.data
|
11 |
+
d_train = pd.get_dummies(dataset)
|
12 |
+
y_true = (d.target == ">50K") * 1
|
13 |
+
|
14 |
+
classifier = DecisionTreeClassifier(min_samples_leaf=10, max_depth=4)
|
15 |
+
classifier.fit(d_train, y_true)
|
16 |
+
y_pred = classifier.predict(d_train)
|
17 |
+
print("ACCURACY OF THE CLASSIFIER:")
|
18 |
+
print(accuracy_score(y_true, y_pred))
|
19 |
+
print("-------------------------------------------" + "\n")
|
20 |
+
|
21 |
+
dataset = dataset.head(1000)
|
22 |
+
y_pred = y_pred[:1000]
|
23 |
+
y_true = y_true[:1000]
|
24 |
+
|
25 |
+
|
26 |
+
task = dsd.SubgroupDiscoveryTask(
|
27 |
+
dataset, y_true, y_pred, qf=fm.demographic_parity_difference
|
28 |
+
)
|
29 |
+
result_set = dsd.BeamSearch(beam_width=10).execute(task)
|
30 |
+
|
31 |
+
print(result_set.to_dataframe())
|
32 |
+
print(result_set.extract_sg_feature(0, dataset))
|
fairsd/notebooks/fairsd_settings.ipynb
ADDED
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "compliant-oregon",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# FairSD settings\n",
|
9 |
+
"For this example is used the UCI adult dataset where the objective is to predict whether a person makes more (label 1) or less (0) than $50,000 a year."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"id": "illegal-american",
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import pandas as pd\n",
|
20 |
+
"import numpy as np\n",
|
21 |
+
"from sklearn.datasets import fetch_openml\n",
|
22 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
23 |
+
"#Import dataset\n",
|
24 |
+
"d = fetch_openml(data_id=1590, as_frame=True)\n",
|
25 |
+
"X = d.data\n",
|
26 |
+
"d_train=pd.get_dummies(X)\n",
|
27 |
+
"y_true = (d.target == '>50K') * 1\n",
|
28 |
+
"#training the classifier\n",
|
29 |
+
"classifier = DecisionTreeClassifier(min_samples_leaf=10, max_depth=4)\n",
|
30 |
+
"classifier.fit(d_train, y_true)\n",
|
31 |
+
"#Producing y_pred\n",
|
32 |
+
"y_pred = classifier.predict(d_train)"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"cell_type": "markdown",
|
37 |
+
"id": "private-origin",
|
38 |
+
"metadata": {},
|
39 |
+
"source": [
|
40 |
+
"## SubgroupDiscoveryTask Object\n",
|
41 |
+
"This class will contain all the parameters useful for the sg discovery algorithms.<br/>\n",
|
42 |
+
"**<u>Parameters</u>:**"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "markdown",
|
47 |
+
"id": "reverse-kinase",
|
48 |
+
"metadata": {},
|
49 |
+
"source": [
|
50 |
+
"|name |type |default value |description\n",
|
51 |
+
"|:-----|:-----|:-----|:----- \n",
|
52 |
+
"|X| pandas.Dataframe or <br/> numpy.ndarray| |dataset.\n",
|
53 |
+
"|y_true| pandas.Dataframe,<br/> pandas.Series or <br/> numpy.ndarray| |ground truth.\n",
|
54 |
+
"|y_pred| pandas.Dataframe,<br/> pandas.Series or <br/> numpy.ndarray| |predicted label.\n",
|
55 |
+
"|feature_names| list of strings| None| see below.\n",
|
56 |
+
"|nominal_features| list of strings| None| see below.\n",
|
57 |
+
"|numeric_features| list of strings| None| see below.\n",
|
58 |
+
"|qf | String or <br/>callable object| 'equalized_odds_difference'| quality function to use.\n",
|
59 |
+
"|discretizer| String| 'equalfreq'| see below.\n",
|
60 |
+
"|num_bins | int| 6| see below.\n",
|
61 |
+
"|dynamic_discretization| bool| True| see below.\n",
|
62 |
+
"|result_set_size| int| 5| maximum number of subgroups <br/>that the sg discovery will return.\n",
|
63 |
+
"|depth| int| 3| maximum number of descriptors<br/> that a description can contain.\n",
|
64 |
+
"|min_quality| float| 0.1| minimum quality<br/> that a subgroup needs to be selected. \n",
|
65 |
+
"|min_support| int| 200| minimum size <br/>that a subgroup needs to be selected. \n",
|
66 |
+
"|sensitive_features| list| None|list of sensitive features names (str). If this list is not none, the sg-discovery task will return only subgroups containing at least one fo the features in the list. "
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "markdown",
|
71 |
+
"id": "geographic-limitation",
|
72 |
+
"metadata": {},
|
73 |
+
"source": [
|
74 |
+
"### X and feature_names\n",
|
75 |
+
"X parameters can be a Pandas DataFrame or a Numpy array. If we pass a numpy array we must also pass the feature_names parameter, a list with the column names of X:"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 2,
|
81 |
+
"id": "southwest-chemistry",
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"import fairsd as fsd\n",
|
86 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred) # X is a pandas.Dataframe\n",
|
87 |
+
"\n",
|
88 |
+
"#X as numpy array\n",
|
89 |
+
"x_np = X.to_numpy()\n",
|
90 |
+
"columns = list(X.columns)\n",
|
91 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred, feature_names = columns)"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "markdown",
|
96 |
+
"id": "about-storm",
|
97 |
+
"metadata": {},
|
98 |
+
"source": [
|
99 |
+
"### nominal_features and numeric_features\n",
|
100 |
+
"These two parameters (string list type) are used to specify which columns of X have nominal values and which ones have numeric values. For attributes that do not appear in either of the two list, the data type will be automatically inferred.<br/>\n",
|
101 |
+
"**Example:**<br/>\n",
|
102 |
+
" Let's analize the education-num attribute"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": 3,
|
108 |
+
"id": "biological-liabilities",
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [
|
111 |
+
{
|
112 |
+
"name": "stdout",
|
113 |
+
"output_type": "stream",
|
114 |
+
"text": [
|
115 |
+
"[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16.]\n"
|
116 |
+
]
|
117 |
+
}
|
118 |
+
],
|
119 |
+
"source": [
|
120 |
+
"educationnum_val = X['education-num'].unique()\n",
|
121 |
+
"print(np.sort(educationnum_val))"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "markdown",
|
126 |
+
"id": "naval-leadership",
|
127 |
+
"metadata": {},
|
128 |
+
"source": [
|
129 |
+
"The values of this attribute are integer number from 1 to 16. <br/>\n",
|
130 |
+
"The package would treat this attribute as numeric by default, but if we want to treat it as a nominal attribute we can use the nominal_features parameter:"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 4,
|
136 |
+
"id": "unauthorized-oriental",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred, nominal_features = ['education-num'])"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "markdown",
|
145 |
+
"id": "possible-wilderness",
|
146 |
+
"metadata": {},
|
147 |
+
"source": [
|
148 |
+
"### discretizer\n",
|
149 |
+
"This parameter determine the algorithm that will perform the numerical features discretization.\n",
|
150 |
+
"It is set to 'equalfreq' by default but other options are possible:<br/>"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "markdown",
|
155 |
+
"id": "manufactured-algorithm",
|
156 |
+
"metadata": {},
|
157 |
+
"source": [
|
158 |
+
"|dicretizer parameter |description \n",
|
159 |
+
"|:-----|:----- \n",
|
160 |
+
"|'eualfreq'|approximate equal frequency discretization.\n",
|
161 |
+
"|'equalwidth'|equal width discretization.\n",
|
162 |
+
"|'mdlp'|minimum description length principle discretization (Fayyad and Irani approach)."
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 5,
|
168 |
+
"id": "blocked-morning",
|
169 |
+
"metadata": {},
|
170 |
+
"outputs": [],
|
171 |
+
"source": [
|
172 |
+
"#example\n",
|
173 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred, discretizer='mdlp')"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"cell_type": "markdown",
|
178 |
+
"id": "developmental-orbit",
|
179 |
+
"metadata": {},
|
180 |
+
"source": [
|
181 |
+
"### num_bins\n",
|
182 |
+
"This parameter determine the maximum number of bins that a numerical feature discretization operation will produce. The number of produced bins could also be less than num_bins due to the min_support constraint:<br/>\n",
|
183 |
+
"* the equal-frequency and mdlp discretizations will never produce bins with size lower than min_support, as we know in advance that these bins could never be used to create a description with large enough support.<br/>\n",
|
184 |
+
"\n",
|
185 |
+
"If a value lower than two is used, the number of bins will be automatically decided.\n"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": 6,
|
191 |
+
"id": "focused-patrick",
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"#example\n",
|
196 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred, discretizer='equalwidth', num_bins = 4)"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "markdown",
|
201 |
+
"id": "periodic-glass",
|
202 |
+
"metadata": {},
|
203 |
+
"source": [
|
204 |
+
"### dynamic_discretization\n",
|
205 |
+
"If is set to False, the discretization of numeric features will be done only once for each numerical feature before starting the subgroup discovery algorithm.<br/>\n",
|
206 |
+
"If instead is set to True, the discretization of numerical featues will be integrated with the subgroup discovery algorithm: will be done each time the algorithm will try to expand a description with a numerical attribute."
|
207 |
+
]
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"execution_count": 7,
|
212 |
+
"id": "hispanic-boulder",
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [],
|
215 |
+
"source": [
|
216 |
+
"#example\n",
|
217 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred, dynamic_discretization = False)"
|
218 |
+
]
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"cell_type": "markdown",
|
222 |
+
"id": "interstate-edinburgh",
|
223 |
+
"metadata": {},
|
224 |
+
"source": [
|
225 |
+
"## Qaulity Measures\n",
|
226 |
+
"### Fairlearn Quality measures\n",
|
227 |
+
"All the Fairlear metrics can be used as quality measure (qf parameter in SubgroupDiscoveryTask object). See the Fairlearn documentation [here](https://fairlearn.github.io/v0.6.0/api_reference/fairlearn.metrics.html).<br/>\n",
|
228 |
+
"**The predefined fairlearn metrics are:**\n"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "markdown",
|
233 |
+
"id": "premium-charlotte",
|
234 |
+
"metadata": {},
|
235 |
+
"source": [
|
236 |
+
"|Metric name | Description\n",
|
237 |
+
"|:-----|:-----\n",
|
238 |
+
"| demographic_parity_difference | Defined as the absolute value of the difference in the **selection rates** between a subgroup and its negation.\n",
|
239 |
+
"|demographic_parity_ratio | Defined as the ratio between the smallest and the largest group-level **selection rate**, between a subgroup and its negation.\n",
|
240 |
+
"|equalized_odds_difference | The greater of two metrics: true_positive_rate_difference and false_positive_rate_difference. The former is the difference between the TPRs, between a subgroup and its negation. The latter is defined similarly, but for FPRs.\n",
|
241 |
+
"|equalized_odds_ratio | The smaller of two metrics: true_positive_rate_ratio and false_positive_rate_ratio. The former is the ratio between the TPRs, between a subgroup and its negation. The latter is defined similarly, but for FPRs."
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "markdown",
|
246 |
+
"id": "upset-broadcasting",
|
247 |
+
"metadata": {},
|
248 |
+
"source": [
|
249 |
+
"We can inizialize a SubgroupDiscoveryTask object in this way:"
|
250 |
+
]
|
251 |
+
},
|
252 |
+
{
|
253 |
+
"cell_type": "code",
|
254 |
+
"execution_count": 8,
|
255 |
+
"id": "labeled-retail",
|
256 |
+
"metadata": {},
|
257 |
+
"outputs": [],
|
258 |
+
"source": [
|
259 |
+
"import fairsd as fsd\n",
|
260 |
+
"from fairlearn.metrics import demographic_parity_ratio\n",
|
261 |
+
"task = fsd.SubgroupDiscoveryTask(X, y_true, y_pred, demographic_parity_ratio)"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "markdown",
|
266 |
+
"id": "personalized-munich",
|
267 |
+
"metadata": {},
|
268 |
+
"source": [
|
269 |
+
"Or, faster, for this four predefined fairlearn metrics, we can pass the same metric as a string:"
|
270 |
+
]
|
271 |
+
},
|
272 |
+
{
|
273 |
+
"cell_type": "code",
|
274 |
+
"execution_count": 9,
|
275 |
+
"id": "directed-robin",
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": [
|
279 |
+
"task = fsd.SubgroupDiscoveryTask(X, y_true, y_pred, 'demographic_parity_ratio')"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
{
|
283 |
+
"cell_type": "markdown",
|
284 |
+
"id": "subjective-carrier",
|
285 |
+
"metadata": {},
|
286 |
+
"source": [
|
287 |
+
"From the version 6.0, Fairlearn also offers the interesting possibility of \"create a scalar returning metric function based on aggregation of a disaggregated metric\".<br/>\n",
|
288 |
+
"**Example:**"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 10,
|
294 |
+
"id": "efficient-tolerance",
|
295 |
+
"metadata": {},
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"from sklearn.metrics import accuracy_score\n",
|
299 |
+
"from fairlearn.metrics import make_derived_metric\n",
|
300 |
+
"derived_metric = make_derived_metric(metric = accuracy_score, transform = 'difference')\n",
|
301 |
+
"task = fsd.SubgroupDiscoveryTask(X, y_true, y_pred, derived_metric)"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "markdown",
|
306 |
+
"id": "solid-nutrition",
|
307 |
+
"metadata": {},
|
308 |
+
"source": [
|
309 |
+
"For more details see the Fairlearn documentation."
|
310 |
+
]
|
311 |
+
},
|
312 |
+
{
|
313 |
+
"cell_type": "markdown",
|
314 |
+
"id": "distinguished-tonight",
|
315 |
+
"metadata": {},
|
316 |
+
"source": [
|
317 |
+
"### Customized Quality Measures\n",
|
318 |
+
"It is possible to create a quality measure by estending the class [QualityFunction](https://github.com/MaurizioPulizzi/fairsd/blob/main/fairsd/qualitymeasures.py#L3):"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"execution_count": 11,
|
324 |
+
"id": "eligible-british",
|
325 |
+
"metadata": {},
|
326 |
+
"outputs": [],
|
327 |
+
"source": [
|
328 |
+
"class MyQualityMeasure(fsd.QualityFunction):\n",
|
329 |
+
" def evaluate(self, y_true = None, y_pred=None, sensitive_features=None):\n",
|
330 |
+
" return 0.5\n",
|
331 |
+
" \n",
|
332 |
+
"task = fsd.SubgroupDiscoveryTask(X, y_true, y_pred, MyQualityMeasure.evaluate)"
|
333 |
+
]
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"cell_type": "markdown",
|
337 |
+
"id": "abstract-crown",
|
338 |
+
"metadata": {},
|
339 |
+
"source": [
|
340 |
+
"## Descriptions and Quality Measures\n",
|
341 |
+
"A subgroup description is formed by the conjunction of zero or more descriptors.<br/>\n",
|
342 |
+
"A descriptor is a statement in the form \"attribute_name = attribute_value\" for nomilal attributes or \"attribute_name = range\" for numerical attributes.<br/>\n",
|
343 |
+
"Example of Description: \" sex = 'Male' AND age = (18, 30] \". <br/>\n",
|
344 |
+
"The Top-k subgroup discovery task in this package returns the k subgroup descriptions of the subgroups that exert the greatest disparity.<br>\n",
|
345 |
+
"There is no single definition of subgroup disparity, the meaning changes according to the used quality measure.\n",
|
346 |
+
"\n",
|
347 |
+
"**All metrics in the [fairlearn.metrics](https://fairlearn.github.io/v0.6.0/api_reference/fairlearn.metrics.html) module are symmetrical:** they always return a value between 0 and 1 and do not distinguish whether a subgroup is \\\"positively\\\" or \\\"negatively\\\" dissimilar. For example the descriptions \\\"married = True\\\" and \\\"married = False\\\" will always have the same quality.\n"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "markdown",
|
352 |
+
"id": "round-dealer",
|
353 |
+
"metadata": {},
|
354 |
+
"source": [
|
355 |
+
"## Implemented Subgroup Disovery Algorithms\n",
|
356 |
+
"This package offers two Top-K Subgroup Discovery Algorithms: **BeamSearch** and **DSSD**."
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "markdown",
|
361 |
+
"id": "nonprofit-penetration",
|
362 |
+
"metadata": {},
|
363 |
+
"source": [
|
364 |
+
"### BeamSearch Algorithm\n",
|
365 |
+
"BeamSEarch is an euristic algorithm representing a good between the completeness of exhaustive search algorithms and the speed of greedy algorithms. <br/> This trade-off can be adjusted via the beam width. For a beam width that tends to infinity the algorithm becomes an exhaustive search algoritm, instead with a beam width equal to 1 the algorithm becomes a greedy one."
|
366 |
+
]
|
367 |
+
},
|
368 |
+
{
|
369 |
+
"cell_type": "code",
|
370 |
+
"execution_count": 12,
|
371 |
+
"id": "neutral-representative",
|
372 |
+
"metadata": {},
|
373 |
+
"outputs": [],
|
374 |
+
"source": [
|
375 |
+
"# Beam search algorithm usage\n",
|
376 |
+
"from fairsd import SubgroupDiscoveryTask\n",
|
377 |
+
"from fairsd import BeamSearch\n",
|
378 |
+
"task = SubgroupDiscoveryTask(X, y_true, y_pred)\n",
|
379 |
+
"resultset = BeamSearch().execute(task)"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"cell_type": "markdown",
|
384 |
+
"id": "balanced-center",
|
385 |
+
"metadata": {},
|
386 |
+
"source": [
|
387 |
+
"The beam width is 20 by default but it is possible specify a different value:"
|
388 |
+
]
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"cell_type": "code",
|
392 |
+
"execution_count": 13,
|
393 |
+
"id": "identical-journey",
|
394 |
+
"metadata": {},
|
395 |
+
"outputs": [
|
396 |
+
{
|
397 |
+
"name": "stdout",
|
398 |
+
"output_type": "stream",
|
399 |
+
"text": [
|
400 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" \n",
|
401 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" AND capital-loss = (-infinite, 0.0] \n",
|
402 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" AND capital-gain = (-infinite, 0.0] \n",
|
403 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" AND race = \"White\" \n",
|
404 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" AND sex = \"Male\" \n",
|
405 |
+
"\n"
|
406 |
+
]
|
407 |
+
}
|
408 |
+
],
|
409 |
+
"source": [
|
410 |
+
"resultset = BeamSearch(beam_width=10).execute(task)\n",
|
411 |
+
"print(resultset)"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"cell_type": "markdown",
|
416 |
+
"id": "seasonal-triumph",
|
417 |
+
"metadata": {},
|
418 |
+
"source": [
|
419 |
+
"#### Redundancy\n",
|
420 |
+
"As we can see, the returned result set is somehow redundant: many variants of, essentially, the same subgroup are presents in it. <br/>\n",
|
421 |
+
"Many top-k sg-discovery algorithms suffers of this problem.<br/>\n",
|
422 |
+
"To solve this problem it is suggested to use the DSSD algorithm."
|
423 |
+
]
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "markdown",
|
427 |
+
"id": "hungarian-brass",
|
428 |
+
"metadata": {},
|
429 |
+
"source": [
|
430 |
+
"### DSSD (Diverse Subgroup Set Discovery) algorithm\n",
|
431 |
+
" This algorithm is a variant of the Beam Search Algorithm that also take into account the redundancy of the generated subgroups.<br/>\n",
|
432 |
+
"In this package a cover-based redundancy definition is used: roughly, the more tuples two subgroups have in common, the more they are considered redundant. <br/>\n",
|
433 |
+
"The degree to which redundancy is mitigated is determined by the alpha parameter. the more a is high, the less the subgroups redundancy is taken into account. Alpha must be between zero and 1, and, the more alpha is high, the less the subgroups redundancy is taken into account.<br>\n",
|
434 |
+
"\n",
|
435 |
+
"This algorithm is described in details in the Van Leeuwen and Knobbe's paper \"Diverse Subgroup Set Discovery\".\n"
|
436 |
+
]
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"cell_type": "code",
|
440 |
+
"execution_count": 14,
|
441 |
+
"id": "pleased-elizabeth",
|
442 |
+
"metadata": {},
|
443 |
+
"outputs": [
|
444 |
+
{
|
445 |
+
"name": "stdout",
|
446 |
+
"output_type": "stream",
|
447 |
+
"text": [
|
448 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" \n",
|
449 |
+
"education = \"Bachelors\" AND relationship = \"Husband\" \n",
|
450 |
+
"education-num = (13.0, +infinite] AND marital-status = \"Married-civ-spouse\" \n",
|
451 |
+
"education-num = (13.0, +infinite] AND race = \"Asian-Pac-Islander\" \n",
|
452 |
+
"relationship = \"Husband\" AND capital-gain = (7298.0, +infinite] \n",
|
453 |
+
"\n"
|
454 |
+
]
|
455 |
+
}
|
456 |
+
],
|
457 |
+
"source": [
|
458 |
+
"# DSSD algorithm usage\n",
|
459 |
+
"from fairsd import DSSD\n",
|
460 |
+
"resultset = DSSD(beam_width=10, a = 0.9).execute(task) # \"a\" parameter represents alpha, 0.9 by default \n",
|
461 |
+
"print(resultset)"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "markdown",
|
466 |
+
"id": "placed-people",
|
467 |
+
"metadata": {},
|
468 |
+
"source": [
|
469 |
+
"As we can see now, just by looking at the descriptions in the result set, the redundancy sems very attenuated."
|
470 |
+
]
|
471 |
+
}
|
472 |
+
],
|
473 |
+
"metadata": {
|
474 |
+
"kernelspec": {
|
475 |
+
"display_name": "Python 3",
|
476 |
+
"language": "python",
|
477 |
+
"name": "python3"
|
478 |
+
},
|
479 |
+
"language_info": {
|
480 |
+
"codemirror_mode": {
|
481 |
+
"name": "ipython",
|
482 |
+
"version": 3
|
483 |
+
},
|
484 |
+
"file_extension": ".py",
|
485 |
+
"mimetype": "text/x-python",
|
486 |
+
"name": "python",
|
487 |
+
"nbconvert_exporter": "python",
|
488 |
+
"pygments_lexer": "ipython3",
|
489 |
+
"version": "3.8.3"
|
490 |
+
}
|
491 |
+
},
|
492 |
+
"nbformat": 4,
|
493 |
+
"nbformat_minor": 5
|
494 |
+
}
|
fairsd/notebooks/fairsd_usage.ipynb
ADDED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "proper-mustang",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Quick start -- Use Case Example\n",
|
9 |
+
"For this example is used the [UCI adult dataset](https://archive.ics.uci.edu/ml/datasets/Adult) where the objective is to predict whether a person makes more (label 1) or less (0) than $50,000 a year."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"id": "authorized-better",
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import pandas as pd\n",
|
20 |
+
"from sklearn.datasets import fetch_openml\n",
|
21 |
+
"from sklearn.tree import DecisionTreeClassifier\n",
|
22 |
+
"\n",
|
23 |
+
"#Import dataset\n",
|
24 |
+
"d = fetch_openml(data_id=1590, as_frame=True)\n",
|
25 |
+
"X = d.data\n",
|
26 |
+
"d_train=pd.get_dummies(X)\n",
|
27 |
+
"y_true = (d.target == '>50K') * 1\n",
|
28 |
+
"\n",
|
29 |
+
"#training the classifier\n",
|
30 |
+
"classifier = DecisionTreeClassifier(min_samples_leaf=10, max_depth=4)\n",
|
31 |
+
"classifier.fit(d_train, y_true)\n",
|
32 |
+
"\n",
|
33 |
+
"#Producing y_pred\n",
|
34 |
+
"y_pred = classifier.predict(d_train)"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"id": "adolescent-damage",
|
40 |
+
"metadata": {},
|
41 |
+
"source": [
|
42 |
+
"## Use of the FairSD package\n",
|
43 |
+
"Here we use the DSSD (Diverse Subgroup Set Discovery) algorithm and the demographic_parity_difference (from Fairlearn) to find the top-k (k = 5 by default) subgroups that exert the greatest disparity.<br/>\n",
|
44 |
+
"The execute method return a **ResultSet object**."
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": 2,
|
50 |
+
"id": "talented-dynamics",
|
51 |
+
"metadata": {},
|
52 |
+
"outputs": [],
|
53 |
+
"source": [
|
54 |
+
"import fairsd as fsd\n",
|
55 |
+
"task=fsd.SubgroupDiscoveryTask(X, y_true, y_pred, qf = \"demographic_parity_difference\")\n",
|
56 |
+
"result_set=fsd.DSSD().execute(task)"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"id": "liked-paradise",
|
62 |
+
"metadata": {},
|
63 |
+
"source": [
|
64 |
+
"### ResultSet object"
|
65 |
+
]
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"cell_type": "markdown",
|
69 |
+
"id": "blind-official",
|
70 |
+
"metadata": {},
|
71 |
+
"source": [
|
72 |
+
"We can transform the result set into a dataframe as shown below. Each row of this dataframe represents a subgroup."
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 3,
|
78 |
+
"id": "european-relaxation",
|
79 |
+
"metadata": {
|
80 |
+
"scrolled": true
|
81 |
+
},
|
82 |
+
"outputs": [
|
83 |
+
{
|
84 |
+
"data": {
|
85 |
+
"text/html": [
|
86 |
+
"<div>\n",
|
87 |
+
"<style scoped>\n",
|
88 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
89 |
+
" vertical-align: middle;\n",
|
90 |
+
" }\n",
|
91 |
+
"\n",
|
92 |
+
" .dataframe tbody tr th {\n",
|
93 |
+
" vertical-align: top;\n",
|
94 |
+
" }\n",
|
95 |
+
"\n",
|
96 |
+
" .dataframe thead th {\n",
|
97 |
+
" text-align: right;\n",
|
98 |
+
" }\n",
|
99 |
+
"</style>\n",
|
100 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
101 |
+
" <thead>\n",
|
102 |
+
" <tr style=\"text-align: right;\">\n",
|
103 |
+
" <th></th>\n",
|
104 |
+
" <th>quality</th>\n",
|
105 |
+
" <th>description</th>\n",
|
106 |
+
" <th>size</th>\n",
|
107 |
+
" <th>proportion</th>\n",
|
108 |
+
" </tr>\n",
|
109 |
+
" </thead>\n",
|
110 |
+
" <tbody>\n",
|
111 |
+
" <tr>\n",
|
112 |
+
" <th>0</th>\n",
|
113 |
+
" <td>0.913502</td>\n",
|
114 |
+
" <td>education = \"Bachelors\" AND marital-status = \"...</td>\n",
|
115 |
+
" <td>4136</td>\n",
|
116 |
+
" <td>0.084681</td>\n",
|
117 |
+
" </tr>\n",
|
118 |
+
" <tr>\n",
|
119 |
+
" <th>1</th>\n",
|
120 |
+
" <td>0.913502</td>\n",
|
121 |
+
" <td>education-num = (12.0, 13.0] AND marital-statu...</td>\n",
|
122 |
+
" <td>4136</td>\n",
|
123 |
+
" <td>0.084681</td>\n",
|
124 |
+
" </tr>\n",
|
125 |
+
" <tr>\n",
|
126 |
+
" <th>2</th>\n",
|
127 |
+
" <td>0.866879</td>\n",
|
128 |
+
" <td>capital-gain = (6849.0, +infinite] AND fnlwgt ...</td>\n",
|
129 |
+
" <td>2036</td>\n",
|
130 |
+
" <td>0.041685</td>\n",
|
131 |
+
" </tr>\n",
|
132 |
+
" <tr>\n",
|
133 |
+
" <th>3</th>\n",
|
134 |
+
" <td>0.863130</td>\n",
|
135 |
+
" <td>education = \"Masters\" AND marital-status = \"Ma...</td>\n",
|
136 |
+
" <td>1527</td>\n",
|
137 |
+
" <td>0.031264</td>\n",
|
138 |
+
" </tr>\n",
|
139 |
+
" <tr>\n",
|
140 |
+
" <th>4</th>\n",
|
141 |
+
" <td>0.853604</td>\n",
|
142 |
+
" <td>education-num = (14.0, +infinite] AND marital-...</td>\n",
|
143 |
+
" <td>999</td>\n",
|
144 |
+
" <td>0.020454</td>\n",
|
145 |
+
" </tr>\n",
|
146 |
+
" </tbody>\n",
|
147 |
+
"</table>\n",
|
148 |
+
"</div>"
|
149 |
+
],
|
150 |
+
"text/plain": [
|
151 |
+
" quality description size \\\n",
|
152 |
+
"0 0.913502 education = \"Bachelors\" AND marital-status = \"... 4136 \n",
|
153 |
+
"1 0.913502 education-num = (12.0, 13.0] AND marital-statu... 4136 \n",
|
154 |
+
"2 0.866879 capital-gain = (6849.0, +infinite] AND fnlwgt ... 2036 \n",
|
155 |
+
"3 0.863130 education = \"Masters\" AND marital-status = \"Ma... 1527 \n",
|
156 |
+
"4 0.853604 education-num = (14.0, +infinite] AND marital-... 999 \n",
|
157 |
+
"\n",
|
158 |
+
" proportion \n",
|
159 |
+
"0 0.084681 \n",
|
160 |
+
"1 0.084681 \n",
|
161 |
+
"2 0.041685 \n",
|
162 |
+
"3 0.031264 \n",
|
163 |
+
"4 0.020454 "
|
164 |
+
]
|
165 |
+
},
|
166 |
+
"metadata": {},
|
167 |
+
"output_type": "display_data"
|
168 |
+
}
|
169 |
+
],
|
170 |
+
"source": [
|
171 |
+
"df=result_set.to_dataframe()\n",
|
172 |
+
"display(df)"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "markdown",
|
177 |
+
"id": "steady-version",
|
178 |
+
"metadata": {},
|
179 |
+
"source": [
|
180 |
+
"We can also print the result set or convert it into a string as shown below."
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 4,
|
186 |
+
"id": "continuing-quarter",
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [
|
189 |
+
{
|
190 |
+
"name": "stdout",
|
191 |
+
"output_type": "stream",
|
192 |
+
"text": [
|
193 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" \n",
|
194 |
+
"education-num = (12.0, 13.0] AND marital-status = \"Married-civ-spouse\" \n",
|
195 |
+
"capital-gain = (6849.0, +infinite] AND fnlwgt = (24763.0, +infinite] \n",
|
196 |
+
"education = \"Masters\" AND marital-status = \"Married-civ-spouse\" \n",
|
197 |
+
"education-num = (14.0, +infinite] AND marital-status = \"Married-civ-spouse\" \n",
|
198 |
+
"\n"
|
199 |
+
]
|
200 |
+
}
|
201 |
+
],
|
202 |
+
"source": [
|
203 |
+
"resultset_string = result_set.to_string()\n",
|
204 |
+
"print(result_set)"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "markdown",
|
209 |
+
"id": "difficult-rendering",
|
210 |
+
"metadata": {},
|
211 |
+
"source": [
|
212 |
+
"### Generate a feature from a subgroup\n",
|
213 |
+
"ResultSet basically contains a list of subgroup descriptions ([Description](https://github.com/MaurizioPulizzi/fairsd/blob/main/fairsd/sgdescription.py#L80) object).<br/>\n",
|
214 |
+
"Another intresting method of Resultset object allow us to \n",
|
215 |
+
"**select a subgroup X from the result set and automatically generate the feature \"Belong to subgroup X\"**.This is very useful for deepening the analysis on the found subgroups, for example we can use the FairLearn library for this purpose.<br/>\n",
|
216 |
+
"An example is shown below:"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"execution_count": 5,
|
222 |
+
"id": "covered-falls",
|
223 |
+
"metadata": {},
|
224 |
+
"outputs": [
|
225 |
+
{
|
226 |
+
"name": "stdout",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"sg0\n",
|
230 |
+
"False 0.0864985\n",
|
231 |
+
"True 1\n",
|
232 |
+
"Name: selection_rate, dtype: object\n"
|
233 |
+
]
|
234 |
+
}
|
235 |
+
],
|
236 |
+
"source": [
|
237 |
+
"from fairlearn.metrics import MetricFrame\n",
|
238 |
+
"from fairlearn.metrics import selection_rate\n",
|
239 |
+
"\n",
|
240 |
+
"# Here we generate the feature \"Belong to subgroup n. 0\"\n",
|
241 |
+
"# The result is a pandas Series. The name of this Series is \"sg0\".\n",
|
242 |
+
"# This series contains an element for each instance of the dataset. Each element is True \n",
|
243 |
+
"# iff the istance belong to the subgroup sg0\n",
|
244 |
+
"sg_feature = result_set.sg_feature(sg_index=0, X=X)\n",
|
245 |
+
"\n",
|
246 |
+
"# Here we basically use the FairLearn library to further analyzing the subgroup sg0\n",
|
247 |
+
"selection_rate = MetricFrame(selection_rate, y_true, y_pred, sensitive_features=sg_feature)\n",
|
248 |
+
"print(selection_rate.by_group)"
|
249 |
+
]
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "markdown",
|
253 |
+
"id": "conceptual-battlefield",
|
254 |
+
"metadata": {},
|
255 |
+
"source": [
|
256 |
+
"### Description object\n",
|
257 |
+
"We can obtain the subgroup feature also retrieving the relative Description object first:"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": 6,
|
263 |
+
"id": "worth-service",
|
264 |
+
"metadata": {},
|
265 |
+
"outputs": [
|
266 |
+
{
|
267 |
+
"name": "stdout",
|
268 |
+
"output_type": "stream",
|
269 |
+
"text": [
|
270 |
+
"0 False\n",
|
271 |
+
"1 False\n",
|
272 |
+
"2 False\n",
|
273 |
+
"3 False\n",
|
274 |
+
"4 False\n",
|
275 |
+
" ... \n",
|
276 |
+
"48837 False\n",
|
277 |
+
"48838 False\n",
|
278 |
+
"48839 False\n",
|
279 |
+
"48840 False\n",
|
280 |
+
"48841 False\n",
|
281 |
+
"Length: 48842, dtype: bool\n"
|
282 |
+
]
|
283 |
+
}
|
284 |
+
],
|
285 |
+
"source": [
|
286 |
+
"description0 = result_set.get_description(0)\n",
|
287 |
+
"sg_feature = description0.to_boolean_array(dataset = X)\n",
|
288 |
+
"print(sg_feature)"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "markdown",
|
293 |
+
"id": "portuguese-eating",
|
294 |
+
"metadata": {},
|
295 |
+
"source": [
|
296 |
+
"Once we have the Description object of a subgroup, we can also extract other information of the subgroup.<br/>\n",
|
297 |
+
"We can:\n",
|
298 |
+
" * convert the Description object into a string\n",
|
299 |
+
" * retrieve the size of the subgroup\n",
|
300 |
+
" * retrieve the quality (fairness measure) of the subgroup\n",
|
301 |
+
" * retrieve the names of the attributes that compose the subgroup description"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "code",
|
306 |
+
"execution_count": 7,
|
307 |
+
"id": "pleased-chancellor",
|
308 |
+
"metadata": {},
|
309 |
+
"outputs": [
|
310 |
+
{
|
311 |
+
"name": "stdout",
|
312 |
+
"output_type": "stream",
|
313 |
+
"text": [
|
314 |
+
"education = \"Bachelors\" AND marital-status = \"Married-civ-spouse\" \n",
|
315 |
+
"4136\n",
|
316 |
+
"0.913501543416991\n",
|
317 |
+
"['education', 'marital-status']\n"
|
318 |
+
]
|
319 |
+
}
|
320 |
+
],
|
321 |
+
"source": [
|
322 |
+
"# String conversion\n",
|
323 |
+
"str_descr = description0.to_string()\n",
|
324 |
+
"print( str_descr ) # also print(description0) works\n",
|
325 |
+
"\n",
|
326 |
+
"# Size\n",
|
327 |
+
"print( description0.size() )\n",
|
328 |
+
"\n",
|
329 |
+
"# Quality\n",
|
330 |
+
"print( description0.get_quality() )\n",
|
331 |
+
"\n",
|
332 |
+
"# Attribute names\n",
|
333 |
+
"print( description0.get_attributes() )"
|
334 |
+
]
|
335 |
+
}
|
336 |
+
],
|
337 |
+
"metadata": {
|
338 |
+
"kernelspec": {
|
339 |
+
"display_name": "Python 3",
|
340 |
+
"language": "python",
|
341 |
+
"name": "python3"
|
342 |
+
},
|
343 |
+
"language_info": {
|
344 |
+
"codemirror_mode": {
|
345 |
+
"name": "ipython",
|
346 |
+
"version": 3
|
347 |
+
},
|
348 |
+
"file_extension": ".py",
|
349 |
+
"mimetype": "text/x-python",
|
350 |
+
"name": "python",
|
351 |
+
"nbconvert_exporter": "python",
|
352 |
+
"pygments_lexer": "ipython3",
|
353 |
+
"version": "3.8.3"
|
354 |
+
}
|
355 |
+
},
|
356 |
+
"nbformat": 4,
|
357 |
+
"nbformat_minor": 5
|
358 |
+
}
|
fairsd/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy>=1.17.2
|
2 |
+
pandas>=0.25.1
|
3 |
+
scikit-learn>=0.22.1
|
4 |
+
fairlearn>=0.5.0
|
5 |
+
math>=3.9.1
|
fairsd/tests/__init__.py
ADDED
File without changes
|
fairsd/tests/age.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
fairsd/tests/discretizationtest.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
from unittest import TestCase
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
|
7 |
+
sys.path.append("../")
|
8 |
+
import fairsd.discretization as disc
|
9 |
+
|
10 |
+
|
11 |
+
class TestMDLP(TestCase):
|
12 |
+
def test_findCutPoints(self):
|
13 |
+
a = pd.read_csv("age.csv")
|
14 |
+
x = a["age"]
|
15 |
+
y = a["y_true"]
|
16 |
+
discretizer = disc.MDLP()
|
17 |
+
correct_result = [21.0, 23.0, 24.0, 27.0, 30.0, 35.0, 41.0, 54.0, 61.0, 67.0]
|
18 |
+
self.assertEqual(discretizer.findCutPoints(x, y), correct_result)
|
19 |
+
|
20 |
+
|
21 |
+
class TestEqualFrequency(TestCase):
|
22 |
+
def test_findCutPoints(self):
|
23 |
+
discretizer = disc.EqualFrequency(2, 0)
|
24 |
+
|
25 |
+
x = np.ones(49)
|
26 |
+
x = np.concatenate(
|
27 |
+
(x, [2])
|
28 |
+
) # no cut points expected because of the min_bin_size
|
29 |
+
self.assertEqual(discretizer.findCutPoints(x), [])
|
30 |
+
x = np.concatenate((x, [2])) # expected one cut point
|
31 |
+
self.assertEqual(discretizer.findCutPoints(x), [1])
|
32 |
+
x = np.concatenate((x, [2])) # expected one cut point
|
33 |
+
self.assertEqual(discretizer.findCutPoints(x), [1])
|
34 |
+
|
35 |
+
x = np.ones(5)
|
36 |
+
correct_res = []
|
37 |
+
self.assertEqual(
|
38 |
+
discretizer.findCutPoints(x), correct_res
|
39 |
+
) # no cut points expected
|
40 |
+
|
41 |
+
x = np.concatenate((x, np.zeros(5)))
|
42 |
+
correct_res.append(0)
|
43 |
+
self.assertEqual(discretizer.findCutPoints(x), correct_res)
|
44 |
+
|
45 |
+
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
46 |
+
self.assertEqual(discretizer.findCutPoints(x), [3, 5, 8])
|
47 |
+
discretizer = disc.EqualFrequency(1, 4)
|
48 |
+
self.assertEqual(discretizer.findCutPoints(x), [3, 5, 8])
|
49 |
+
discretizer = disc.EqualFrequency(1, 0)
|
50 |
+
self.assertEqual(discretizer.findCutPoints(x), [2, 3, 4, 5, 7, 8, 9])
|
51 |
+
x = np.array([1, 2, 2, 2, 2, 6, 7, 8, 9, 10])
|
52 |
+
self.assertEqual(discretizer.findCutPoints(x), [2, 7, 8, 9])
|
53 |
+
|
54 |
+
discretizer = disc.EqualFrequency(2, 0)
|
55 |
+
x = np.array([1, 7, 7, 7, 7, 7, 7, 8, 7, 7])
|
56 |
+
self.assertEqual(discretizer.findCutPoints(x), [])
|
57 |
+
|
58 |
+
|
59 |
+
class TestEqualWidth(TestCase):
|
60 |
+
def test_findCutPoints(self):
|
61 |
+
discretizer = disc.EqualWidth(2, 20)
|
62 |
+
x = np.zeros(9)
|
63 |
+
x = np.concatenate((x, [10]))
|
64 |
+
self.assertEqual(discretizer.findCutPoints(x), [2.5, 5, 7.5])
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
unittest.main()
|
fairsd/tests/sgdiscoverytest.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import unittest
|
3 |
+
from unittest import TestCase
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import sys
|
7 |
+
|
8 |
+
sys.path.append("../")
|
9 |
+
import fairsd as dsd
|
10 |
+
import random
|
11 |
+
|
12 |
+
|
13 |
+
def list_of_descriptors(num_des=3):
|
14 |
+
# create list of Descriptor
|
15 |
+
los = []
|
16 |
+
# nominal descriptors
|
17 |
+
for i in range(0, num_des):
|
18 |
+
attribute_name = "a" + str(i)
|
19 |
+
los.append(dsd.Descriptor(attribute_name=attribute_name, attribute_value=True))
|
20 |
+
# numeric descriptors
|
21 |
+
for i in range(0, num_des):
|
22 |
+
attribute_name = "b" + str(i)
|
23 |
+
los.append(
|
24 |
+
dsd.Descriptor(
|
25 |
+
attribute_name=attribute_name, up_bound=3, low_bound=1, is_numeric=True
|
26 |
+
)
|
27 |
+
)
|
28 |
+
return los
|
29 |
+
|
30 |
+
|
31 |
+
def istantiate_dataset():
|
32 |
+
data = {
|
33 |
+
"a0": [True, True, False, True, False, False],
|
34 |
+
"b0": [2, 2, 3, 4, 5, 6],
|
35 |
+
"y_true": [0, 0, 0, 1, 1, 1],
|
36 |
+
"y_pred": [0, 0, 1, 0, 1, 1],
|
37 |
+
}
|
38 |
+
return pd.DataFrame(data)
|
39 |
+
|
40 |
+
|
41 |
+
class TestDescriptorMethods(TestCase):
|
42 |
+
def test_is_present_in(self):
|
43 |
+
los = list_of_descriptors()
|
44 |
+
|
45 |
+
self.assertTrue(
|
46 |
+
dsd.Descriptor("a2", attribute_value=True).is_present_in(los), "error1"
|
47 |
+
)
|
48 |
+
self.assertFalse(
|
49 |
+
dsd.Descriptor("a1", attribute_value=False).is_present_in(los), "error2"
|
50 |
+
)
|
51 |
+
self.assertFalse(
|
52 |
+
dsd.Descriptor("b1", attribute_value=False).is_present_in(los), "error3"
|
53 |
+
)
|
54 |
+
self.assertTrue(
|
55 |
+
dsd.Descriptor(
|
56 |
+
"b1", up_bound=3, low_bound=1, is_numeric=True
|
57 |
+
).is_present_in(los),
|
58 |
+
"error4",
|
59 |
+
)
|
60 |
+
self.assertFalse(
|
61 |
+
dsd.Descriptor(
|
62 |
+
"b1", up_bound=3, low_bound=0, is_numeric=True
|
63 |
+
).is_present_in(los),
|
64 |
+
"error5",
|
65 |
+
)
|
66 |
+
self.assertFalse(
|
67 |
+
dsd.Descriptor(
|
68 |
+
"a1", up_bound=3, low_bound=1, is_numeric=True
|
69 |
+
).is_present_in(los),
|
70 |
+
"error6",
|
71 |
+
)
|
72 |
+
self.assertFalse(
|
73 |
+
dsd.Descriptor("c1", attribute_value=False).is_present_in(los), "error7"
|
74 |
+
)
|
75 |
+
self.assertFalse(
|
76 |
+
dsd.Descriptor(
|
77 |
+
"d1", up_bound=3, low_bound=1, is_numeric=True
|
78 |
+
).is_present_in(los),
|
79 |
+
"error8",
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
class TestDescriptionMethods(TestCase):
|
84 |
+
def setUp(self):
|
85 |
+
self.los = list_of_descriptors(3)
|
86 |
+
self.des = dsd.Description(self.los)
|
87 |
+
self.des.support = 5
|
88 |
+
self.des.set_quality(0.1)
|
89 |
+
|
90 |
+
def test__lt__(self):
|
91 |
+
# test with equal quality and different support
|
92 |
+
los1 = list_of_descriptors(1)
|
93 |
+
des1 = dsd.Description(los1)
|
94 |
+
des1.support = 5
|
95 |
+
des1.set_quality(0.1)
|
96 |
+
self.assertFalse(des1.__lt__(self.des))
|
97 |
+
des1.support = 4
|
98 |
+
des1.set_quality(0.1)
|
99 |
+
self.assertTrue(des1.__lt__(self.des))
|
100 |
+
des1.support = 6
|
101 |
+
des1.set_quality(0.1)
|
102 |
+
self.assertFalse(des1.__lt__(self.des))
|
103 |
+
|
104 |
+
# test with different qualityes
|
105 |
+
des1.set_quality(0.2)
|
106 |
+
self.assertFalse(des1.__lt__(self.des))
|
107 |
+
des1.set_quality(0.01)
|
108 |
+
self.assertTrue(des1.__lt__(self.des))
|
109 |
+
des1.set_quality(0.1)
|
110 |
+
self.assertFalse(des1.__lt__(self.des))
|
111 |
+
|
112 |
+
def test_to_boolean_array(self):
|
113 |
+
los = list_of_descriptors(1)
|
114 |
+
des = dsd.Description(los)
|
115 |
+
self.assertTrue(
|
116 |
+
np.array_equal(
|
117 |
+
des.to_boolean_array(istantiate_dataset()),
|
118 |
+
np.array([True, True, False, False, False, False]),
|
119 |
+
)
|
120 |
+
)
|
121 |
+
|
122 |
+
des = dsd.Description()
|
123 |
+
self.assertTrue(
|
124 |
+
np.array_equal(
|
125 |
+
des.to_boolean_array(istantiate_dataset()),
|
126 |
+
np.array([True, True, True, True, True, True]),
|
127 |
+
)
|
128 |
+
)
|
129 |
+
|
130 |
+
def test_get_attributes(self):
|
131 |
+
self.assertEqual(
|
132 |
+
self.des.get_attributes(), ["a0", "a1", "a2", "b0", "b1", "b2"]
|
133 |
+
)
|
134 |
+
|
135 |
+
def test_is_present_in(self):
|
136 |
+
list_t = []
|
137 |
+
for i in range(3):
|
138 |
+
list_t.append(dsd.Description(list_of_descriptors(i + 1)))
|
139 |
+
|
140 |
+
self.assertTrue(dsd.Description(list_of_descriptors(2)).is_present_in(list_t))
|
141 |
+
self.assertFalse(dsd.Description(list_of_descriptors(4)).is_present_in(list_t))
|
142 |
+
|
143 |
+
|
144 |
+
class TestDiscretizerMethods(TestCase):
|
145 |
+
"""
|
146 |
+
This test assumes that the classes Discretizer, Description and Descriptor works correctly.
|
147 |
+
These three classes are tested above.
|
148 |
+
"""
|
149 |
+
|
150 |
+
def test_discretize_mdlp(self):
|
151 |
+
discretizer = dsd.Discretizer(discretization_type="mdlp", target="y_true")
|
152 |
+
|
153 |
+
dataset = istantiate_dataset()
|
154 |
+
description = dsd.Description()
|
155 |
+
res = discretizer.discretize(dataset, description, "b0")
|
156 |
+
|
157 |
+
d0 = dsd.Descriptor(
|
158 |
+
"b0", up_bound=3, low_bound=None, to_discretize=False, is_numeric=True
|
159 |
+
)
|
160 |
+
d1 = dsd.Descriptor(
|
161 |
+
"b0", up_bound=None, low_bound=3, to_discretize=False, is_numeric=True
|
162 |
+
)
|
163 |
+
correct_result = [d0, d1]
|
164 |
+
self.assertEqual(len(res), 2)
|
165 |
+
self.assertTrue(correct_result[0].is_present_in(res))
|
166 |
+
self.assertTrue(correct_result[1].is_present_in(res))
|
167 |
+
|
168 |
+
def test_discretize_equalfreq(self):
|
169 |
+
discretizer = dsd.Discretizer(discretization_type="equalfreq", num_bins=2)
|
170 |
+
|
171 |
+
dataset = istantiate_dataset()
|
172 |
+
description = dsd.Description()
|
173 |
+
res = discretizer.discretize(dataset, description, "b0")
|
174 |
+
|
175 |
+
d0 = dsd.Descriptor(
|
176 |
+
"b0", up_bound=3, low_bound=None, to_discretize=False, is_numeric=True
|
177 |
+
)
|
178 |
+
d1 = dsd.Descriptor(
|
179 |
+
"b0", up_bound=None, low_bound=3, to_discretize=False, is_numeric=True
|
180 |
+
)
|
181 |
+
correct_result = [d0, d1]
|
182 |
+
self.assertEqual(len(res), 2)
|
183 |
+
self.assertTrue(correct_result[0].is_present_in(res))
|
184 |
+
self.assertTrue(correct_result[1].is_present_in(res))
|
185 |
+
|
186 |
+
def test_discretize_equalwidth(self):
|
187 |
+
discretizer = dsd.Discretizer(discretization_type="equalwidth", num_bins=2)
|
188 |
+
|
189 |
+
dataset = istantiate_dataset()
|
190 |
+
description = dsd.Description()
|
191 |
+
res = discretizer.discretize(dataset, description, "b0")
|
192 |
+
|
193 |
+
d0 = dsd.Descriptor(
|
194 |
+
"b0", up_bound=4, low_bound=None, to_discretize=False, is_numeric=True
|
195 |
+
)
|
196 |
+
d1 = dsd.Descriptor(
|
197 |
+
"b0", up_bound=None, low_bound=4, to_discretize=False, is_numeric=True
|
198 |
+
)
|
199 |
+
correct_result = [d0, d1]
|
200 |
+
self.assertEqual(len(res), 2)
|
201 |
+
self.assertTrue(correct_result[0].is_present_in(res))
|
202 |
+
self.assertTrue(correct_result[1].is_present_in(res))
|
203 |
+
|
204 |
+
|
205 |
+
class TestSearchSpaceMethods(TestCase):
|
206 |
+
def test_extract_search_space_AND_init_methods(self):
|
207 |
+
dataset = istantiate_dataset()
|
208 |
+
ss = dsd.SearchSpace(
|
209 |
+
dataset,
|
210 |
+
ignore=["y_pred", "y_true"],
|
211 |
+
discretizer=dsd.Discretizer(discretization_type="equalfreq", num_bins=2),
|
212 |
+
)
|
213 |
+
|
214 |
+
# creation of correct result
|
215 |
+
d0 = dsd.Descriptor("a0", attribute_value=True, is_numeric=False)
|
216 |
+
d1 = dsd.Descriptor("a0", attribute_value=False, is_numeric=False)
|
217 |
+
d2 = dsd.Descriptor(
|
218 |
+
"b0", up_bound=3, low_bound=None, to_discretize=False, is_numeric=True
|
219 |
+
)
|
220 |
+
d3 = dsd.Descriptor(
|
221 |
+
"b0", up_bound=None, low_bound=3, to_discretize=False, is_numeric=True
|
222 |
+
)
|
223 |
+
correct_result = [d0, d1, d2, d3]
|
224 |
+
|
225 |
+
discretizer = dsd.Discretizer(discretization_type="mdlp", target="y_true")
|
226 |
+
result = ss.extract_search_space(dataset, discretizer)
|
227 |
+
|
228 |
+
# batch of tests 1
|
229 |
+
self.assertEqual(len(result), 4)
|
230 |
+
self.assertTrue(correct_result[0].is_present_in(result))
|
231 |
+
self.assertTrue(correct_result[1].is_present_in(result))
|
232 |
+
self.assertTrue(correct_result[2].is_present_in(result))
|
233 |
+
self.assertTrue(correct_result[3].is_present_in(result))
|
234 |
+
|
235 |
+
# batch of tests 2: with not null description
|
236 |
+
descr = dsd.Description(
|
237 |
+
[d2]
|
238 |
+
) # b0>3, This description excludes the feature b0 because all b0>3 have positive class
|
239 |
+
result = ss.extract_search_space(dataset, discretizer, descr)
|
240 |
+
|
241 |
+
self.assertEqual(len(result), 2)
|
242 |
+
self.assertTrue(correct_result[0].is_present_in(result))
|
243 |
+
self.assertTrue(correct_result[1].is_present_in(result))
|
244 |
+
|
245 |
+
# batch of tests 3: Numeric attribute treated as nominal a one
|
246 |
+
ss = dsd.SearchSpace(
|
247 |
+
dataset, ignore=["y_pred", "y_true"], nominal_features=["b0"]
|
248 |
+
)
|
249 |
+
result = ss.extract_search_space(dataset, discretizer)
|
250 |
+
|
251 |
+
# creation of correct result
|
252 |
+
d0 = dsd.Descriptor("a0", attribute_value=True, is_numeric=False)
|
253 |
+
d1 = dsd.Descriptor("a0", attribute_value=False, is_numeric=False)
|
254 |
+
d2 = dsd.Descriptor("b0", attribute_value=2, is_numeric=False)
|
255 |
+
d3 = dsd.Descriptor("b0", attribute_value=3, is_numeric=False)
|
256 |
+
d4 = dsd.Descriptor("b0", attribute_value=4, is_numeric=False)
|
257 |
+
d5 = dsd.Descriptor("b0", attribute_value=5, is_numeric=False)
|
258 |
+
d6 = dsd.Descriptor("b0", attribute_value=6, is_numeric=False)
|
259 |
+
correct_result = [d0, d1, d2, d3, d4, d5, d6]
|
260 |
+
|
261 |
+
self.assertEqual(len(result), 7)
|
262 |
+
self.assertTrue(correct_result[0].is_present_in(result))
|
263 |
+
self.assertTrue(correct_result[1].is_present_in(result))
|
264 |
+
self.assertTrue(correct_result[2].is_present_in(result))
|
265 |
+
self.assertTrue(correct_result[3].is_present_in(result))
|
266 |
+
self.assertTrue(correct_result[4].is_present_in(result))
|
267 |
+
self.assertTrue(correct_result[5].is_present_in(result))
|
268 |
+
self.assertTrue(correct_result[6].is_present_in(result))
|
269 |
+
|
270 |
+
|
271 |
+
COSTANT_SEED = 3
|
272 |
+
|
273 |
+
|
274 |
+
class TestQF(dsd.QualityFunction):
|
275 |
+
"""
|
276 |
+
This quality function is created only for testing the sg_discovery algorithms.
|
277 |
+
The evaluate function return a pseudo random number each time it is called.
|
278 |
+
The seed is always =3, in this way the sequence of returned numbers is always reproducible.
|
279 |
+
"""
|
280 |
+
|
281 |
+
def __init__(self):
|
282 |
+
random.seed(a=COSTANT_SEED)
|
283 |
+
|
284 |
+
def evaluate(self, y_true=None, y_pred=None, sensitive_features=None):
|
285 |
+
return random.random()
|
286 |
+
|
287 |
+
|
288 |
+
class TestBeamSeacrchMethods(TestCase):
|
289 |
+
"""
|
290 |
+
A toy dataset is used for testing. A maximum of 8 descriptions can be generated from this dataset,
|
291 |
+
in this way the result of the algorithm can be calculated manually.
|
292 |
+
The quality function is the TestQF defined above.
|
293 |
+
"""
|
294 |
+
|
295 |
+
def test_execute(self):
|
296 |
+
data = {
|
297 |
+
"a0": [True, True, False, True, False, False],
|
298 |
+
"b0": ["a", "b", "a", "b", "a", "b"],
|
299 |
+
}
|
300 |
+
x = pd.DataFrame(data)
|
301 |
+
y_true = pd.Series([0, 0, 0, 1, 1, 1])
|
302 |
+
task = dsd.SubgroupDiscoveryTask(
|
303 |
+
x, y_true, qf=TestQF().evaluate, depth=6, min_quality=0, min_support=0
|
304 |
+
)
|
305 |
+
|
306 |
+
result_set = dsd.BeamSearch(beam_width=10).execute(task)
|
307 |
+
result_set = result_set.descriptions_list
|
308 |
+
self.assertEqual(len(result_set), 8)
|
309 |
+
|
310 |
+
# create correct result
|
311 |
+
correct_result = []
|
312 |
+
random.seed(a=COSTANT_SEED)
|
313 |
+
correct_result.append((random.random(), "a0 = 'True' "))
|
314 |
+
correct_result.append((random.random(), "a0 = 'False' "))
|
315 |
+
correct_result.append((random.random(), "b0 = 'a' "))
|
316 |
+
correct_result.append((random.random(), "b0 = 'b' "))
|
317 |
+
correct_result.append((random.random(), "a0 = 'True' AND b0 = 'a' "))
|
318 |
+
correct_result.append((random.random(), "a0 = 'True' AND b0 = 'b' "))
|
319 |
+
correct_result.append((random.random(), "a0 = 'False' AND b0 = 'a' "))
|
320 |
+
correct_result.append((random.random(), "a0 = 'False' AND b0 = 'b' "))
|
321 |
+
correct_result.sort(reverse=True)
|
322 |
+
|
323 |
+
for i in range(8):
|
324 |
+
self.assertEqual(correct_result[i][1], result_set[i].__repr__())
|
325 |
+
|
326 |
+
# test 2 with reduced result set size
|
327 |
+
task = dsd.SubgroupDiscoveryTask(
|
328 |
+
x,
|
329 |
+
y_true,
|
330 |
+
qf=TestQF().evaluate,
|
331 |
+
result_set_size=3,
|
332 |
+
min_quality=0,
|
333 |
+
min_support=0,
|
334 |
+
)
|
335 |
+
result_set = dsd.BeamSearch(beam_width=10).execute(task)
|
336 |
+
result_set = result_set.descriptions_list
|
337 |
+
|
338 |
+
self.assertEqual(len(result_set), 3)
|
339 |
+
for i in range(3):
|
340 |
+
self.assertEqual(correct_result[i][1], result_set[i].__repr__())
|
341 |
+
|
342 |
+
|
343 |
+
class TestDSSDMethods(TestCase):
|
344 |
+
"""
|
345 |
+
A toy dataset is used for testing. A maximum of 8 descriptions can be generated from this dataset,
|
346 |
+
in this way the result of the algorithm can be calculated manually.
|
347 |
+
The quality function is the TestQF defined above.
|
348 |
+
"""
|
349 |
+
|
350 |
+
def test_execute(self):
|
351 |
+
data = {
|
352 |
+
"a0": [True, True, False, True, False, False],
|
353 |
+
"b0": ["a", "b", "a", "b", "a", "b"],
|
354 |
+
}
|
355 |
+
x = pd.DataFrame(data)
|
356 |
+
y_true = pd.Series([0, 0, 0, 1, 1, 1])
|
357 |
+
task = dsd.SubgroupDiscoveryTask(
|
358 |
+
x, y_true, qf=TestQF().evaluate, depth=1, min_quality=0, min_support=0
|
359 |
+
)
|
360 |
+
|
361 |
+
result_set = dsd.DSSD(beam_width=10, a=1).execute(task)
|
362 |
+
|
363 |
+
result_set = result_set.descriptions_list
|
364 |
+
self.assertEqual(len(result_set), 4)
|
365 |
+
|
366 |
+
# create correct result
|
367 |
+
correct_result = []
|
368 |
+
random.seed(a=COSTANT_SEED)
|
369 |
+
correct_result.append((random.random(), "a0 = 'True' "))
|
370 |
+
correct_result.append((random.random(), "a0 = 'False' "))
|
371 |
+
correct_result.append((random.random(), "b0 = 'a' "))
|
372 |
+
correct_result.append((random.random(), "b0 = 'b' "))
|
373 |
+
correct_result.sort(reverse=True)
|
374 |
+
|
375 |
+
for i in range(4):
|
376 |
+
self.assertEqual(correct_result[i][1], result_set[i].__repr__())
|
377 |
+
|
378 |
+
# test 2 with reduced result set size
|
379 |
+
task = dsd.SubgroupDiscoveryTask(
|
380 |
+
x,
|
381 |
+
y_true,
|
382 |
+
qf=TestQF().evaluate,
|
383 |
+
depth=1,
|
384 |
+
result_set_size=3,
|
385 |
+
min_quality=0,
|
386 |
+
min_support=0,
|
387 |
+
)
|
388 |
+
result_set = dsd.DSSD(beam_width=10, a=1).execute(task)
|
389 |
+
result_set = result_set.descriptions_list
|
390 |
+
|
391 |
+
self.assertEqual(len(result_set), 3)
|
392 |
+
for i in range(3):
|
393 |
+
self.assertEqual(correct_result[i][1], result_set[i].__repr__())
|
394 |
+
|
395 |
+
|
396 |
+
if __name__ == "__main__":
|
397 |
+
unittest.main()
|
images/screenshot.png
ADDED
![]() |
load.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import time
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
from sklearn.ensemble import RandomForestClassifier
|
6 |
+
import shap
|
7 |
+
import pandas as pd
|
8 |
+
import numpy as np
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
from sklearn.datasets import fetch_openml
|
11 |
+
from sklearn.tree import DecisionTreeClassifier
|
12 |
+
from utils import (
|
13 |
+
combine_one_hot,
|
14 |
+
combine_all_one_hot_shap_logloss,
|
15 |
+
)
|
16 |
+
|
17 |
+
from fairsd import fairsd
|
18 |
+
from fairsd.fairsd.algorithms import ResultSet
|
19 |
+
|
20 |
+
|
21 |
+
def import_dataset(dataset: str, sensitive_features: List[str] = None):
|
22 |
+
"""Import the dataset from OpenML and preprocess it.
|
23 |
+
Args:
|
24 |
+
dataset (str): Dataset to be imported. Supported datasets: adult, german_credit, heloc, credit.
|
25 |
+
sensitive_features (list, optional): Sensitive features to be used in the dataset. Defaults to None.
|
26 |
+
"""
|
27 |
+
|
28 |
+
print("Loading data...")
|
29 |
+
cols_to_drop = []
|
30 |
+
# Import dataset
|
31 |
+
if dataset == "adult":
|
32 |
+
dataset_id = 1590
|
33 |
+
target = ">50K"
|
34 |
+
cols_to_drop = ["fnlwgt", "education-num"]
|
35 |
+
if sensitive_features is not None:
|
36 |
+
cols_to_drop = [col for col in cols_to_drop if col not in sensitive_features]
|
37 |
+
elif dataset in ("credit_g", "german", "german_credit"):
|
38 |
+
dataset_id = 31
|
39 |
+
target = "good"
|
40 |
+
cols_to_drop = ["personal_status", "other_parties", "residence_since", "foreign_worker"]
|
41 |
+
if sensitive_features is not None:
|
42 |
+
cols_to_drop = [col for col in cols_to_drop if col not in sensitive_features]
|
43 |
+
elif dataset == "heloc":
|
44 |
+
dataset_id = 45023
|
45 |
+
target = "1"
|
46 |
+
elif dataset == "credit":
|
47 |
+
dataset_id = 43978
|
48 |
+
target = 1
|
49 |
+
else:
|
50 |
+
raise NotImplementedError(
|
51 |
+
"Only the following datasets are supported now: adult, german_credit, heloc, credit."
|
52 |
+
)
|
53 |
+
|
54 |
+
d = fetch_openml(data_id=dataset_id, as_frame=True, parser="auto")
|
55 |
+
X = d.data
|
56 |
+
if sensitive_features is not None:
|
57 |
+
incorrect_sensitive_features = [
|
58 |
+
col for col in sensitive_features if col not in X.columns
|
59 |
+
]
|
60 |
+
if incorrect_sensitive_features:
|
61 |
+
raise ValueError(
|
62 |
+
f"Sensitive features {incorrect_sensitive_features} not found in the dataset."
|
63 |
+
)
|
64 |
+
|
65 |
+
X = X.drop(cols_to_drop, axis=1, errors="ignore")
|
66 |
+
|
67 |
+
# Fill missing values - "missing" for categorical, mean for numerical
|
68 |
+
for col in X.columns:
|
69 |
+
if X[col].dtype.name == "category":
|
70 |
+
X[col] = X[col].cat.add_categories("missing").fillna("missing")
|
71 |
+
elif X[col].dtype.name == "object":
|
72 |
+
X[col] = X[col].fillna("missing")
|
73 |
+
else:
|
74 |
+
X[col] = X[col].fillna(X[col].mean())
|
75 |
+
|
76 |
+
# Get target as 1/0
|
77 |
+
y_true = (d.target == target) * 1
|
78 |
+
return X, y_true
|
79 |
+
|
80 |
+
|
81 |
+
def load_data(
|
82 |
+
dataset="adult",
|
83 |
+
n_samples=0,
|
84 |
+
test_split=0.3,
|
85 |
+
sensitive_features: List[str] = None,
|
86 |
+
) -> Tuple[pd.DataFrame, pd.Series, pd.Series, pd.DataFrame, pd.DataFrame, List[str]]:
|
87 |
+
"""Load data from UCI Adult dataset.
|
88 |
+
Args:
|
89 |
+
dataset (str): Dataset to be loaded, currently only Adult and German Credit datasets are supported
|
90 |
+
n_samples (int, optional): Number of samples to load. Select 0 to load all.
|
91 |
+
test_split (bool, optional): Ratio of selected data samples for test set
|
92 |
+
sensitive_features (list, optional): Sensitive features to be used in the dataset. Defaults to None.
|
93 |
+
"""
|
94 |
+
X, y_true = import_dataset(dataset, sensitive_features)
|
95 |
+
|
96 |
+
# Get categorical feature names
|
97 |
+
cat_features = X.select_dtypes(include=["category"]).columns
|
98 |
+
|
99 |
+
# One-hot encoding
|
100 |
+
onehot_X = pd.get_dummies(X) * 1
|
101 |
+
|
102 |
+
if n_samples and X.shape[0] >= n_samples:
|
103 |
+
# Load only n examples
|
104 |
+
X = X.iloc[:n_samples]
|
105 |
+
y_true = y_true.iloc[:n_samples]
|
106 |
+
onehot_X = onehot_X.iloc[:n_samples]
|
107 |
+
|
108 |
+
if test_split and float(test_split) < 1.0:
|
109 |
+
(
|
110 |
+
X_train,
|
111 |
+
X_test,
|
112 |
+
y_true_train,
|
113 |
+
y_true_test,
|
114 |
+
onehot_X_train,
|
115 |
+
onehot_X_test,
|
116 |
+
) = train_test_split(X, y_true, onehot_X, test_size=test_split, random_state=0)
|
117 |
+
# Reset indices
|
118 |
+
X_train.reset_index(inplace=True, drop=True)
|
119 |
+
X_test.reset_index(inplace=True, drop=True)
|
120 |
+
y_true_train = y_true_train.reset_index(drop=True)
|
121 |
+
y_true_test = y_true_test.reset_index(drop=True)
|
122 |
+
onehot_X_train.reset_index(inplace=True, drop=True)
|
123 |
+
onehot_X_test.reset_index(inplace=True, drop=True)
|
124 |
+
|
125 |
+
X_test.reset_index(inplace=True, drop=True)
|
126 |
+
y_true_test.reset_index(inplace=True, drop=True)
|
127 |
+
|
128 |
+
return (
|
129 |
+
X_test,
|
130 |
+
y_true_train,
|
131 |
+
y_true_test,
|
132 |
+
onehot_X_train,
|
133 |
+
onehot_X_test,
|
134 |
+
cat_features,
|
135 |
+
)
|
136 |
+
|
137 |
+
else:
|
138 |
+
return X, y_true, y_true, onehot_X, onehot_X, cat_features
|
139 |
+
|
140 |
+
|
141 |
+
def add_bias(
|
142 |
+
bias: str, X_test: pd.DataFrame, onehot_X_test: pd.DataFrame, subgroup: pd.Series
|
143 |
+
) -> pd.Series:
|
144 |
+
"""Add bias to the dataset."""
|
145 |
+
|
146 |
+
if bias in ("random", "noise"):
|
147 |
+
feature = "capital-gain"
|
148 |
+
# Add random noise to the subset
|
149 |
+
std_val = X_test[feature].std()
|
150 |
+
mean_val = X_test[feature].mean()
|
151 |
+
X_test.loc[subgroup, feature] += np.random.normal(
|
152 |
+
mean_val, std_val, sum(subgroup)
|
153 |
+
)
|
154 |
+
onehot_X_test.loc[subgroup, feature] += np.random.normal(
|
155 |
+
mean_val, std_val, sum(subgroup)
|
156 |
+
)
|
157 |
+
elif bias == "mean":
|
158 |
+
feature = "age"
|
159 |
+
# Add the mean of the feature to the subset
|
160 |
+
X_test.loc[subgroup, feature] = X_test[feature].mean()
|
161 |
+
onehot_X_test.loc[subgroup, feature] = onehot_X_test[feature].mean()
|
162 |
+
elif bias == "median":
|
163 |
+
feature = "age"
|
164 |
+
# Add the median of the feature to the subset
|
165 |
+
X_test.loc[subgroup, feature] = X_test[feature].median()
|
166 |
+
onehot_X_test.loc[subgroup, feature] = onehot_X_test[feature].median()
|
167 |
+
elif bias in ("bin", "binning"):
|
168 |
+
feature = "age"
|
169 |
+
# Add the binning of the feature to the subset
|
170 |
+
X_test.loc[subgroup, feature] = X_test[subgroup].apply(
|
171 |
+
lambda x: x[feature] // 20 * 20, axis=1
|
172 |
+
)
|
173 |
+
onehot_X_test.loc[subgroup, feature] = onehot_X_test[subgroup].apply(
|
174 |
+
lambda x: x // 20 * 20, axis=1
|
175 |
+
)
|
176 |
+
elif bias == "sum_std":
|
177 |
+
feature = "age"
|
178 |
+
std_val = X_test[feature].std()
|
179 |
+
# Add the standard deviation of the feature to the subset
|
180 |
+
X_test.loc[subgroup, feature] = X_test[subgroup].apply(
|
181 |
+
lambda x: x[feature] + std_val, axis=1
|
182 |
+
)
|
183 |
+
onehot_X_test.loc[subgroup, feature] = onehot_X_test[feature].sum()
|
184 |
+
|
185 |
+
elif bias == "swap":
|
186 |
+
# Swap all values of the feature to another value in the same column
|
187 |
+
# feature = "education"
|
188 |
+
# value_selected = "Doctorate"
|
189 |
+
feature = "marital-status"
|
190 |
+
value_selected = "Married-civ-spouse"
|
191 |
+
X_test.loc[subgroup, feature] = value_selected
|
192 |
+
# Onehot_X_test is one hot encoded so we need to swap the entire column for the feature
|
193 |
+
onehot_X_test.loc[subgroup, onehot_X_test.columns.str.startswith(feature)] = 0
|
194 |
+
onehot_X_test.loc[subgroup, [feature + "_" + value_selected]] = 1
|
195 |
+
else:
|
196 |
+
raise ValueError(
|
197 |
+
f"Bias method '{bias}' not supported. Supported methods: random, mean, median, bin, sum_std, swap."
|
198 |
+
)
|
199 |
+
print(
|
200 |
+
f"Added bias to the dataset by method: {bias}. Feature {feature} was affected. Size of the subset impacted: {sum(subgroup)}."
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
def get_classifier(
|
205 |
+
onehot_X_train: pd.DataFrame,
|
206 |
+
y_true_train: pd.Series,
|
207 |
+
onehot_X_test: pd.DataFrame,
|
208 |
+
with_names=True,
|
209 |
+
model="rf",
|
210 |
+
) -> Tuple[object, np.ndarray]:
|
211 |
+
"""Get a decision tree classifier for the given dataset.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
onehot_X_train: one-hot encoded features dataframe used for training
|
215 |
+
y_true_train: true labels for training
|
216 |
+
onehot_X_test: one-hot encoded features dataframe used for testing
|
217 |
+
Returns:
|
218 |
+
classifier: trained classifier
|
219 |
+
y_pred_prob: predicted probabilities of the positive class
|
220 |
+
"""
|
221 |
+
if model == "rf":
|
222 |
+
classifier = RandomForestClassifier(
|
223 |
+
n_estimators=20, max_depth=8, random_state=0
|
224 |
+
)
|
225 |
+
elif model == "dt":
|
226 |
+
classifier = DecisionTreeClassifier(min_samples_leaf=30, max_depth=8)
|
227 |
+
elif model == "xgb":
|
228 |
+
from xgboost import XGBClassifier
|
229 |
+
classifier = XGBClassifier() # FIXME: Issue with feature names in german credit dataset
|
230 |
+
else:
|
231 |
+
raise ValueError(f"Model {model} not supported. Supported models: rf, dt, xgb.")
|
232 |
+
|
233 |
+
if not with_names:
|
234 |
+
onehot_X_train = onehot_X_train.values
|
235 |
+
onehot_X_test = onehot_X_test.values
|
236 |
+
|
237 |
+
print("Training the classifier...")
|
238 |
+
classifier.fit(onehot_X_train, y_true_train)
|
239 |
+
|
240 |
+
y_pred_prob = classifier.predict_proba(onehot_X_test)
|
241 |
+
|
242 |
+
return classifier, y_pred_prob[:, 1]
|
243 |
+
|
244 |
+
|
245 |
+
def combine_shap_one_hot(shap_values, X_columns, cat_features):
|
246 |
+
# Combine one-hot encoded cat_features
|
247 |
+
non_cat_features = [col for col in X_columns if col not in cat_features]
|
248 |
+
for cat_feat in cat_features:
|
249 |
+
# Get column masks for each cat feature
|
250 |
+
col_masks = [
|
251 |
+
col.startswith(cat_feat) and col not in non_cat_features
|
252 |
+
for col in shap_values.feature_names
|
253 |
+
]
|
254 |
+
|
255 |
+
shap_values = combine_one_hot(
|
256 |
+
shap_values, cat_feat, col_masks, return_original=False
|
257 |
+
)
|
258 |
+
return shap_values
|
259 |
+
|
260 |
+
|
261 |
+
def get_shap_values(classifier, d_train, X, cat_features, combine_cat_features=True):
|
262 |
+
"""Get shap values for a given classifier and dataset.
|
263 |
+
Combines one-hot encoded categorical features into original features."""
|
264 |
+
# Producing shap values
|
265 |
+
explainer = shap.TreeExplainer(classifier)
|
266 |
+
# TODO: Add and evaluate interation values
|
267 |
+
# shap_interaction = explainer.shap_interaction_values(X)
|
268 |
+
|
269 |
+
shap_values = explainer(d_train)
|
270 |
+
|
271 |
+
if combine_cat_features:
|
272 |
+
shap_values = combine_shap_one_hot(shap_values, X.columns, cat_features)
|
273 |
+
return shap_values
|
274 |
+
|
275 |
+
|
276 |
+
def get_shap_logloss(
|
277 |
+
classifier, d_train, y_true, X, cat_features, combine_cat_features=True
|
278 |
+
):
|
279 |
+
"""Get shap values of the model log loss for a given classifier and dataset.
|
280 |
+
Combines one-hot encoded categorical features into original features."""
|
281 |
+
explainer_bg_100 = shap.TreeExplainer(
|
282 |
+
classifier,
|
283 |
+
shap.sample(d_train, 100),
|
284 |
+
feature_perturbation="interventional",
|
285 |
+
model_output="log_loss",
|
286 |
+
)
|
287 |
+
shap_values_logloss_all = explainer_bg_100.shap_values(d_train, y_true)
|
288 |
+
if len(shap_values_logloss_all) == 2:
|
289 |
+
shap_values_logloss_all = shap_values_logloss_all[1]
|
290 |
+
shap_logloss_df = pd.DataFrame(shap_values_logloss_all, columns=d_train.columns)
|
291 |
+
if combine_cat_features:
|
292 |
+
shap_logloss_df = combine_all_one_hot_shap_logloss(
|
293 |
+
shap_logloss_df, X.columns, cat_features
|
294 |
+
)
|
295 |
+
return shap_logloss_df
|
296 |
+
|
297 |
+
|
298 |
+
def get_fairsd_result_set(
|
299 |
+
X,
|
300 |
+
y_true,
|
301 |
+
y_pred,
|
302 |
+
qf="equalized_odds_difference",
|
303 |
+
method="to_overall",
|
304 |
+
min_quality=0.01,
|
305 |
+
depth=1,
|
306 |
+
min_support=100,
|
307 |
+
result_set_size=30,
|
308 |
+
sensitive_features=None,
|
309 |
+
**kwargs,
|
310 |
+
) -> ResultSet:
|
311 |
+
"""Get result set from fairsd DSSD task
|
312 |
+
|
313 |
+
Args:
|
314 |
+
X (pd.DataFrame): Dataset
|
315 |
+
y_true (pd.Series): True labels
|
316 |
+
y_pred (pd.Series): Predicted labels
|
317 |
+
qf (str, optional): Quality function. Defaults to "equalized_odds_difference".
|
318 |
+
depth (int, optional): Depth of subgroup discovery.
|
319 |
+
min_support (int, optional): Minimum support.
|
320 |
+
result_set_size (int, optional): Size of result set.
|
321 |
+
kwargs: Additional arguments to pass to fairsd.SubgroupDiscoveryTask,
|
322 |
+
including result_set_ratio and logging_level (as defined by logging module)
|
323 |
+
"""
|
324 |
+
if sensitive_features is not None:
|
325 |
+
X = X[sensitive_features].copy()
|
326 |
+
task = fairsd.SubgroupDiscoveryTask(
|
327 |
+
X,
|
328 |
+
y_true,
|
329 |
+
y_pred,
|
330 |
+
qf=qf,
|
331 |
+
depth=depth,
|
332 |
+
result_set_size=result_set_size,
|
333 |
+
min_quality=min_quality,
|
334 |
+
min_support=min_support,
|
335 |
+
**kwargs,
|
336 |
+
)
|
337 |
+
if "logging_level" in kwargs:
|
338 |
+
logging_level = kwargs["logging_level"]
|
339 |
+
else:
|
340 |
+
logging_level = logging.WARNING
|
341 |
+
|
342 |
+
if logging_level < logging.WARNING:
|
343 |
+
# Only print if logging level is lower than WARNING
|
344 |
+
print(f"Running DSSD...")
|
345 |
+
start = time.time()
|
346 |
+
if method == "to_overall":
|
347 |
+
result_set = fairsd.DSSD(beam_width=30).execute(task, method="to_overall")
|
348 |
+
else:
|
349 |
+
result_set = fairsd.DSSD(beam_width=30).execute(task)
|
350 |
+
if logging_level < logging.WARNING:
|
351 |
+
print(f"DSSD Done! Total time: {time.time() - start}")
|
352 |
+
return result_set
|
metrics.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from typing import Callable, Union
|
4 |
+
from sklearn.metrics import (
|
5 |
+
roc_auc_score,
|
6 |
+
brier_score_loss,
|
7 |
+
log_loss,
|
8 |
+
accuracy_score,
|
9 |
+
f1_score,
|
10 |
+
average_precision_score,
|
11 |
+
)
|
12 |
+
from sklearn.calibration import calibration_curve
|
13 |
+
from fairlearn.metrics import make_derived_metric
|
14 |
+
|
15 |
+
true_positive_score = lambda y_true, y_pred: (y_true & y_pred).sum() / y_true.sum()
|
16 |
+
false_positive_score = (
|
17 |
+
lambda y_true, y_pred: ((1 - y_true) & y_pred).sum() / ((1 - y_true)).sum()
|
18 |
+
)
|
19 |
+
false_negative_score = lambda y_true, y_pred: 1 - true_positive_score(y_true, y_pred)
|
20 |
+
Y_PRED_METRICS = (
|
21 |
+
"auprc_diff",
|
22 |
+
"auprc_ratio",
|
23 |
+
"acc_diff",
|
24 |
+
"f1_diff",
|
25 |
+
"f1_ratio",
|
26 |
+
"equalized_odds_diff",
|
27 |
+
"equalized_odds_ratio",
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def average_log_loss_score(y_true, y_pred):
|
32 |
+
"""Average log loss function."""
|
33 |
+
return np.mean(log_loss(y_true, y_pred))
|
34 |
+
|
35 |
+
|
36 |
+
def miscalibration_score(y_true, y_pred, n_bins=10):
|
37 |
+
"""Miscalibration score. Calibration is the difference between the predicted and the true probability of the positive class."""
|
38 |
+
prob_true, prob_pred = calibration_curve(y_true, y_pred, n_bins=n_bins)
|
39 |
+
return np.mean(np.abs(prob_true - prob_pred))
|
40 |
+
|
41 |
+
|
42 |
+
def get_qf_from_str(
|
43 |
+
metric: str, transform: str = "difference"
|
44 |
+
) -> Union[Callable[[pd.Series, pd.Series, pd.Series], float], str]:
|
45 |
+
"""Get the quality function from a string.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
metric (str): Name of the metric. If None, the default metric is used.
|
49 |
+
transform (str): Type of the metric. Can be "difference", "ratio" or other fairlearn supported transforms
|
50 |
+
|
51 |
+
Returns: the quality function according to the selected metric - a defined function
|
52 |
+
or a string (in case of equalized odds difference as it's already defined in fairsd)
|
53 |
+
"""
|
54 |
+
# Preprocess metric string. If it ends with "diff" or "ratio", set transform accordingly
|
55 |
+
if metric.split("_")[-1] == "diff":
|
56 |
+
transform = "difference"
|
57 |
+
elif metric.split("_")[-1] == "ratio":
|
58 |
+
transform = "ratio"
|
59 |
+
|
60 |
+
metric = trim_transform_from_str(metric).lower()
|
61 |
+
|
62 |
+
if metric in ("equalized_odds", "eo", "eo_diff"):
|
63 |
+
qf = (
|
64 |
+
"equalized_odds_difference"
|
65 |
+
if transform == "difference"
|
66 |
+
else "equalized_odds_ratio"
|
67 |
+
)
|
68 |
+
elif metric in ("brier_score", "brier_score_loss"):
|
69 |
+
qf = make_derived_metric(metric=brier_score_loss, transform=transform)
|
70 |
+
elif metric in ("log_loss", "loss", "total_loss"):
|
71 |
+
qf = make_derived_metric(metric=log_loss, transform=transform)
|
72 |
+
elif metric in ("accuracy", "accuracy_score", "acc"):
|
73 |
+
qf = make_derived_metric(metric=accuracy_score, transform=transform)
|
74 |
+
elif metric in ("f1", "f1_score"):
|
75 |
+
qf = make_derived_metric(metric=f1_score, transform=transform)
|
76 |
+
elif metric in ("al", "average_loss", "average_log_loss"):
|
77 |
+
qf = make_derived_metric(metric=average_log_loss_score, transform=transform)
|
78 |
+
elif metric in ("roc_auc", "auroc", "auc_roc", "roc_auc_score"):
|
79 |
+
qf = make_derived_metric(metric=roc_auc_score, transform=transform)
|
80 |
+
elif metric in ("miscalibration", "miscal", "cal", "calibration"):
|
81 |
+
qf = make_derived_metric(metric=miscalibration_score, transform=transform)
|
82 |
+
elif metric in (
|
83 |
+
"auprc",
|
84 |
+
"pr_auc",
|
85 |
+
"precision_recall_auc",
|
86 |
+
"average_precision_score",
|
87 |
+
):
|
88 |
+
qf = make_derived_metric(metric=average_precision_score, transform=transform)
|
89 |
+
elif metric in ("false_positive_rate", "fpr"):
|
90 |
+
qf = make_derived_metric(metric=false_positive_score, transform=transform)
|
91 |
+
elif metric in ("true_positive_rate", "tpr"):
|
92 |
+
qf = make_derived_metric(metric=true_positive_score, transform=transform)
|
93 |
+
elif metric in ("fnr", "false_negative_rate"):
|
94 |
+
qf = make_derived_metric(metric=false_negative_score, transform=transform)
|
95 |
+
else:
|
96 |
+
raise ValueError(
|
97 |
+
f"Metric: {metric} not supported. "
|
98 |
+
"Metric must be one of the following: "
|
99 |
+
"equalized_odds, brier_score_loss, log_loss, accuracy_score, average_loss, "
|
100 |
+
"roc_auc_diff, miscalibration_diff, auprc_diff, fpr_diff, tpr_diff"
|
101 |
+
)
|
102 |
+
|
103 |
+
return qf
|
104 |
+
|
105 |
+
|
106 |
+
def get_name_from_metric_str(metric: str) -> str:
|
107 |
+
"""Get the name of the metric from a string nicely formatted."""
|
108 |
+
metric = trim_transform_from_str(metric)
|
109 |
+
if metric in ("equalized_odds", "eo"):
|
110 |
+
return "Equalized Odds"
|
111 |
+
# Split words and Capitalize the first letters
|
112 |
+
return " ".join(
|
113 |
+
[
|
114 |
+
word.upper()
|
115 |
+
if word in ("auprc, auroc", "auc", "roc", "prc", "tpr", "fpr", "fnr")
|
116 |
+
else word.capitalize()
|
117 |
+
for word in metric.split("_")
|
118 |
+
]
|
119 |
+
)
|
120 |
+
|
121 |
+
|
122 |
+
def trim_transform_from_str(metric: str) -> str:
|
123 |
+
"""Trim the transform from a string."""
|
124 |
+
if metric.split("_")[-1] == "diff" or metric.split("_")[-1] == "ratio":
|
125 |
+
metric = "_".join(metric.split("_")[:-1])
|
126 |
+
return metric
|
127 |
+
|
128 |
+
|
129 |
+
def get_quality_metric_from_str(metric: str) -> Callable[[pd.Series, pd.Series], float]:
|
130 |
+
"""Get the quality metric from a string."""
|
131 |
+
|
132 |
+
if metric.split("_")[-1] == "diff" or metric.split("_")[-1] == "ratio":
|
133 |
+
metric = "_".join(metric.split("_")[:-1]).lower()
|
134 |
+
|
135 |
+
if metric in ("equalized_odds", "eo"):
|
136 |
+
# Get max of tpr and fpr
|
137 |
+
return (
|
138 |
+
lambda y_true, y_pred:
|
139 |
+
"TPR: "
|
140 |
+
+ str(true_positive_score(y_true, y_pred).round(3))
|
141 |
+
+ "; FPR: "
|
142 |
+
+ str(false_positive_score(y_true, y_pred).round(3))
|
143 |
+
)
|
144 |
+
elif metric in ("brier_score", "brier_score_loss"):
|
145 |
+
quality_metric = brier_score_loss
|
146 |
+
elif metric in ("log_loss", "loss", "total_loss"):
|
147 |
+
quality_metric = log_loss
|
148 |
+
elif metric in ("accuracy", "accuracy_score", "acc"):
|
149 |
+
quality_metric = accuracy_score
|
150 |
+
elif metric in ("f1", "f1_score"):
|
151 |
+
quality_metric = f1_score
|
152 |
+
elif metric in ("al", "average_loss", "average_log_loss"):
|
153 |
+
quality_metric = average_log_loss_score
|
154 |
+
elif metric in ("roc_auc", "auroc", "auc_roc", "roc_auc_score"):
|
155 |
+
quality_metric = roc_auc_score
|
156 |
+
elif metric in ("miscalibration", "miscal"):
|
157 |
+
quality_metric = miscalibration_score
|
158 |
+
elif metric in (
|
159 |
+
"auprc",
|
160 |
+
"pr_auc",
|
161 |
+
"precision_recall_auc",
|
162 |
+
"average_precision_score",
|
163 |
+
):
|
164 |
+
quality_metric = average_precision_score
|
165 |
+
elif metric in ("false_positive_rate", "fpr"):
|
166 |
+
quality_metric = false_positive_score
|
167 |
+
elif metric in ("true_positive_rate", "tpr"):
|
168 |
+
quality_metric = true_positive_score
|
169 |
+
elif metric in ("fnr", "false_negative_rate"):
|
170 |
+
quality_metric = false_negative_score
|
171 |
+
else:
|
172 |
+
raise ValueError(
|
173 |
+
f"Metric: {metric} not supported. "
|
174 |
+
"Metric must be one of the following: "
|
175 |
+
"equalized_odds, brier_score_loss, log_loss, accuracy_score, "
|
176 |
+
"average_loss, roc_auc_diff, miscalibration_diff, auprc_diff, fpr_diff, tpr_diff"
|
177 |
+
)
|
178 |
+
|
179 |
+
return lambda y_true, y_pred: quality_metric(y_true, y_pred).round(3)
|
180 |
+
|
181 |
+
|
182 |
+
def sort_quality_metrics_df(
|
183 |
+
result_set_df: pd.DataFrame, quality_metric: str
|
184 |
+
) -> pd.DataFrame:
|
185 |
+
"""Sort the result set dataframe by the quality metric."""
|
186 |
+
# If quality_metric ends with ratio
|
187 |
+
if quality_metric.split("_")[-1] == "ratio":
|
188 |
+
# if ratios are below 1.0, the metric is more significant the lower it is so we sort in ascending order
|
189 |
+
result_set_df = result_set_df.sort_values(by="quality", ascending=False)
|
190 |
+
elif quality_metric.split("_")[-1] in ("difference", "diff"):
|
191 |
+
# If metric is a loss (lower is better)
|
192 |
+
if (
|
193 |
+
"acc" in quality_metric
|
194 |
+
or "au" in quality_metric
|
195 |
+
or "f1" in quality_metric
|
196 |
+
):
|
197 |
+
# Sort the result_set_df in descending order based on the metric_score
|
198 |
+
result_set_df = result_set_df.sort_values(
|
199 |
+
by="metric_score", ascending=True
|
200 |
+
)
|
201 |
+
# If max differences are below one, we are talking about difference in ratios, so we should show the results in ascending order
|
202 |
+
else:
|
203 |
+
# Sort the result_set_df in ascending order based on the metric_score
|
204 |
+
result_set_df = result_set_df.sort_values(by="metric_score", ascending=False)
|
205 |
+
else:
|
206 |
+
raise ValueError(
|
207 |
+
"Metric must be either a difference or a ratio! Provided metric:",
|
208 |
+
quality_metric,
|
209 |
+
)
|
210 |
+
|
211 |
+
return result_set_df
|
plot.py
ADDED
@@ -0,0 +1,911 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from dash import dash_table
|
4 |
+
import plotly.express as px
|
5 |
+
import plotly.graph_objects as go
|
6 |
+
from sklearn.calibration import calibration_curve
|
7 |
+
from sklearn.metrics import (
|
8 |
+
auc,
|
9 |
+
average_precision_score,
|
10 |
+
log_loss,
|
11 |
+
precision_recall_curve,
|
12 |
+
roc_auc_score,
|
13 |
+
roc_curve,
|
14 |
+
)
|
15 |
+
from metrics import (
|
16 |
+
Y_PRED_METRICS,
|
17 |
+
get_quality_metric_from_str,
|
18 |
+
get_name_from_metric_str,
|
19 |
+
miscalibration_score,
|
20 |
+
)
|
21 |
+
from scipy.stats import ks_2samp #, wasserstein_distance
|
22 |
+
|
23 |
+
COLS_TO_SHOW = [
|
24 |
+
"quality",
|
25 |
+
"size",
|
26 |
+
"description",
|
27 |
+
"auc_diff",
|
28 |
+
"f1_diff",
|
29 |
+
"exp_js_div",
|
30 |
+
"exp_mi",
|
31 |
+
"exp_ks_test_p_val",
|
32 |
+
"stat_diff_exp",
|
33 |
+
]
|
34 |
+
|
35 |
+
|
36 |
+
def plot_pr_curves(y_true, y_pred, y_pred_prob, sg_feature, title=None):
|
37 |
+
"""Plots PRC curves for the subgroup and the baseline of the subgroup"""
|
38 |
+
|
39 |
+
baseline_size = len(y_true)
|
40 |
+
precision, recall, thresholds = precision_recall_curve(y_true, y_pred_prob)
|
41 |
+
auprc_baseline = average_precision_score(y_true, y_pred)
|
42 |
+
# Plot with plotly
|
43 |
+
fig = go.Figure()
|
44 |
+
# Sort the precision, recall and thresholds according to recall
|
45 |
+
precision, recall, thresholds = zip(
|
46 |
+
*sorted(zip(precision, recall, thresholds), key=lambda x: x[0])
|
47 |
+
)
|
48 |
+
fig.add_trace(
|
49 |
+
go.Scatter(
|
50 |
+
x=precision,
|
51 |
+
y=recall,
|
52 |
+
name="Baseline (area = %0.2f, n = %d)" % (auprc_baseline, baseline_size),
|
53 |
+
mode="lines+markers",
|
54 |
+
customdata=thresholds,
|
55 |
+
hovertemplate="Precision: %{x}<br>Recall: %{y}<br>Threshold: %{customdata}",
|
56 |
+
)
|
57 |
+
)
|
58 |
+
|
59 |
+
subgroup_size = sum(sg_feature)
|
60 |
+
sg_precision, sg_recall, sg_tresholds = precision_recall_curve(
|
61 |
+
y_true[sg_feature], y_pred_prob[sg_feature]
|
62 |
+
)
|
63 |
+
auprc_subgroup = average_precision_score(y_true[sg_feature], y_pred[sg_feature])
|
64 |
+
sg_precision, sg_recall, sg_tresholds = zip(
|
65 |
+
*sorted(zip(sg_precision, sg_recall, sg_tresholds), key=lambda x: x[0])
|
66 |
+
)
|
67 |
+
fig.add_trace(
|
68 |
+
go.Scatter(
|
69 |
+
x=sg_precision,
|
70 |
+
y=sg_recall,
|
71 |
+
name="Subgroup (area = %0.2f, n = %d)" % (auprc_subgroup, subgroup_size),
|
72 |
+
mode="lines+markers",
|
73 |
+
customdata=sg_tresholds,
|
74 |
+
hovertemplate="Precision: %{x}<br>Recall: %{y}<br>Threshold: %{customdata}",
|
75 |
+
)
|
76 |
+
)
|
77 |
+
# Mark axes
|
78 |
+
fig.update_xaxes(title="Recall")
|
79 |
+
fig.update_yaxes(title="Precision")
|
80 |
+
# Update title
|
81 |
+
if title:
|
82 |
+
fig.update_layout(title=title)
|
83 |
+
# Update height
|
84 |
+
fig.update_layout(height=550)
|
85 |
+
return fig
|
86 |
+
|
87 |
+
|
88 |
+
def plot_roc_curves(y_true, y_pred_prob, sg_feature, title=None):
|
89 |
+
"""Plots ROC curves for the subgroup and the baseline of the subgroup"""
|
90 |
+
|
91 |
+
fig = go.Figure()
|
92 |
+
|
93 |
+
# Plot ROC curve for baseline
|
94 |
+
baseline_size = len(y_true)
|
95 |
+
baseline_fpr, baseline_tpr, thresholds = roc_curve(y_true, y_pred_prob)
|
96 |
+
roc_auc_group2 = auc(baseline_fpr, baseline_tpr)
|
97 |
+
fig.add_trace(
|
98 |
+
go.Scatter(
|
99 |
+
x=baseline_fpr,
|
100 |
+
y=baseline_tpr,
|
101 |
+
name="Baseline (area = %0.3f, n = %d)" % (roc_auc_group2, baseline_size),
|
102 |
+
mode="lines+markers",
|
103 |
+
customdata=thresholds,
|
104 |
+
hovertemplate="False Positive Rate: %{x}<br>True Positive Rate: %{y}<br>Threshold: %{customdata}",
|
105 |
+
)
|
106 |
+
)
|
107 |
+
|
108 |
+
# Plot ROC curve for subgroup
|
109 |
+
group_size1 = sum(sg_feature)
|
110 |
+
sg_fpr, sg_tpr, sg_thresholds = roc_curve(
|
111 |
+
y_true[sg_feature], y_pred_prob[sg_feature]
|
112 |
+
)
|
113 |
+
auroc_subgroup = auc(sg_fpr, sg_tpr)
|
114 |
+
|
115 |
+
fig.add_trace(
|
116 |
+
go.Scatter(
|
117 |
+
x=sg_fpr,
|
118 |
+
y=sg_tpr,
|
119 |
+
name="Subgroup (area = %0.3f, n = %d)" % (auroc_subgroup, group_size1),
|
120 |
+
mode="lines+markers",
|
121 |
+
customdata=sg_thresholds,
|
122 |
+
hovertemplate="False Positive Rate: %{x}<br>True Positive Rate: %{y}<br>Threshold: %{customdata}",
|
123 |
+
)
|
124 |
+
)
|
125 |
+
# Mark axes
|
126 |
+
fig.update_xaxes(title="False Positive Rate")
|
127 |
+
fig.update_yaxes(title="True Positive Rate")
|
128 |
+
|
129 |
+
# Update title
|
130 |
+
if title:
|
131 |
+
fig.update_layout(title=title)
|
132 |
+
# Update height
|
133 |
+
fig.update_layout(height=550)
|
134 |
+
return fig
|
135 |
+
|
136 |
+
|
137 |
+
def plot_calibration_curve(
|
138 |
+
y_true, y_pred_prob, sg_feature, n_bins=10, strategy="uniform"
|
139 |
+
):
|
140 |
+
"""Plots calibration curve for a classifier for group and its opposite"""
|
141 |
+
|
142 |
+
fig = go.Figure()
|
143 |
+
for group in ["Baseline", "Subgroup"]:
|
144 |
+
group_filter = (
|
145 |
+
sg_feature if group == "Subgroup" else pd.Series([True] * len(y_true))
|
146 |
+
)
|
147 |
+
cal_curve = calibration_curve(
|
148 |
+
y_true=y_true[group_filter],
|
149 |
+
y_prob=y_pred_prob[group_filter],
|
150 |
+
n_bins=n_bins,
|
151 |
+
strategy=strategy,
|
152 |
+
)
|
153 |
+
# Write the calibration plot with plotly
|
154 |
+
fig.add_trace(
|
155 |
+
go.Scatter(
|
156 |
+
x=cal_curve[1],
|
157 |
+
y=cal_curve[0],
|
158 |
+
name=group
|
159 |
+
+ "<br> (miscalibration score = %0.3f, n = %d)"
|
160 |
+
% (
|
161 |
+
miscalibration_score(
|
162 |
+
y_true[group_filter], y_pred_prob[group_filter], n_bins=n_bins
|
163 |
+
),
|
164 |
+
sum(group_filter),
|
165 |
+
),
|
166 |
+
# Add number of datapoints included at each point
|
167 |
+
customdata=group_filter.sum() * range(1, n_bins + 1) // n_bins,
|
168 |
+
mode="lines+markers",
|
169 |
+
hovertemplate="Mean predicted probability: %{x}<br>Fraction of positives: %{y}<br>Subgroup size: %{customdata}",
|
170 |
+
)
|
171 |
+
)
|
172 |
+
# Add perfect calibration line
|
173 |
+
fig.add_trace(
|
174 |
+
go.Scatter(
|
175 |
+
x=[0, 1],
|
176 |
+
y=[0, 1],
|
177 |
+
mode='lines',
|
178 |
+
name='Perfect Calibration',
|
179 |
+
line=dict(dash='dash')
|
180 |
+
)
|
181 |
+
)
|
182 |
+
fig.update_layout(title="Calibration curves for the selected subgroup and baseline")
|
183 |
+
fig.update_xaxes(title_text="Mean predicted probability")
|
184 |
+
fig.update_yaxes(title_text="Fraction of positives")
|
185 |
+
# Update height
|
186 |
+
fig.update_layout(height=550)
|
187 |
+
return fig
|
188 |
+
|
189 |
+
|
190 |
+
def get_sg_hist(y_df_local, categories=["TN", "FN", "TP", "FP"], title=None):
|
191 |
+
"""Returns a histogram of the predictions for the subgroup
|
192 |
+
|
193 |
+
Args:
|
194 |
+
y_df_local (pd.DataFrame): A dataframe with the true labels and the predictions
|
195 |
+
"""
|
196 |
+
y_df_local = y_df_local[y_df_local["category"].isin(categories)]
|
197 |
+
sg_hist = px.histogram(
|
198 |
+
y_df_local,
|
199 |
+
x="probability",
|
200 |
+
color="category",
|
201 |
+
hover_data=y_df_local.columns,
|
202 |
+
category_orders={"category": categories},
|
203 |
+
)
|
204 |
+
# Hide TN and TP by default
|
205 |
+
# sg_hist.for_each_trace(lambda t: t.update(visible="legendonly") if t.name in ["TN", "TP"] else t)
|
206 |
+
|
207 |
+
sg_hist.update_xaxes(range=[0, 1])
|
208 |
+
sg_hist.update_traces(
|
209 |
+
xbins=dict(
|
210 |
+
start=0,
|
211 |
+
end=1,
|
212 |
+
size=0.1,
|
213 |
+
),
|
214 |
+
)
|
215 |
+
if title is None:
|
216 |
+
title = "Histogram of prediction probabilities for the selected subgroup"
|
217 |
+
sg_hist.update_layout(
|
218 |
+
title_text=title, legend_title_text="", modebar_remove=["zoom", "pan"]
|
219 |
+
)
|
220 |
+
sg_hist.layout.xaxis.fixedrange = True
|
221 |
+
sg_hist.layout.yaxis.fixedrange = True
|
222 |
+
# Update histogram height
|
223 |
+
sg_hist.update_layout(height=550)
|
224 |
+
return sg_hist
|
225 |
+
|
226 |
+
|
227 |
+
def get_data_table(
|
228 |
+
subgroup_description, y_true, y_pred, y_pred_prob, qf_metric, sg_feature, n_bins=10
|
229 |
+
):
|
230 |
+
"""Generates a data table with the subgroup description and the subgroup size"""
|
231 |
+
|
232 |
+
# tpr = true_positive_score(y_true[sg_feature], y_pred[sg_feature]).round(3)
|
233 |
+
# fpr = false_positive_score(y_true[sg_feature], y_pred[sg_feature]).round(3)
|
234 |
+
auroc = roc_auc_score(y_true[sg_feature], y_pred_prob[sg_feature]).round(3)
|
235 |
+
# auprc = average_precision_score(y_true[sg_feature], y_pred[sg_feature]).round(3)
|
236 |
+
cal_score = miscalibration_score(
|
237 |
+
y_true[sg_feature], y_pred_prob[sg_feature], n_bins=n_bins
|
238 |
+
).round(3)
|
239 |
+
# avg_loss = log_loss(y_true[sg_feature], y_pred_prob[sg_feature]).round(3)
|
240 |
+
brier_score = np.mean((y_true[sg_feature] - y_pred_prob[sg_feature]) ** 2).round(3)
|
241 |
+
|
242 |
+
metric_name = get_name_from_metric_str(qf_metric)
|
243 |
+
|
244 |
+
if qf_metric in Y_PRED_METRICS:
|
245 |
+
quality_score = get_quality_metric_from_str(qf_metric)(
|
246 |
+
y_true[sg_feature], y_pred[sg_feature]
|
247 |
+
)
|
248 |
+
else:
|
249 |
+
quality_score = get_quality_metric_from_str(qf_metric)(
|
250 |
+
y_true[sg_feature], y_pred_prob[sg_feature]
|
251 |
+
)
|
252 |
+
# fp = sum((y_true[sg_feature] == 0) & (y_pred[sg_feature] == 1))
|
253 |
+
# fn = sum((y_true[sg_feature] == 1) & (y_pred[sg_feature] == 0))
|
254 |
+
|
255 |
+
# Generate a data table with the subgroup description
|
256 |
+
table_content = {
|
257 |
+
"Statistic": [
|
258 |
+
"Description",
|
259 |
+
"Size",
|
260 |
+
"AUROC",
|
261 |
+
"Miscalibration score",
|
262 |
+
"Brier score",
|
263 |
+
metric_name
|
264 |
+
],
|
265 |
+
"Value": [
|
266 |
+
subgroup_description,
|
267 |
+
sg_feature.sum(),
|
268 |
+
auroc,
|
269 |
+
cal_score,
|
270 |
+
brier_score,
|
271 |
+
quality_score
|
272 |
+
],
|
273 |
+
}
|
274 |
+
# if metric_name != "Average Log Loss":
|
275 |
+
# table_content["Statistic"].append(metric_name)
|
276 |
+
# table_content["Value"].append(quality_score)
|
277 |
+
|
278 |
+
df = pd.DataFrame(table_content)
|
279 |
+
data_table = dash_table.DataTable(
|
280 |
+
style_table={"overflowX": "auto"},
|
281 |
+
style_data={"whiteSpace": "normal", "height": "auto"},
|
282 |
+
data=df.to_dict("records"),
|
283 |
+
style_cell_conditional=[
|
284 |
+
{"if": {"column_id": "Feature"}, "width": "35%"},
|
285 |
+
],
|
286 |
+
)
|
287 |
+
|
288 |
+
return data_table
|
289 |
+
|
290 |
+
|
291 |
+
def get_data_distr_charts(
|
292 |
+
X, y_pred, sg_feature, feature, description, nbins=20, agg="percentage"
|
293 |
+
):
|
294 |
+
"""For positive and negative labels and predictions, returns a figure with the histogram distribution across
|
295 |
+
feature values of the selected feature in the subgroup and the baseline"""
|
296 |
+
|
297 |
+
# Get the data distribution for the feature values of the selected feature in the subgroup and the baseline
|
298 |
+
class_plot = get_data_distr_chart(
|
299 |
+
X, sg_feature, feature, description, nbins, agg
|
300 |
+
)
|
301 |
+
class_plot.update_layout(
|
302 |
+
title="Data distribution for "
|
303 |
+
+ feature
|
304 |
+
+ f" in the subgroup ({description}) and the baseline"
|
305 |
+
)
|
306 |
+
# Get the predictions distribution for the feature values of the selected feature in the subgroup and the baseline
|
307 |
+
pred_plot = get_data_distr_chart(
|
308 |
+
y_pred, sg_feature, feature, description, nbins, agg
|
309 |
+
)
|
310 |
+
pred_plot.update_layout(
|
311 |
+
title="Predictions distribution for "
|
312 |
+
+ feature
|
313 |
+
+ f" in the subgroup ({description}) and the baseline"
|
314 |
+
)
|
315 |
+
if agg == "percentage":
|
316 |
+
# Update yaxis range to 1
|
317 |
+
class_plot.update_layout(yaxis=dict(range=[0, 100]))
|
318 |
+
pred_plot.update_layout(yaxis=dict(range=[0, 100]))
|
319 |
+
|
320 |
+
return class_plot, pred_plot
|
321 |
+
|
322 |
+
|
323 |
+
def get_data_distr_chart(
|
324 |
+
X, sg_feature, feature, description, nbins=20, agg="percentage"
|
325 |
+
):
|
326 |
+
"""Returns a figure with the data distribution for the feature values of the selected feature in the subgroup and the baseline"""
|
327 |
+
|
328 |
+
if type(X) == pd.DataFrame and feature in X.columns:
|
329 |
+
X = X[feature]
|
330 |
+
else:
|
331 |
+
X = pd.Series(X).astype("str")
|
332 |
+
X_sg = X[sg_feature].copy()
|
333 |
+
|
334 |
+
fig = go.Figure()
|
335 |
+
# Add trace for baseline
|
336 |
+
fig.add_trace(
|
337 |
+
go.Histogram(
|
338 |
+
x=X,
|
339 |
+
name="Baseline",
|
340 |
+
histnorm="percent" if agg == "percentage" else "",
|
341 |
+
nbinsx=nbins,
|
342 |
+
)
|
343 |
+
)
|
344 |
+
|
345 |
+
# Add trace for subgroup
|
346 |
+
fig.add_trace(
|
347 |
+
go.Histogram(
|
348 |
+
x=X_sg,
|
349 |
+
name="Subgroup",
|
350 |
+
histnorm="percent" if agg == "percentage" else "",
|
351 |
+
nbinsx=nbins,
|
352 |
+
)
|
353 |
+
)
|
354 |
+
|
355 |
+
# Update layout
|
356 |
+
fig.update_layout(
|
357 |
+
title="Data distribution for "
|
358 |
+
+ feature
|
359 |
+
+ f" in the subgroup ({description}) and the baseline",
|
360 |
+
xaxis_title=feature,
|
361 |
+
yaxis_title=agg.capitalize() + " of data in respective group",
|
362 |
+
)
|
363 |
+
# Update yaxis range to max value
|
364 |
+
# max_y = max(
|
365 |
+
# max(fig.data[0].y), # FIXME: How to get the max value of the histogram object?
|
366 |
+
# max(fig.data[1].y)
|
367 |
+
# )
|
368 |
+
# fig.update_layout(yaxis=dict(range=[0, max_y]))
|
369 |
+
|
370 |
+
return fig
|
371 |
+
|
372 |
+
|
373 |
+
def get_feat_val_violin_plots(shap_values_df, sg_feature=None):
|
374 |
+
"""Returns a violin plot for the features SHAP values in the subgroup and the baseline"""
|
375 |
+
shap_values_df = transform_box_data(shap_values_df, sg_feature)
|
376 |
+
|
377 |
+
# Create the fig and add hoverdata of confidence intervals
|
378 |
+
fig = px.violin(
|
379 |
+
shap_values_df,
|
380 |
+
x="feature",
|
381 |
+
y="shap_value",
|
382 |
+
color="group",
|
383 |
+
points="all",
|
384 |
+
title="Feature contributions to model loss for subgroup and baseline. <br> "
|
385 |
+
+ "The lower the value, the higher its informativeness.",
|
386 |
+
hover_data=shap_values_df.columns,
|
387 |
+
height=600,
|
388 |
+
)
|
389 |
+
|
390 |
+
|
391 |
+
# Update layout
|
392 |
+
fig.update_layout(
|
393 |
+
yaxis_title="SHAP log loss value",
|
394 |
+
xaxis_title="Feature",
|
395 |
+
violinmode="group",
|
396 |
+
)
|
397 |
+
# Update height
|
398 |
+
fig.update_layout(height=600)
|
399 |
+
return fig
|
400 |
+
|
401 |
+
|
402 |
+
def get_feat_val_violin_plot(X, shap_df, sg_feature, feature, description, nbins=20):
|
403 |
+
"""Returns a figure with a violin plot for the feature value SHAP contributions in the subgroup and the baseline"""
|
404 |
+
feature_type = "categorical"
|
405 |
+
orig_df = X.copy()
|
406 |
+
# If data is continuous, we need to bin it
|
407 |
+
if X[feature].dtype in [np.float64, np.int64]:
|
408 |
+
X[feature] = pd.cut(X[feature], bins=nbins)
|
409 |
+
bins = X[feature].cat.categories.astype(str)
|
410 |
+
sorted_bins = sorted(
|
411 |
+
bins, key=lambda x: float(x.split(",")[0].strip("(").strip(" ").strip("]"))
|
412 |
+
)
|
413 |
+
X[feature] = X[feature].astype(str)
|
414 |
+
feature_type = "continuous"
|
415 |
+
|
416 |
+
# Merge the shap values with the feature values such that we can plot the violin plot of shap values per feature value
|
417 |
+
concat_df = pd.concat([shap_df[feature], X[feature]], axis=1)
|
418 |
+
concat_df.columns = ["SHAP", feature]
|
419 |
+
|
420 |
+
# Create a violin plot for the feature values
|
421 |
+
fig = go.Figure()
|
422 |
+
|
423 |
+
# Add trace for baseline
|
424 |
+
fig.add_trace(
|
425 |
+
go.Violin(
|
426 |
+
x=concat_df[feature],
|
427 |
+
y=concat_df["SHAP"],
|
428 |
+
name="Baseline",
|
429 |
+
box_visible=True,
|
430 |
+
meanline_visible=True,
|
431 |
+
points="all",
|
432 |
+
text=orig_df.apply(lambda row: ' '.join(f'{i}:{v}' for i, v in row.items()), axis=1),
|
433 |
+
hoverinfo="y+text",
|
434 |
+
)
|
435 |
+
)
|
436 |
+
|
437 |
+
# Add trace for subgroup
|
438 |
+
fig.add_trace(
|
439 |
+
go.Violin(
|
440 |
+
x=concat_df[feature][sg_feature],
|
441 |
+
y=concat_df["SHAP"][sg_feature],
|
442 |
+
name="Subgroup",
|
443 |
+
box_visible=True,
|
444 |
+
meanline_visible=True,
|
445 |
+
points="all",
|
446 |
+
text=orig_df[sg_feature].apply(lambda row: ' '.join(f'{i}:{v}' for i, v in row.items()), axis=1), hoverinfo="y+text",
|
447 |
+
)
|
448 |
+
)
|
449 |
+
slider_note = (
|
450 |
+
"Use slider below to adjust the (max) number of bins for the violin plot"
|
451 |
+
if feature_type == "continuous"
|
452 |
+
else "When the selected feature is categorical, slider changes do not affect the plot."
|
453 |
+
)
|
454 |
+
# Update layout
|
455 |
+
fig.update_layout(
|
456 |
+
title="Feature distribution for "
|
457 |
+
+ feature
|
458 |
+
+ f" in the subgroup ({description}) and the baseline <br>"
|
459 |
+
+ "The lower the feature's values, the higher its informativeness",
|
460 |
+
xaxis_title=f"Feature values of {feature} <br> Feature type: {feature_type}; {slider_note}",
|
461 |
+
yaxis_title="SHAP log loss value",
|
462 |
+
violinmode="group",
|
463 |
+
# yaxis=dict(range=[-0.4, 0.4])
|
464 |
+
)
|
465 |
+
# If feature is continuous, we need to sort the bins
|
466 |
+
if feature_type == "continuous":
|
467 |
+
fig.update_xaxes(categoryorder="array", categoryarray=sorted_bins)
|
468 |
+
else:
|
469 |
+
fig.update_xaxes(categoryorder="category ascending")
|
470 |
+
# Update height
|
471 |
+
fig.update_layout(height=700)
|
472 |
+
return fig
|
473 |
+
|
474 |
+
|
475 |
+
def get_feat_val_box(X, shap_df, sg_feature, feature, description, nbins=20):
|
476 |
+
"""Returns a figure with a box plot for the feature value SHAP contributions in the subgroup and the baseline"""
|
477 |
+
feature_type = "categorical"
|
478 |
+
orig_df = X.copy()
|
479 |
+
# If data is continuous, we need to bin it
|
480 |
+
if X[feature].dtype in [np.float64, np.int64]:
|
481 |
+
X[feature] = pd.cut(X[feature], bins=nbins)
|
482 |
+
bins = X[feature].cat.categories.astype(str)
|
483 |
+
sorted_bins = sorted(
|
484 |
+
bins, key=lambda x: float(x.split(",")[0].strip("(").strip(" ").strip("]"))
|
485 |
+
)
|
486 |
+
X[feature] = X[feature].astype(str)
|
487 |
+
feature_type = "continuous"
|
488 |
+
|
489 |
+
# Merge the shap values with the feature values such that we can plot the violin plot of shap values per feature value
|
490 |
+
concat_df = pd.concat([shap_df[feature], X[feature]], axis=1)
|
491 |
+
concat_df.columns = ["SHAP", feature]
|
492 |
+
|
493 |
+
# Create a violin plot for the feature values
|
494 |
+
fig = go.Figure()
|
495 |
+
|
496 |
+
# Add trace for baseline
|
497 |
+
fig.add_trace(
|
498 |
+
go.Box(
|
499 |
+
x=concat_df[feature],
|
500 |
+
y=concat_df["SHAP"],
|
501 |
+
name="Baseline",
|
502 |
+
boxpoints="all",
|
503 |
+
text=orig_df.apply(lambda row: ' '.join(f'{i}:{v}' for i, v in row.items()), axis=1),
|
504 |
+
hoverinfo="y+text",
|
505 |
+
)
|
506 |
+
)
|
507 |
+
|
508 |
+
# Add trace for subgroup
|
509 |
+
fig.add_trace(
|
510 |
+
go.Box(
|
511 |
+
x=concat_df[feature][sg_feature],
|
512 |
+
y=concat_df["SHAP"][sg_feature],
|
513 |
+
name="Subgroup",
|
514 |
+
boxpoints="all",
|
515 |
+
text=orig_df[sg_feature].apply(lambda row: ' '.join(f'{i}:{v}' for i, v in row.items()), axis=1),
|
516 |
+
hoverinfo="y+text",
|
517 |
+
)
|
518 |
+
)
|
519 |
+
slider_note = (
|
520 |
+
"Use slider below to adjust the (max) number of bins for the violin plot"
|
521 |
+
if feature_type == "continuous"
|
522 |
+
else "When the selected feature is categorical, slider changes do not affect the plot."
|
523 |
+
)
|
524 |
+
# Update layout
|
525 |
+
fig.update_layout(
|
526 |
+
title="Feature distribution for "
|
527 |
+
+ feature
|
528 |
+
+ f" in the subgroup ({description}) and the baseline <br>"
|
529 |
+
+ "The lower the feature's values, the higher its informativeness",
|
530 |
+
xaxis_title=f"Feature"
|
531 |
+
+ f" values of {feature} <br> Feature type: {feature_type}; {slider_note}",
|
532 |
+
yaxis_title="SHAP log loss value",
|
533 |
+
)
|
534 |
+
# If feature is continuous, we need to sort the bins
|
535 |
+
if feature_type == "continuous":
|
536 |
+
fig.update_xaxes(categoryorder="array", categoryarray=sorted_bins)
|
537 |
+
else:
|
538 |
+
fig.update_xaxes(categoryorder="category ascending")
|
539 |
+
# Update height
|
540 |
+
fig.update_layout(height=700)
|
541 |
+
return fig
|
542 |
+
|
543 |
+
|
544 |
+
def get_feat_val_bar(X, shap_df, sg_feature, feature, description, nbins=20, agg="mean"):
|
545 |
+
"""Returns a figure with a bar plot for the feature value SHAP contributions in the subgroup and the baseline"""
|
546 |
+
feature_type = "categorical"
|
547 |
+
orig_df = X.copy()
|
548 |
+
# If data is continuous, we need to bin it
|
549 |
+
if X[feature].dtype in [np.float64, np.int64]:
|
550 |
+
X[feature] = pd.cut(X[feature], bins=nbins)
|
551 |
+
bins = X[feature].cat.categories.astype(str)
|
552 |
+
sorted_bins = sorted(
|
553 |
+
bins, key=lambda x: float(x.split(",")[0].strip("(").strip(" ").strip("]"))
|
554 |
+
)
|
555 |
+
X[feature] = X[feature].astype(str)
|
556 |
+
feature_type = "continuous"
|
557 |
+
|
558 |
+
# Merge the shap values with the feature values such that we can plot the violin plot of shap values per feature value
|
559 |
+
concat_df = pd.concat([shap_df[feature], X[feature]], axis=1)
|
560 |
+
concat_df.columns = ["SHAP", feature]
|
561 |
+
|
562 |
+
# Create a violin plot for the feature values
|
563 |
+
fig = go.Figure()
|
564 |
+
|
565 |
+
if agg == "sum_weighted":
|
566 |
+
# Calculate the weighted sum of the SHAP values by using a sum divided over the total group size
|
567 |
+
baseline_df = concat_df.groupby(feature)["SHAP"].agg("sum") / len(concat_df)
|
568 |
+
sg_df = concat_df[sg_feature].groupby(feature)["SHAP"].agg("sum") / len(concat_df[sg_feature])
|
569 |
+
title = "Weighted sum"
|
570 |
+
else:
|
571 |
+
baseline_df = concat_df.groupby(feature)["SHAP"].agg(agg)
|
572 |
+
sg_df = concat_df[sg_feature].groupby(feature)["SHAP"].agg(agg)
|
573 |
+
title = agg.capitalize()
|
574 |
+
|
575 |
+
# Add trace for baseline
|
576 |
+
fig.add_trace(
|
577 |
+
go.Bar(
|
578 |
+
x=baseline_df.index,
|
579 |
+
y=baseline_df.values,
|
580 |
+
name="Baseline",
|
581 |
+
text=orig_df.apply(lambda row: ' '.join(f'{i}:{v}' for i, v in row.items()), axis=1),
|
582 |
+
hoverinfo="y+text",
|
583 |
+
)
|
584 |
+
)
|
585 |
+
# Add trace for subgroup
|
586 |
+
fig.add_trace(
|
587 |
+
go.Bar(
|
588 |
+
x=sg_df.index,
|
589 |
+
y=sg_df.values,
|
590 |
+
name="Subgroup",
|
591 |
+
text=orig_df[sg_feature].apply(lambda row: ' '.join(f'{i}:{v}' for i, v in row.items()), axis=1),
|
592 |
+
hoverinfo="y+text",
|
593 |
+
)
|
594 |
+
)
|
595 |
+
slider_note = (
|
596 |
+
"Use slider below to adjust the (max) number of bins for the violin plot"
|
597 |
+
if feature_type == "continuous"
|
598 |
+
else "When the selected feature is categorical, slider changes do not affect the plot."
|
599 |
+
)
|
600 |
+
# Update layout
|
601 |
+
fig.update_layout(
|
602 |
+
title="SHAP feature value loss contributions for "
|
603 |
+
+ feature
|
604 |
+
+ f" in the subgroup ({description}) and the baseline <br>"
|
605 |
+
+ "The lower the feature's values, the higher its informativeness",
|
606 |
+
xaxis_title=f"Feature"
|
607 |
+
+ f" values of {feature} <br> Feature type: {feature_type}; {slider_note}",
|
608 |
+
yaxis_title=title + " of SHAP log loss values",
|
609 |
+
)
|
610 |
+
# If feature is continuous, we need to sort the bins
|
611 |
+
if feature_type == "continuous":
|
612 |
+
fig.update_xaxes(categoryorder="array", categoryarray=sorted_bins)
|
613 |
+
else:
|
614 |
+
fig.update_xaxes(categoryorder="category ascending")
|
615 |
+
# Update height
|
616 |
+
fig.update_layout(height=700)
|
617 |
+
return fig
|
618 |
+
|
619 |
+
|
620 |
+
def get_feat_bar(shap_values_df, sg_feature, agg, error_bars=False) -> go.Figure:
|
621 |
+
"""Returns a figure with the feature contributions to the model loss
|
622 |
+
|
623 |
+
Args:
|
624 |
+
shap_values_df (pd.DataFrame): The shap values dataframe
|
625 |
+
sg_feature (pd.Series): The subgroup feature
|
626 |
+
agg (str): The aggregation method (mean, median, sum, mean_diff)
|
627 |
+
error_bars (bool): Whether to add error bars to the plot
|
628 |
+
Returns:
|
629 |
+
go.Figure: The figure with the feature contributions to the model loss
|
630 |
+
"""
|
631 |
+
|
632 |
+
sg_shap_values_df = shap_values_df[sg_feature]
|
633 |
+
if agg == "mean":
|
634 |
+
# Get the mean absolute shap value for each feature
|
635 |
+
shap_values_df_agg = shap_values_df.mean(numeric_only=True)
|
636 |
+
sg_shap_values_agg = sg_shap_values_df.mean(numeric_only=True)
|
637 |
+
elif agg == "median":
|
638 |
+
# Get the median absolute shap value for each feature
|
639 |
+
shap_values_df_agg = shap_values_df.median(numeric_only=True)
|
640 |
+
sg_shap_values_agg = sg_shap_values_df.median(numeric_only=True)
|
641 |
+
elif agg == "sum":
|
642 |
+
# Get the sum of absolute shap value for each feature
|
643 |
+
shap_values_df_agg = shap_values_df.sum(numeric_only=True)
|
644 |
+
sg_shap_values_agg = sg_shap_values_df.sum(numeric_only=True)
|
645 |
+
elif agg == "mean_diff":
|
646 |
+
# Get the difference in mean shap value for each feature
|
647 |
+
shap_values_df_agg = sg_shap_values_df.mean(numeric_only=True) - shap_values_df.mean(numeric_only=True)
|
648 |
+
# Then produce only a single bar plot with the difference
|
649 |
+
fig = go.Figure()
|
650 |
+
fig.add_trace(
|
651 |
+
go.Bar(
|
652 |
+
x=shap_values_df_agg.index,
|
653 |
+
y=shap_values_df_agg,
|
654 |
+
name="Difference (Subgroup - Baseline)",
|
655 |
+
)
|
656 |
+
)
|
657 |
+
fig.update_layout(
|
658 |
+
title="Difference in mean of the SHAP feature contributions to model loss (Subgroup - Baseline)",
|
659 |
+
xaxis_title="Feature",
|
660 |
+
yaxis_title="Difference in SHAP contributions to loss",
|
661 |
+
)
|
662 |
+
return fig
|
663 |
+
|
664 |
+
# Combine the two dataframes
|
665 |
+
shap_values_df_agg = pd.concat([shap_values_df_agg, sg_shap_values_agg], axis=1)
|
666 |
+
shap_values_df_agg.columns = ["Baseline", "Subgroup"]
|
667 |
+
|
668 |
+
# Create the fig and add hoverdata of confidence intervals
|
669 |
+
fig = go.Figure()
|
670 |
+
ytitle = agg.capitalize() + " SHAP log loss value - feature contribution to loss"
|
671 |
+
|
672 |
+
if error_bars:
|
673 |
+
ytitle += " <br> With standard deviation error bars"
|
674 |
+
fig.add_trace(
|
675 |
+
go.Bar(
|
676 |
+
x=shap_values_df_agg.index,
|
677 |
+
y=shap_values_df_agg["Baseline"],
|
678 |
+
name="Baseline",
|
679 |
+
customdata=shap_values_df.std(axis=0, numeric_only=True),
|
680 |
+
hovertemplate="Feature: %{x}<br>Baseline: %{y}<br>Standard deviation: %{customdata}",
|
681 |
+
error_y=dict(
|
682 |
+
type="data",
|
683 |
+
array=shap_values_df.std(axis=0, numeric_only=True),
|
684 |
+
visible=True,
|
685 |
+
),
|
686 |
+
)
|
687 |
+
)
|
688 |
+
fig.add_trace(
|
689 |
+
go.Bar(
|
690 |
+
x=shap_values_df_agg.index,
|
691 |
+
y=shap_values_df_agg["Subgroup"],
|
692 |
+
name="Subgroup",
|
693 |
+
customdata=sg_shap_values_df.std(axis=0, numeric_only=True),
|
694 |
+
hovertemplate="Feature: %{x}<br>Subgroup: %{y}<br>Standard deviation: %{customdata}",
|
695 |
+
error_y=dict(
|
696 |
+
type="data",
|
697 |
+
array=sg_shap_values_df.std(axis=0, numeric_only=True),
|
698 |
+
visible=True,
|
699 |
+
),
|
700 |
+
)
|
701 |
+
)
|
702 |
+
else:
|
703 |
+
fig.add_trace(
|
704 |
+
go.Bar(
|
705 |
+
x=shap_values_df_agg.index,
|
706 |
+
y=shap_values_df_agg["Baseline"],
|
707 |
+
name="Baseline",
|
708 |
+
)
|
709 |
+
)
|
710 |
+
fig.add_trace(
|
711 |
+
go.Bar(
|
712 |
+
x=shap_values_df_agg.index,
|
713 |
+
y=shap_values_df_agg["Subgroup"],
|
714 |
+
name="Subgroup",
|
715 |
+
)
|
716 |
+
)
|
717 |
+
|
718 |
+
# Update the fig
|
719 |
+
fig.update_layout(
|
720 |
+
barmode="group",
|
721 |
+
yaxis_tickangle=-45,
|
722 |
+
title="Feature contributions to model loss for subgroup and baseline. <br> "
|
723 |
+
+ "The lower the value, the higher its informativeness.",
|
724 |
+
yaxis_title=ytitle,
|
725 |
+
xaxis_title="Feature",
|
726 |
+
height=600,
|
727 |
+
)
|
728 |
+
# Order x axis by the absolute value of the difference between the subgroup and the baseline
|
729 |
+
fig.update_xaxes(
|
730 |
+
categoryorder="array",
|
731 |
+
categoryarray=shap_values_df_agg.abs().diff(axis=1).sort_values(by="Subgroup").index,
|
732 |
+
)
|
733 |
+
# Turn y labels
|
734 |
+
fig.update_layout(xaxis_tickangle=-25)
|
735 |
+
return fig
|
736 |
+
|
737 |
+
|
738 |
+
def get_feat_box(shap_values_df, sg_feature=None) -> go.Figure:
|
739 |
+
"""Returns a figure with the feature contributions to the model loss
|
740 |
+
|
741 |
+
Args:
|
742 |
+
shap_values_df (pd.DataFrame): The shap values dataframe
|
743 |
+
sg_feature (pd.Series): The subgroup feature
|
744 |
+
title (str): The title of the figure
|
745 |
+
Returns:
|
746 |
+
go.Figure: The figure with the feature contributions to the model loss
|
747 |
+
"""
|
748 |
+
shap_values_df = transform_box_data(shap_values_df, sg_feature)
|
749 |
+
|
750 |
+
# Create the fig and add hoverdata of confidence intervals
|
751 |
+
fig = px.box(
|
752 |
+
shap_values_df,
|
753 |
+
x="feature",
|
754 |
+
y="shap_value",
|
755 |
+
color="group",
|
756 |
+
points=False,
|
757 |
+
title="Feature contributions to model loss for subgroup and baseline. <br> "
|
758 |
+
+ "The lower the value, the higher its informativeness.",
|
759 |
+
hover_data=shap_values_df.columns,
|
760 |
+
height=600,
|
761 |
+
)
|
762 |
+
|
763 |
+
# Update the fig
|
764 |
+
fig.update_layout(
|
765 |
+
yaxis_title="SHAP log loss value - feature contribution to loss",
|
766 |
+
xaxis_title="Feature",
|
767 |
+
)
|
768 |
+
fig.update_traces(boxmean=True)
|
769 |
+
|
770 |
+
return fig
|
771 |
+
|
772 |
+
|
773 |
+
def transform_box_data(shap_values_df, sg_feature):
|
774 |
+
"""Transforms the shap values dataframe to be suitable for a box plot"""
|
775 |
+
shap_values_df = shap_values_df.drop(columns="group", errors="ignore")
|
776 |
+
|
777 |
+
if sg_feature is not None:
|
778 |
+
sg_shap_values_df = shap_values_df[sg_feature]
|
779 |
+
|
780 |
+
# Put all shap values of different features in a single column with feature names as a new column
|
781 |
+
sg_shap_values_df = sg_shap_values_df.reset_index()
|
782 |
+
sg_shap_values_df = sg_shap_values_df.melt(
|
783 |
+
id_vars="index", var_name="feature", value_name="shap_value"
|
784 |
+
)
|
785 |
+
sg_shap_values_df["group"] = "Subgroup"
|
786 |
+
|
787 |
+
# Put all shap values of different features in a single column with feature names as a new column
|
788 |
+
shap_values_df = shap_values_df.reset_index()
|
789 |
+
shap_values_df = shap_values_df.melt(
|
790 |
+
id_vars="index", var_name="feature", value_name="shap_value"
|
791 |
+
)
|
792 |
+
shap_values_df["group"] = "Baseline"
|
793 |
+
|
794 |
+
if sg_feature is not None:
|
795 |
+
# Combine the two dataframes
|
796 |
+
shap_values_df = pd.concat([shap_values_df, sg_shap_values_df], axis=0)
|
797 |
+
# Drop column index
|
798 |
+
shap_values_df = shap_values_df.drop(columns="index")
|
799 |
+
return shap_values_df
|
800 |
+
|
801 |
+
|
802 |
+
def get_feat_table(shap_values_df, sg_feature, sensitivity=4, alpha=0.05):
|
803 |
+
"""Returns a data table with the feature contributions to the model loss summary and tests for significance"""
|
804 |
+
|
805 |
+
sg_shap_values_df = shap_values_df[sg_feature]
|
806 |
+
|
807 |
+
# Get the mean absolute shap value for each feature
|
808 |
+
shap_values_df_mean = shap_values_df.mean(numeric_only=True).round(5)
|
809 |
+
sg_shap_values_mean = sg_shap_values_df.mean(numeric_only=True).round(5)
|
810 |
+
shap_values_df_mean = pd.concat([shap_values_df_mean, sg_shap_values_mean], axis=1)
|
811 |
+
shap_values_df_mean.columns = ["Baseline", "Subgroup"]
|
812 |
+
|
813 |
+
# Get the standard deviation of the shap values
|
814 |
+
shap_values_df_std = shap_values_df.std(numeric_only=True).round(5)
|
815 |
+
sg_shap_values_df_std = sg_shap_values_df.std(numeric_only=True).round(5)
|
816 |
+
shap_values_df_std = pd.concat([shap_values_df_std, sg_shap_values_df_std], axis=1)
|
817 |
+
shap_values_df_std.columns = ["Baseline", "Subgroup"]
|
818 |
+
|
819 |
+
# Get the p-value of the shap values
|
820 |
+
shap_values_df_p = pd.DataFrame(
|
821 |
+
index=shap_values_df_mean.index, columns=["KS_p_value"]
|
822 |
+
)
|
823 |
+
|
824 |
+
# Round shap values to 5 decimal places
|
825 |
+
shap_values_df = shap_values_df.round(sensitivity)
|
826 |
+
sg_shap_values_df = sg_shap_values_df.round(sensitivity)
|
827 |
+
|
828 |
+
for feature in shap_values_df_mean.index:
|
829 |
+
# Run KS test
|
830 |
+
statistic, p_value = ks_2samp(
|
831 |
+
shap_values_df[feature], sg_shap_values_df[feature]
|
832 |
+
)
|
833 |
+
shap_values_df_p.loc[feature, "KS_p_value"] = p_value.round(6)
|
834 |
+
shap_values_df_p.loc[feature, "KS_statistic"] = statistic
|
835 |
+
|
836 |
+
# Calculate Wasserstein distance
|
837 |
+
# wasserstein_dist = wasserstein_distance(
|
838 |
+
# shap_values_df[feature], sg_shap_values_df[feature]
|
839 |
+
# )
|
840 |
+
# shap_values_df_p.loc[feature, "Wasserstein_distance"] = wasserstein_dist.round(6)
|
841 |
+
|
842 |
+
shap_values_df_p = shap_values_df_p.round(6)
|
843 |
+
|
844 |
+
# Merge the dataframes
|
845 |
+
df = shap_values_df_mean.merge(
|
846 |
+
shap_values_df_std, left_index=True, right_index=True
|
847 |
+
)
|
848 |
+
df = df.merge(shap_values_df_p, left_index=True, right_index=True)
|
849 |
+
df = df.reset_index()
|
850 |
+
df.columns = [
|
851 |
+
"Feature",
|
852 |
+
"Baseline_avg",
|
853 |
+
"Subgroup_avg",
|
854 |
+
"Baseline_std",
|
855 |
+
"Subgroup_std",
|
856 |
+
"KS p-value",
|
857 |
+
"KS statistic",
|
858 |
+
# "Wasserstein dist",
|
859 |
+
]
|
860 |
+
|
861 |
+
df["Cohen's d"] = (df["Subgroup_avg"] - df["Baseline_avg"]) / np.sqrt(
|
862 |
+
(df["Baseline_std"] ** 2 + df["Subgroup_std"] ** 2) / 2
|
863 |
+
)
|
864 |
+
df["Cohen's d"] = df["Cohen's d"].round(5)
|
865 |
+
|
866 |
+
# Order df rows based on the p-value and the mean
|
867 |
+
df = df.sort_values(by=["KS statistic"], ascending=[False])
|
868 |
+
|
869 |
+
# Merge avg and std columns
|
870 |
+
df["Baseline (mean ± std)"] = (
|
871 |
+
df["Baseline_avg"].astype(str) + " ± " + df["Baseline_std"].astype(str)
|
872 |
+
)
|
873 |
+
df["Subgroup (mean ± std)"] = (
|
874 |
+
df["Subgroup_avg"].astype(str) + " ± " + df["Subgroup_std"].astype(str)
|
875 |
+
)
|
876 |
+
df = df.drop(
|
877 |
+
columns=["Baseline_avg", "Subgroup_avg", "Baseline_std", "Subgroup_std"]
|
878 |
+
)
|
879 |
+
|
880 |
+
# Reorder columns
|
881 |
+
df = df[
|
882 |
+
[
|
883 |
+
"Feature",
|
884 |
+
"Baseline (mean ± std)",
|
885 |
+
"Subgroup (mean ± std)",
|
886 |
+
"Cohen's d",
|
887 |
+
"KS statistic",
|
888 |
+
"KS p-value",
|
889 |
+
# "Wasserstein dist",
|
890 |
+
]
|
891 |
+
]
|
892 |
+
|
893 |
+
# Generate a data table with the feature contributions to the model loss
|
894 |
+
data_table = dash_table.DataTable(
|
895 |
+
style_table={"overflowX": "auto"},
|
896 |
+
style_data={"whiteSpace": "normal", "height": "auto"},
|
897 |
+
data=df.to_dict("records"),
|
898 |
+
# Format p-values in bold font if below 0.05
|
899 |
+
style_data_conditional=[
|
900 |
+
{
|
901 |
+
"if": {"filter_query": "{KS p-value} < " + str(alpha)},
|
902 |
+
"fontWeight": "bold",
|
903 |
+
},
|
904 |
+
# Change background of cells to gray in the column "Cohen's d"
|
905 |
+
{
|
906 |
+
"if": {"column_id": "Cohen's d"},
|
907 |
+
"backgroundColor": "#f9f9f9",
|
908 |
+
},
|
909 |
+
],
|
910 |
+
)
|
911 |
+
return data_table
|
requirements.txt
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib==3.7.2
|
2 |
+
pandas==1.5.3
|
3 |
+
scikit-learn==1.3.0
|
4 |
+
scikit-image==0.21.0
|
5 |
+
tqdm==4.65.0
|
6 |
+
numpy==1.24.4
|
7 |
+
fairlearn==0.9.0
|
8 |
+
seaborn==0.12.2
|
9 |
+
pmlb==1.0.1.post3
|
10 |
+
shap==0.42.1
|
11 |
+
lime==0.2.0.1
|
12 |
+
statsmodels==0.14.0
|
13 |
+
plotly==5.17.0
|
14 |
+
dash==2.14.1
|
15 |
+
dice-ml==0.11
|
16 |
+
explainerdashboard==0.4.3
|
17 |
+
kaleido==0.2.1
|
18 |
+
rfpimp==1.3.7
|
19 |
+
dash-bootstrap-components==1.0.0
|
utils.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import shap
|
4 |
+
from fairlearn.metrics import make_derived_metric
|
5 |
+
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
|
6 |
+
|
7 |
+
roc_auc_diff = make_derived_metric(metric=roc_auc_score, transform="difference")
|
8 |
+
f1_score_diff = make_derived_metric(metric=f1_score, transform="difference")
|
9 |
+
precision_diff = make_derived_metric(metric=precision_score, transform="difference")
|
10 |
+
recall_diff = make_derived_metric(metric=recall_score, transform="difference")
|
11 |
+
|
12 |
+
|
13 |
+
def combine_all_one_hot_shap_logloss(
|
14 |
+
sage_values_df: pd.DataFrame, target_features, cat_features
|
15 |
+
):
|
16 |
+
"""Combine all one hot encoded features into parent features"""
|
17 |
+
sage_values_df = sage_values_df.copy()
|
18 |
+
# Combine one hot encoded features with sage values
|
19 |
+
non_cat_features = [col for col in target_features if col not in cat_features]
|
20 |
+
for cat_feat in cat_features:
|
21 |
+
# Get column masks for each cat feature
|
22 |
+
col_mask = [
|
23 |
+
col.startswith(cat_feat) and col not in non_cat_features
|
24 |
+
for col in sage_values_df.columns
|
25 |
+
]
|
26 |
+
# Sum columns from col_mask
|
27 |
+
sage_values_df[cat_feat] = sage_values_df.loc[:, col_mask].sum(axis=1)
|
28 |
+
# Add cat_feat to col_mask
|
29 |
+
if len(col_mask) < sage_values_df.shape[1]:
|
30 |
+
col_mask.append(False)
|
31 |
+
# Drop columns from col_mask
|
32 |
+
sage_values_df.drop(sage_values_df.columns[col_mask], axis=1, inplace=True)
|
33 |
+
return sage_values_df
|
34 |
+
|
35 |
+
|
36 |
+
def combine_one_hot(shap_values, name, mask, return_original=True):
|
37 |
+
"""Combines one-hot-encoded features into a single feature
|
38 |
+
|
39 |
+
Args:
|
40 |
+
shap_values: an Explanation object
|
41 |
+
name: name of new feature
|
42 |
+
mask: bool array same lenght as features
|
43 |
+
|
44 |
+
This function assumes that shap_values[:, mask] make up a one-hot-encoded feature
|
45 |
+
"""
|
46 |
+
mask = np.array(mask)
|
47 |
+
mask_col_names = np.array(shap_values.feature_names, dtype="object")[mask]
|
48 |
+
|
49 |
+
sv_name = shap.Explanation(
|
50 |
+
shap_values.values[:, mask],
|
51 |
+
feature_names=list(mask_col_names),
|
52 |
+
data=shap_values.data[:, mask],
|
53 |
+
base_values=shap_values.base_values,
|
54 |
+
display_data=shap_values.display_data,
|
55 |
+
instance_names=shap_values.instance_names,
|
56 |
+
output_names=shap_values.output_names,
|
57 |
+
output_indexes=shap_values.output_indexes,
|
58 |
+
lower_bounds=shap_values.lower_bounds,
|
59 |
+
upper_bounds=shap_values.upper_bounds,
|
60 |
+
main_effects=shap_values.main_effects,
|
61 |
+
hierarchical_values=shap_values.hierarchical_values,
|
62 |
+
clustering=shap_values.clustering,
|
63 |
+
)
|
64 |
+
|
65 |
+
new_data = (sv_name.data * np.arange(sum(mask))).sum(axis=1).astype(int)
|
66 |
+
|
67 |
+
svdata = np.concatenate(
|
68 |
+
[shap_values.data[:, ~mask], new_data.reshape(-1, 1)], axis=1
|
69 |
+
)
|
70 |
+
|
71 |
+
if shap_values.display_data is None:
|
72 |
+
svdd = shap_values.data[:, ~mask]
|
73 |
+
else:
|
74 |
+
svdd = shap_values.display_data[:, ~mask]
|
75 |
+
|
76 |
+
svdisplay_data = np.concatenate(
|
77 |
+
[svdd, mask_col_names[new_data].reshape(-1, 1)], axis=1
|
78 |
+
)
|
79 |
+
|
80 |
+
new_values = sv_name.values.sum(axis=1)
|
81 |
+
|
82 |
+
# Reshape new_values to match the dims of shap_values.values
|
83 |
+
svvalues = np.concatenate(
|
84 |
+
[shap_values.values[:, ~mask], new_values.reshape(-1, 1, 2)], axis=1
|
85 |
+
)
|
86 |
+
|
87 |
+
svfeature_names = list(np.array(shap_values.feature_names)[~mask]) + [name]
|
88 |
+
|
89 |
+
sv = shap.Explanation(
|
90 |
+
svvalues,
|
91 |
+
base_values=shap_values.base_values,
|
92 |
+
data=svdata,
|
93 |
+
display_data=svdisplay_data,
|
94 |
+
instance_names=shap_values.instance_names,
|
95 |
+
feature_names=svfeature_names,
|
96 |
+
output_names=shap_values.output_names,
|
97 |
+
output_indexes=shap_values.output_indexes,
|
98 |
+
lower_bounds=shap_values.lower_bounds,
|
99 |
+
upper_bounds=shap_values.upper_bounds,
|
100 |
+
main_effects=shap_values.main_effects,
|
101 |
+
hierarchical_values=shap_values.hierarchical_values,
|
102 |
+
clustering=shap_values.clustering,
|
103 |
+
)
|
104 |
+
if return_original:
|
105 |
+
return sv, sv_name
|
106 |
+
else:
|
107 |
+
return sv
|
108 |
+
|
109 |
+
|
110 |
+
def drop_description_attributes(shap_values_df, description):
|
111 |
+
"""Remove shap_values for columns that are in the description
|
112 |
+
|
113 |
+
Args:
|
114 |
+
shap_values_df: pd.DataFrame
|
115 |
+
description: fairsd.Description
|
116 |
+
"""
|
117 |
+
attributes = description.get_attributes()
|
118 |
+
shap_values_df.drop(attributes, axis=1, errors="ignore", inplace=True)
|