Spaces:
Sleeping
Sleeping
AMontiB
commited on
Commit
·
734982a
1
Parent(s):
7ea5247
upload
Browse files- .gitattributes +1 -0
- .gitignore +11 -0
- README.md +67 -8
- attribution_demonstrator/__init__.py +0 -0
- attribution_demonstrator/assets/Logo_ID.png +3 -0
- attribution_demonstrator/assets/Logo_MUR.png +3 -0
- attribution_demonstrator/assets/Logo_NGEU.png +3 -0
- attribution_demonstrator/assets/Logo_Serics.png +3 -0
- attribution_demonstrator/connector/__init__.py +0 -0
- attribution_demonstrator/connector/abstract_connector.py +24 -0
- attribution_demonstrator/connector/azure_model_connector.py +180 -0
- attribution_demonstrator/main.py +177 -0
- pyproject.toml +20 -0
- requirements.txt +5 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
roles/
|
| 2 |
+
venv*
|
| 3 |
+
.idea/
|
| 4 |
+
|
| 5 |
+
*.log
|
| 6 |
+
*.pyc
|
| 7 |
+
__pychache__
|
| 8 |
+
|
| 9 |
+
*.sqlite3
|
| 10 |
+
|
| 11 |
+
.gradio/
|
README.md
CHANGED
|
@@ -1,13 +1,72 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file:
|
| 9 |
pinned: false
|
| 10 |
-
license: cc-by-4.0
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: FF4ALL WILD DEMONSTRATOR
|
| 3 |
+
emoji: 🏞️
|
| 4 |
+
colorFrom: purple
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.33.0
|
| 8 |
+
app_file: attribution_demonstrator/main.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Demonstrator
|
| 13 |
+
|
| 14 |
+
## Run the demonstrator
|
| 15 |
+
|
| 16 |
+
Export the required environment variables:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
export AZURE_CONNECTION_STRING="...."
|
| 20 |
+
export CONTAINER_NAME="..."
|
| 21 |
+
export DATABRICKS_CLIENT_ID="..."
|
| 22 |
+
export DATABRICKS_SECRET="..."
|
| 23 |
+
export MODELS="..."
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
## Using pip
|
| 27 |
+
|
| 28 |
+
1. install the required packages:
|
| 29 |
+
```bash
|
| 30 |
+
pip install -r requirements.txt
|
| 31 |
+
```
|
| 32 |
+
2. run the demonstrator with the command:
|
| 33 |
+
```bash
|
| 34 |
+
export PYTHONPATH=${PWD}/.
|
| 35 |
+
python -m attribution_demonstrator.main
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Using uv
|
| 39 |
+
|
| 40 |
+
1. install uv, please refer to the [uv documentation](https://docs.astral.sh/uv/)
|
| 41 |
+
2. run the demonstrator with the command:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
cd src
|
| 45 |
+
export PYTHONPATH=${PWD}/.
|
| 46 |
+
uv run -m attribution_demonstrator.main
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
## Exporting uv env to a requirements.txt
|
| 51 |
+
|
| 52 |
+
```bash
|
| 53 |
+
uv export --no-emit-workspace --no-dev --no-annotate --no-header --no-hashes --output-file requirements.txt
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
# Environment variables
|
| 57 |
+
|
| 58 |
+
## AZURE_CONNECTION_STRING
|
| 59 |
+
The connection string to the Azure Blob Storage account where the model files are stored.
|
| 60 |
+
|
| 61 |
+
## CONTAINER_NAME
|
| 62 |
+
The name of the Azure Blob Storage container used for temporary image storage.
|
| 63 |
+
|
| 64 |
+
## DATABRICKS_CLIENT_ID and DATABRICKS_SECRET
|
| 65 |
+
The client ID and secret for the Databricks workspace where the models are hosted. These are used to authenticate and access the models.
|
| 66 |
+
The token can be generate in the Databricks workspace under `Workspace settings > Identity and access > Service principals`.
|
| 67 |
+
|
| 68 |
+
## MODELS
|
| 69 |
+
A json string containing the model names and their corresponding urls. The format is:
|
| 70 |
+
```json
|
| 71 |
+
{"model_name_1": "url", "model_name_2": "url"}
|
| 72 |
+
```
|
attribution_demonstrator/__init__.py
ADDED
|
File without changes
|
attribution_demonstrator/assets/Logo_ID.png
ADDED
|
Git LFS Details
|
attribution_demonstrator/assets/Logo_MUR.png
ADDED
|
Git LFS Details
|
attribution_demonstrator/assets/Logo_NGEU.png
ADDED
|
Git LFS Details
|
attribution_demonstrator/assets/Logo_Serics.png
ADDED
|
Git LFS Details
|
attribution_demonstrator/connector/__init__.py
ADDED
|
File without changes
|
attribution_demonstrator/connector/abstract_connector.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModelConnector(ABC):
|
| 8 |
+
"""
|
| 9 |
+
Abstract base class for model connectors.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
@abstractmethod
|
| 13 |
+
async def perform_inference(self, image: Image, model_list: List[str]) -> float:
|
| 14 |
+
"""
|
| 15 |
+
Perform inference on the given image or DataFrame.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
image (Image | DataFrame): The input image or DataFrame.
|
| 19 |
+
model_list (List[str]): List of model identifiers to use for inference.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
float: The predicted score.
|
| 23 |
+
"""
|
| 24 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
attribution_demonstrator/connector/azure_model_connector.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import uuid
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from typing import Dict, Tuple, Union, List
|
| 7 |
+
|
| 8 |
+
import httpx
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from azure.storage.blob import BlobServiceClient
|
| 12 |
+
|
| 13 |
+
from attribution_demonstrator.connector.abstract_connector import ModelConnector
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("orchestrator.common.manager.image")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AzureDatabricksModelConnector(ModelConnector):
|
| 19 |
+
"""
|
| 20 |
+
Azure Model Connector for models serve on Databricks.
|
| 21 |
+
This class is responsible for performing inference on images using a model hosted on Azure.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
models: Dict[str, str],
|
| 26 |
+
databricks_host: str,
|
| 27 |
+
databricks_client_id: str,
|
| 28 |
+
databricks_secret: str,
|
| 29 |
+
azure_connection_string: str,
|
| 30 |
+
azure_container_name: str
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Initialize the AzureModelConnector with the model URI and key.
|
| 34 |
+
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.models = models
|
| 39 |
+
self.azure_connection_string = azure_connection_string
|
| 40 |
+
self.azure_container_name = azure_container_name
|
| 41 |
+
|
| 42 |
+
self._databricks_secret = databricks_secret
|
| 43 |
+
self._databricks_client_id = databricks_client_id
|
| 44 |
+
self._databricks_host = databricks_host
|
| 45 |
+
|
| 46 |
+
self._token = None
|
| 47 |
+
self._token_expiration = None
|
| 48 |
+
|
| 49 |
+
def upload_image_to_azure(self, image: Image) -> str:
|
| 50 |
+
"""
|
| 51 |
+
Upload the image to Azure Blob Storage and return the blob URL.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
image (Image): The input image.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
str: The URL of the uploaded image in Azure Blob Storage.
|
| 58 |
+
"""
|
| 59 |
+
uploaded_file_name = f"ff4all_demostrator_{uuid.uuid4()}.tiff"
|
| 60 |
+
logger.info(f"Uploading [{uploaded_file_name}] image to Azure Blob Storage")
|
| 61 |
+
blob_service_client = BlobServiceClient.from_connection_string(self.azure_connection_string)
|
| 62 |
+
blob_client = blob_service_client.get_blob_client(container=self.azure_container_name, blob=uploaded_file_name)
|
| 63 |
+
|
| 64 |
+
tiff_bytes = BytesIO()
|
| 65 |
+
image.save(tiff_bytes, format="TIFF", compression=None)
|
| 66 |
+
tiff_bytes.seek(0)
|
| 67 |
+
|
| 68 |
+
blob_client.upload_blob(tiff_bytes, overwrite=True)
|
| 69 |
+
logger.info("Image uploaded successfully")
|
| 70 |
+
return uploaded_file_name
|
| 71 |
+
|
| 72 |
+
def _delete_blob(self, blob_name: str):
|
| 73 |
+
"""
|
| 74 |
+
Delete the blob from Azure Blob Storage.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
blob_name (str): The name of the blob to delete.
|
| 78 |
+
"""
|
| 79 |
+
logger.info(f"Deleting blob {blob_name} from Azure Blob Storage")
|
| 80 |
+
blob_service_client = BlobServiceClient.from_connection_string(self.azure_connection_string)
|
| 81 |
+
blob_client = blob_service_client.get_blob_client(container=self.azure_container_name, blob=blob_name)
|
| 82 |
+
|
| 83 |
+
blob_client.delete_blob()
|
| 84 |
+
logger.info("Blob deleted successfully")
|
| 85 |
+
|
| 86 |
+
def _get_databricks_token(self) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Get the Databricks token for authentication.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
str: The Databricks token.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
if self._token and self._token_expiration and self._token_expiration > datetime.now():
|
| 95 |
+
logger.info(f"Using cached Databricks token, valid until {self._token_expiration.isoformat()}")
|
| 96 |
+
return self._token
|
| 97 |
+
else:
|
| 98 |
+
logger.info("Fetching new Databricks token")
|
| 99 |
+
client = httpx.Client()
|
| 100 |
+
response = client.post(
|
| 101 |
+
f"{self._databricks_host}/oidc/v1/token",
|
| 102 |
+
auth=(self._databricks_client_id, self._databricks_secret),
|
| 103 |
+
data={'grant_type': 'client_credentials', 'scope': 'all-apis'}
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
if response.status_code in [200, 201, 202]:
|
| 107 |
+
token_data = response.json()
|
| 108 |
+
self._token = token_data['access_token']
|
| 109 |
+
self._token_expiration = datetime.now() + pd.to_timedelta(token_data['expires_in'], unit='s')
|
| 110 |
+
return self._token
|
| 111 |
+
else:
|
| 112 |
+
raise Exception(f"Failed to get Databricks token: {response.text}")
|
| 113 |
+
|
| 114 |
+
async def invoke_model(self, client: httpx.AsyncClient, model_id: str, model_url: str, payload: dict, auth_token: str) -> Tuple[str, Union[list | Exception]]:
|
| 115 |
+
|
| 116 |
+
logger.info(f"Invoking model {model_id}")
|
| 117 |
+
response = await client.post(
|
| 118 |
+
model_url,
|
| 119 |
+
json=payload,
|
| 120 |
+
headers={'Authorization': f'Bearer {auth_token}',
|
| 121 |
+
'Content-Type': 'application/json'},
|
| 122 |
+
timeout=600 # Set a timeout of 10 minutes
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if isinstance(response, Exception):
|
| 126 |
+
result = f"Error: {str(response)}"
|
| 127 |
+
elif response.status_code == 200:
|
| 128 |
+
result = response.json()["predictions"][0]["prediction"] # only one image is processed at a time
|
| 129 |
+
else:
|
| 130 |
+
result = f"HTTP {response.status_code}"
|
| 131 |
+
|
| 132 |
+
logger.info(f"Model {model_id} invocation completed")
|
| 133 |
+
|
| 134 |
+
return model_id, result
|
| 135 |
+
|
| 136 |
+
async def invoke_models(self, payload: dict, model_list: List[str]) -> dict:
|
| 137 |
+
model_to_invoke = {model_id: self.models[model_id] for model_id in model_list if model_id in self.models}
|
| 138 |
+
logger.info(f"Models to invoke: {model_to_invoke.keys()}")
|
| 139 |
+
|
| 140 |
+
token = self._get_databricks_token()
|
| 141 |
+
async with httpx.AsyncClient() as client:
|
| 142 |
+
tasks = [
|
| 143 |
+
self.invoke_model(
|
| 144 |
+
client=client,
|
| 145 |
+
model_id=model_id,
|
| 146 |
+
model_url=model_url,
|
| 147 |
+
payload=payload,
|
| 148 |
+
auth_token=token
|
| 149 |
+
) for model_id, model_url in model_to_invoke.items()
|
| 150 |
+
]
|
| 151 |
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
| 152 |
+
|
| 153 |
+
logger.info("Invoked all models")
|
| 154 |
+
|
| 155 |
+
return dict(results)
|
| 156 |
+
|
| 157 |
+
async def perform_inference(self, image: Image, model_list: List[str]) -> dict:
|
| 158 |
+
"""
|
| 159 |
+
Perform inference on the given image using the Azure model.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
image (Image): The input image.
|
| 163 |
+
model_list (List[str]): List of model identifiers to use for inference.
|
| 164 |
+
|
| 165 |
+
Returns:
|
| 166 |
+
float: The predicted score.
|
| 167 |
+
"""
|
| 168 |
+
uploaded_file_name = self.upload_image_to_azure(image)
|
| 169 |
+
|
| 170 |
+
df = pd.DataFrame([{
|
| 171 |
+
"file_path": uploaded_file_name
|
| 172 |
+
}])
|
| 173 |
+
payload = {"dataframe_split": df.to_dict(orient="split")}
|
| 174 |
+
|
| 175 |
+
logger.info(f"Sending pyload to the models [{payload}]")
|
| 176 |
+
response = await self.invoke_models(payload, model_list)
|
| 177 |
+
|
| 178 |
+
self._delete_blob(uploaded_file_name)
|
| 179 |
+
|
| 180 |
+
return response
|
attribution_demonstrator/main.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import base64
|
| 3 |
+
import json
|
| 4 |
+
import logging.config
|
| 5 |
+
import os
|
| 6 |
+
from typing import List, Any, Tuple
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import pandas as pd
|
| 11 |
+
import plotly.express as px
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
from attribution_demonstrator.connector.azure_model_connector import AzureDatabricksModelConnector
|
| 15 |
+
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
ASSETS_DIR = Path(__file__).parent / "assets" # setting path to logo images
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Configuration
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
models = json.loads(os.getenv("MODELS"))
|
| 28 |
+
|
| 29 |
+
logger.info(f"Models loaded from environment: {models.keys()}")
|
| 30 |
+
|
| 31 |
+
connector = AzureDatabricksModelConnector(
|
| 32 |
+
models=json.loads(os.getenv("MODELS")),
|
| 33 |
+
databricks_host=os.getenv("DATABRICKS_HOST"),
|
| 34 |
+
databricks_client_id=os.getenv("DATABRICKS_CLIENT_ID"),
|
| 35 |
+
databricks_secret=os.getenv("DATABRICKS_SECRET"),
|
| 36 |
+
azure_connection_string=os.getenv("AZURE_CONNECTION_STRING"),
|
| 37 |
+
azure_container_name=os.getenv("CONTAINER_NAME")
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Inference helpers
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
async def call_tagging_services(image: Image, model_list: List[str]) -> dict:
|
| 45 |
+
|
| 46 |
+
result = await connector.perform_inference(image, model_list)
|
| 47 |
+
return result
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def process_image(image: Image.Image, model_list: List[str]) -> Tuple[Any, dict]:
|
| 51 |
+
if not model_list:
|
| 52 |
+
raise gr.Error("Please select at least one model to perform inference.")
|
| 53 |
+
result = asyncio.run(call_tagging_services(image, model_list))
|
| 54 |
+
|
| 55 |
+
df_splits = []
|
| 56 |
+
|
| 57 |
+
for model, res in result.items():
|
| 58 |
+
if isinstance(res, dict):
|
| 59 |
+
|
| 60 |
+
for generator, score in res.items():
|
| 61 |
+
df_splits.append({
|
| 62 |
+
"model": model,
|
| 63 |
+
"generator": generator,
|
| 64 |
+
"score": score
|
| 65 |
+
})
|
| 66 |
+
else:
|
| 67 |
+
logger.info(f"Skipping model with non-list result: {res}")
|
| 68 |
+
|
| 69 |
+
if df_splits:
|
| 70 |
+
|
| 71 |
+
df = pd.DataFrame(df_splits)
|
| 72 |
+
fig = px.histogram(df, x="generator", y="score",
|
| 73 |
+
color='model', barmode='group')
|
| 74 |
+
fig.update_layout(xaxis_tickangle=45)
|
| 75 |
+
else:
|
| 76 |
+
# return an empty plot if no valid results
|
| 77 |
+
fig, ax = plt.subplots()
|
| 78 |
+
|
| 79 |
+
return fig, result
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def load_image_as_base64(image_path: str) -> str:
|
| 83 |
+
"""
|
| 84 |
+
Load an image from the given path and return it as a base64 string.
|
| 85 |
+
"""
|
| 86 |
+
with open(image_path, "rb") as f:
|
| 87 |
+
return base64.b64encode(f.read()).decode('utf-8')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ---------------------------------------------------------------------------
|
| 91 |
+
# Gradio UI
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
demo = gr.Interface(
|
| 96 |
+
fn=process_image,
|
| 97 |
+
inputs=[
|
| 98 |
+
gr.Image(type="pil", label="Upload an Image"),
|
| 99 |
+
gr.Dropdown(
|
| 100 |
+
choices=list(models.keys()),
|
| 101 |
+
label="Select Model/s",
|
| 102 |
+
multiselect=True, # removed deprecated type="value"
|
| 103 |
+
),
|
| 104 |
+
],
|
| 105 |
+
outputs=[
|
| 106 |
+
gr.Plot(label="Model Scores Barplot"),
|
| 107 |
+
gr.JSON(label="Attribution Results"),
|
| 108 |
+
],
|
| 109 |
+
title="Detection of Deep Fake Media and Life-Long Media Authentication (FF4ALL)",
|
| 110 |
+
description=(
|
| 111 |
+
"The **FF4ALL** project (Detection of Deep-Fake Media and Life-Long Media Authentication) is part of the "
|
| 112 |
+
"extended partnership **SERICS – Security and Rights in the CyberSpace**, funded by Italy’s National "
|
| 113 |
+
"Recovery and Resilience Plan with *Next Generation EU* resources and coordinated by the University of "
|
| 114 |
+
"Cagliari. FF4ALL’s mission is to design open methodologies, tools and public datasets that make it easier "
|
| 115 |
+
"to identify manipulated or AI-generated images, video and audio, and to preserve their authenticity "
|
| 116 |
+
"throughout the entire life-cycle of a file. The research blends computer vision, machine learning, "
|
| 117 |
+
"cryptography and blockchain to create a unified framework for deep-fake detection, source attribution and "
|
| 118 |
+
"tamper-proof traceability. The resulting technologies are intended to support journalists, law-enforcement "
|
| 119 |
+
"agencies, social-media platforms and ordinary citizens in safeguarding the reliability of digital "
|
| 120 |
+
"information. The project also champions an open-source ethos: it publishes code, data and evaluation "
|
| 121 |
+
"protocols under permissive licences, encourages the adoption of open standards, and trains a new "
|
| 122 |
+
"generation of specialists in forensic media analysis. Through close collaboration with industrial partners "
|
| 123 |
+
"and public institutions, FF4ALL aims to build a resilient national ecosystem capable of shielding society "
|
| 124 |
+
"from the threats posed by disinformation. The demo you are about to use is a public proof-of-concept that "
|
| 125 |
+
"exposes the classifiers developed so far, allowing real-time testing and community feedback.\n\n"
|
| 126 |
+
"### How this demo works\n\n"
|
| 127 |
+
"Upload a JPG or PNG image that you suspect was generated by a diffusion or GAN-based model, then tick one "
|
| 128 |
+
"or more of the **four classifiers** provided by FF4ALL partners:\n\n"
|
| 129 |
+
"• **DE-FAKE**: operates as a novel hybrid classifier that uniquely leverages multimodal features. "
|
| 130 |
+
"It utilizes both the image and its corresponding text prompt, processing them through CLIP's respective encoders. "
|
| 131 |
+
"For images lacking a prompt, it employs the Blip2 model to generate an estimated description, ensuring its dual-input mechanism can always be used.\n\n"
|
| 132 |
+
"• **CLIP+MLP**: uses the CLIP Large model to extract image features without considering any text prompt. "
|
| 133 |
+
"These features are then processed by one of two separate classifiers: either a Multi-Layer Perceptron (MLP) with two hidden layers, trained with an Adam optimizer.\n\n"
|
| 134 |
+
"• **EfficientNetB4**: operates on a patch-level basis, analyzing 50 randomly extracted 96x96 pixel patches from an image. "
|
| 135 |
+
"The final attribution score is the average of the scores from these patches, a method designed to enhance robustness against common manipulations "
|
| 136 |
+
"like compression by focusing on synthetic artifacts over semantic content.\n\n"
|
| 137 |
+
"• **Vision Transformer Classifier (VTC)**: employs the ViT-Base model of [1] pre-trained on ImageNet and fine-tuned on the WILD dataset. "
|
| 138 |
+
"This model processes images by dividing them into 16x16 pixel patches, which are then fed into the transformer's encoder. "
|
| 139 |
+
"The ultimate classification score is derived by aggregating predictions from each patch, a strategy intended to improve resilience to localized image distortions.\n\n"
|
| 140 |
+
"Each network is trained to recognise the visual fingerprints left by popular generators (e.g. Stable "
|
| 141 |
+
"Diffusion, Midjourney, DALL-E 2). Once processed, the app returns a probability distribution that shows "
|
| 142 |
+
"how likely the image was produced by any of the supported generators. \n\n "
|
| 143 |
+
"**Note:** the detectors are *not* intended for authentic photographs — if you upload a real photo the "
|
| 144 |
+
"scores will be uninformative.\n\n "
|
| 145 |
+
"Reference: WILD: a new in-the-Wild Image Linkage Dataset for synthetic image attribution, "
|
| 146 |
+
"Pietro Bongini, Sara Mandelli, Andrea Montibeller, Mirko Casu, Orazio Pontorno, Claudio Vittorio Ragaglia, "
|
| 147 |
+
"Luca Zanchetta, Mattia Aquilina, Taiba Majid Wani, Luca Guarnera, Benedetta Tondi, Giulia Boato, Paolo Bestagini, "
|
| 148 |
+
"Irene Amerini, Francesco De Natale, Sebastiano Battiato, Mauro Barni. "
|
| 149 |
+
"Link to paper: https://arxiv.org/abs/2504.19595) \n\n"
|
| 150 |
+
"[1] A. Dosovitskiy, L. Beyer, A. Kolesnikov, D. Weissenborn, X. Zhai,"
|
| 151 |
+
"T. Unterthiner, M. Dehghani, M. Minderer, G. Heigold, S. Gelly et al., "
|
| 152 |
+
"An image is worth 16x16 words: Transformers for image recognition "
|
| 153 |
+
"at scale,” arXiv preprint arXiv:2010.11929, 2020. \n\n"
|
| 154 |
+
),
|
| 155 |
+
article=f"""
|
| 156 |
+
<br><br>
|
| 157 |
+
<div style="display:flex;justify-content:center;align-items:center;gap:2rem;flex-wrap:wrap;">
|
| 158 |
+
<!-- SERICS logo -->
|
| 159 |
+
<img src="data:image/png;base64,{load_image_as_base64(f'{ASSETS_DIR}/Logo_Serics.png')}" alt="SERICS logo" style="height:60px;" />
|
| 160 |
+
<!-- Three-logo banner: NGEU, MUR, Italia Domani -->
|
| 161 |
+
<div style="display:flex;gap:1rem;">
|
| 162 |
+
<img src="data:image/png;base64,{load_image_as_base64(f'{ASSETS_DIR}/Logo_NGEU.png')}" alt="Next Generation EU logo" style="height:60px;" />
|
| 163 |
+
<img src="data:image/png;base64,{load_image_as_base64(f'{ASSETS_DIR}/Logo_MUR.png')}" alt="MUR logo" style="height:60px;" />
|
| 164 |
+
<img src="data:image/png;base64,{load_image_as_base64(f'{ASSETS_DIR}/Logo_ID.png')}" alt="Italia Domani logo" style="height:60px;" />
|
| 165 |
+
</div>
|
| 166 |
+
</div>
|
| 167 |
+
""",
|
| 168 |
+
allow_flagging="never",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
# Entrypoint
|
| 173 |
+
# ---------------------------------------------------------------------------
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
# demo.launch()
|
| 176 |
+
gr.set_static_paths([str(ASSETS_DIR)]) # Ensure assets are served correctly
|
| 177 |
+
demo.launch(allowed_paths=[str(ASSETS_DIR)])
|
pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "attribution-demonstrator"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"azure-storage-blob>=12.25.1",
|
| 9 |
+
"gradio>=5.33.0,<6",
|
| 10 |
+
"matplotlib>=3.10.3",
|
| 11 |
+
"plotly>=6.1.2",
|
| 12 |
+
"seaborn>=0.13.2",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[dependency-groups]
|
| 16 |
+
dev = [
|
| 17 |
+
"flake8>=7.2.0",
|
| 18 |
+
"pre-commit>=4.2.0",
|
| 19 |
+
"ruff>=0.11.13",
|
| 20 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
azure-storage-blob>=12.25.1
|
| 2 |
+
gradio>=5.33.0,<6
|
| 3 |
+
seaborn~=0.13.2
|
| 4 |
+
matplotlib>=3.10.3
|
| 5 |
+
plotly~=6.1.2
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|