Spaces:
Running
Running
File size: 15,828 Bytes
cb5b71d bbea1cc cb5b71d 73ebcab cb5b71d bbea1cc cb5b71d bbea1cc cb5b71d bbea1cc 6a31b9a 0c5b67f cb5b71d bbea1cc 041af8a bbea1cc 8c11dd4 041af8a cb5b71d 5a782ad cb5b71d dc92053 8c11dd4 041af8a bbea1cc cb5b71d 73ebcab 1b94fec 73ebcab cb5b71d 73ebcab cb5b71d e92e659 cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d edf454b 73ebcab cb5b71d 73ebcab cb5b71d 1b94fec 73ebcab 7fe906d cb5b71d 8a5a6d8 cb5b71d bc133ae cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 73ebcab cb5b71d 0c5b67f cb5b71d 1b94fec cb5b71d 1b94fec cb5b71d 1b94fec cb5b71d 1b94fec cb5b71d 0c5b67f 73ebcab 6a31b9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 |
"""Streamlit session state.
In the future, this could be the serialization format between front and back.
"""
from __future__ import annotations
import base64
import dataclasses
import datetime
from typing import Any
import uuid
from etils import epath
import pandas as pd
import requests
import streamlit as st
from core.constants import OAUTH_CLIENT_ID
from core.constants import OAUTH_CLIENT_SECRET
from core.constants import PAST_PROJECTS_PATH
from core.constants import PROJECT_FOLDER_PATTERN
from core.constants import REDIRECT_URI
from core.constants import TABS
from core.names import find_unique_name
import mlcroissant as mlc
def create_class(mlc_class: type, instance: Any, **kwargs) -> Any:
"""Creates the mlcroissant class `mlc_class` from the editor `instance`."""
fields = dataclasses.fields(mlc_class)
params: dict[str, Any] = {}
for field in fields:
name = field.name
if hasattr(instance, name) and name not in kwargs:
params[name] = getattr(instance, name)
return mlc_class(**params, **kwargs)
@dataclasses.dataclass
class User:
"""The connected user."""
access_token: str
id_token: str
username: str
@classmethod
def connect(cls, code: str):
credentials = base64.b64encode(
f"{OAUTH_CLIENT_ID}:{OAUTH_CLIENT_SECRET}".encode()
).decode()
headers = {
"Authorization": f"Basic {credentials}",
}
data = {
"client_id": OAUTH_CLIENT_ID,
"grant_type": "authorization_code",
"code": code,
"redirect_uri": REDIRECT_URI,
}
url = "https://huggingface.co/oauth/token"
response = requests.post(url, data=data, headers=headers)
if response.status_code == 200:
response = response.json()
access_token = response.get("access_token")
id_token = response.get("id_token")
if access_token and id_token:
url = "https://huggingface.co/oauth/userinfo"
headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(url, headers=headers)
if response.status_code == 200:
response = response.json()
username = response.get("preferred_username")
if username:
return User(
access_token=access_token,
username=username,
id_token=id_token,
)
raise Exception(
f"Could not connect to Hugging Face. Please, go to {REDIRECT_URI}."
f" ({response=})."
)
def get_user():
"""Get user from session_state."""
return st.session_state.get(User)
@dataclasses.dataclass
class CurrentProject:
"""The selected project."""
path: epath.Path
@classmethod
def create_new(cls) -> CurrentProject | None:
timestamp = datetime.datetime.now().strftime(PROJECT_FOLDER_PATTERN)
return cls.from_timestamp(timestamp)
@classmethod
def from_timestamp(cls, timestamp: str) -> CurrentProject | None:
user = get_user()
if user is None and OAUTH_CLIENT_ID:
return None
else:
path = PAST_PROJECTS_PATH(user)
return CurrentProject(path=path / timestamp)
class SelectedResource:
"""The selected FileSet or FileObject on the `Resources` page."""
pass
@dataclasses.dataclass
class SelectedRecordSet:
"""The selected RecordSet on the `RecordSets` page."""
record_set_key: int
record_set: RecordSet
@dataclasses.dataclass
class Node:
ctx: mlc.Context = dataclasses.field(default_factory=mlc.Context)
id: str | None = None
name: str | None = None
def get_name_or_id(self):
if self.ctx.is_v0():
return self.name
else:
return self.id
@dataclasses.dataclass
class FileObject(Node):
"""FileObject analogue for editor"""
description: str | None = None
contained_in: list[str] | None = dataclasses.field(default_factory=list)
content_size: str | None = None
content_url: str | None = None
encoding_format: str | None = None
sha256: str | None = None
df: pd.DataFrame | None = None
folder: epath.PathLike | None = None
@dataclasses.dataclass
class FileSet(Node):
"""FileSet analogue for editor"""
contained_in: list[str] = dataclasses.field(default_factory=list)
description: str | None = None
encoding_format: str | None = ""
includes: str | None = ""
@dataclasses.dataclass
class Field(Node):
"""Field analogue for editor"""
description: str | None = None
data_types: str | list[str] | None = None
source: mlc.Source | None = None
references: mlc.Source | None = None
@dataclasses.dataclass
class RecordSet(Node):
"""Record Set analogue for editor"""
data: list[Any] | None = None
data_types: list[str] | None = None
description: str | None = None
is_enumeration: bool | None = None
key: str | list[str] | None = None
fields: list[Field] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
class Metadata(Node):
"""main croissant data object, helper functions exist to load and unload this into the mlcroissant version"""
description: str | None = None
cite_as: str | None = None
creators: list[mlc.Person] = dataclasses.field(default_factory=list)
date_published: datetime.datetime | None = None
license: str | None = ""
# RAI extension attributes
data_collection: str | None = None
data_collection_type: str | None = None
data_collection_missing_data: str | None = None
data_collection_raw_data: str | None = None
data_collection_timeframe: datetime.datetime | None = None
data_imputation_protocol: str | None = None
data_preprocessing_protocol: list[str] = None
data_manipulation_protocol: str | None = None
data_annotation_protocol: str | None = None
data_annotation_platform: str | None = None
data_annotation_analysis: str | None = None
annotation_per_item: str | None = None
annotator_demographics: str | None = None
machine_annotation_tools: str | None = None
data_biases: list[str] = None
data_use_cases: list[str] = None
data_limitations: list[str] = None
data_social_impact: str | None = None
personal_sensitive_information: list[str] = None
data_release_maintenance_plan: str | None = None
url: str = ""
distribution: list[FileObject | FileSet] = dataclasses.field(default_factory=list)
record_sets: list[RecordSet] = dataclasses.field(default_factory=list)
version: str | None = None
def __bool__(self):
return self.name != "" and self.url != ""
def rename_distribution(self, old_name: str, new_name: str):
"""Renames a resource by changing all the references to this resource."""
# Update other resources:
for i, resource in enumerate(self.distribution):
if resource.id == old_name:
self.distribution[i].id = new_name
contained_in = resource.contained_in
if contained_in and old_name in contained_in:
self.distribution[i].contained_in = [
new_name if name == old_name else name for name in contained_in
]
# Updating source/references works just as with RecordSets.
self.rename_record_set(old_name, new_name)
def rename_record_set(self, old_name: str, new_name: str):
"""Renames a RecordSet by changing all the references to this RecordSet."""
for i, record_set in enumerate(self.record_sets):
if record_set.id == old_name:
self.record_sets[i].id = new_name
for j, field in enumerate(record_set.fields):
possible_uuid = f"{old_name}/"
# Update source
source = field.source
if source and source.field and source.field.startswith(possible_uuid):
new_uuid = source.field.replace(old_name, new_name, 1)
self.record_sets[i].fields[j].source.field = new_uuid
if source and source.file_object and source.file_object == old_name:
self.record_sets[i].fields[j].source.file_object = new_name
if source and source.file_set and source.file_set == old_name:
self.record_sets[i].fields[j].source.file_set = new_name
if source and source.distribution and source.distribution == old_name:
self.record_sets[i].fields[j].source.distribution = new_name
# Update references
references = field.references
if (
references
and references.field
and references.field.startswith(possible_uuid)
):
new_uuid = references.field.replace(old_name, new_name, 1)
self.record_sets[i].fields[j].references.field = new_uuid
if (
references
and references.file_object
and references.file_object == old_name
):
self.record_sets[i].fields[j].references.file_object = new_name
if (
references
and references.file_set
and references.file_set == old_name
):
self.record_sets[i].fields[j].references.file_set = new_name
if (
references
and references.distribution
and references.distribution == old_name
):
self.record_sets[i].fields[j].references.distribution = new_name
def rename_field(self, old_name: str, new_name: str):
"""Renames a field by changing all the references to this field."""
for i, record_set in enumerate(self.record_sets):
for j, field in enumerate(record_set.fields):
possible_uuid = f"/{old_name}"
# Update source
source = field.source
# The difference with RecordSet is the `.endswith` here:
if source and source.field and source.field.endswith(possible_uuid):
new_uuid = source.field.replace(old_name, new_name, 1)
self.record_sets[i].fields[j].source.field = new_uuid
# Update references
references = field.references
if (
references
and references.field
and references.field.endswith(possible_uuid)
):
new_uuid = references.field.replace(old_name, new_name, 1)
self.record_sets[i].fields[j].references.field = new_uuid
def rename_id(self, old_id: str, new_id: str):
for resource in self.distribution:
if resource.id == old_id:
resource.id = new_id
if resource.contained_in and old_id in resource.contained_in:
resource.contained_in = [
new_id if uuid == old_id else uuid for uuid in resource.contained_in
]
for record_set in self.record_sets:
if record_set.id == old_id:
record_set.id = new_id
for field in record_set.fields:
if field.id == old_id:
field.id = new_id
for p in ["distribution", "field", "file_object", "file_set"]:
if field.source and getattr(field.source, p) == old_id:
setattr(field.source, p, new_id)
if field.references and getattr(field.references, p) == old_id:
setattr(field.references, p, new_id)
def add_distribution(self, distribution: FileSet | FileObject) -> None:
self.distribution.append(distribution)
def remove_distribution(self, key: int) -> None:
del self.distribution[key]
def add_record_set(self, record_set: RecordSet) -> None:
name = find_unique_name(self.names(), record_set.name)
record_set.name = name
self.record_sets.append(record_set)
def remove_record_set(self, key: int) -> None:
del self.record_sets[key]
def _find_record_set(self, record_set_key: int) -> RecordSet:
if record_set_key >= len(self.record_sets):
raise ValueError(f"Wrong index when finding a RecordSet: {record_set_key}")
return self.record_sets[record_set_key]
def add_field(self, record_set_key: int, field: Field) -> None:
record_set = self._find_record_set(record_set_key)
record_set.fields.append(field)
def remove_field(self, record_set_key: int, field_key: int) -> None:
record_set = self._find_record_set(record_set_key)
if field_key >= len(record_set.fields):
raise ValueError(f"Wrong index when removing field: {field_key}")
del record_set.fields[field_key]
def to_canonical(self) -> mlc.Metadata:
distribution = []
ctx = self.ctx
for file in self.distribution:
if isinstance(file, FileObject):
distribution.append(create_class(mlc.FileObject, file, ctx=ctx))
elif isinstance(file, FileSet):
distribution.append(create_class(mlc.FileSet, file, ctx=ctx))
record_sets = []
for record_set in self.record_sets:
fields = []
for field in record_set.fields:
fields.append(create_class(mlc.Field, field, ctx=ctx))
record_sets.append(
create_class(mlc.RecordSet, record_set, ctx=ctx, fields=fields)
)
return create_class(
mlc.Metadata,
self,
distribution=distribution,
record_sets=record_sets,
)
@classmethod
def from_canonical(cls, canonical_metadata: mlc.Metadata) -> Metadata:
distribution = []
for file in canonical_metadata.distribution:
if isinstance(file, mlc.FileObject):
distribution.append(create_class(FileObject, file))
else:
distribution.append(create_class(FileSet, file))
record_sets = []
for record_set in canonical_metadata.record_sets:
fields = []
for field in record_set.fields:
fields.append(create_class(Field, field))
record_sets.append(
create_class(
RecordSet,
record_set,
fields=fields,
)
)
return create_class(
cls,
canonical_metadata,
distribution=distribution,
record_sets=record_sets,
)
def names(self) -> set[str]:
distribution = set()
record_sets = set()
fields = set()
for resource in self.distribution:
distribution.add(resource.get_name_or_id())
for record_set in self.record_sets:
record_sets.add(record_set.get_name_or_id())
for field in record_set.fields:
fields.add(field.get_name_or_id())
return distribution.union(record_sets).union(fields)
class OpenTab:
pass
def get_tab():
tab = st.session_state.get(OpenTab)
if tab is None:
return 0
else:
return tab
def set_tab(tab: str):
if tab not in TABS:
return
index = TABS.index(tab)
st.session_state[OpenTab] = index
|