Spaces:
Sleeping
Sleeping
import fastapi | |
import enum | |
import numpy | |
import pydantic | |
import typing | |
import jwt | |
import jwt.algorithms | |
import os | |
class ModelName(str, enum.Enum): | |
alexnet = "alexnet" | |
resnet = "resnet" | |
lenet = "lenet" | |
app = fastapi.FastAPI() | |
print(jwt.algorithms.get_default_algorithms()) | |
array = numpy.array(range(1000)) | |
def check_valid_email(email: str): | |
#initial_query = fastapi.Query(min_length = 3, max_length = 32) | |
if "@" in email: | |
return email | |
raise ValueError("Email is invalid") | |
def index(): | |
return {"name" : "First Data"} | |
async def read_item(item_id: int): | |
return {"item_id": item_id} | |
async def read_user_me(): | |
return {"user_id": "the current user"} | |
async def read_user(user_id: str): | |
return {"user_id": user_id} | |
async def get_model(model_name: ModelName): | |
if model_name == ModelName.alexnet: | |
return {"model_name": model_name, "message": "Deep Learning FTW!"} | |
if model_name.value == "lenet": | |
return {"model_name": model_name, "message": "LeCNN all the images"} | |
return {"model_name": model_name, "message": "Have some residuals"} | |
async def read_file(file_path: str): | |
return {"file_path": file_path} | |
async def get_array(start: int, skip: int = 10): | |
print(array[start : start + skip].tolist()) | |
return {"array": array[start : start + skip].tolist()} | |
async def test(test : bool): | |
return {"test": test} | |
async def read_item(channel_owner: str, product_id: int, q: str = None): | |
return {"product_owner": channel_owner, "product_id": product_id, "q": q} | |
class User(pydantic.BaseModel): | |
username: typing.Annotated[str, fastapi.Query(min_length = 3, max_length = 32)] | |
display_name: str | |
email: typing.Annotated[str, pydantic.AfterValidator(check_valid_email)] | |
async def create_user(user: User): | |
return {"user": user} | |
async def get_followers(channel_owner: str, limit: typing.Annotated[int, fastapi.Path(title = "The number of followers to retrieve")]): | |
return {"channel_owner": channel_owner, "limit": limit} | |
def encode_data(user: User): | |
encryption_key = os.environ.get("ENCRYPTION_KEY") | |
encryption_algorithm = os.environ.get("ENCRYPTION_ALGORITHM") | |
if encryption_key is None: | |
raise ValueError("No ENCRYPTION_KEY set for the application") | |
if encryption_algorithm is None: | |
raise ValueError("No ENCRYPTION_ALGORITHM set for the application") | |
print(f"{encryption_key} and {encryption_algorithm}") | |
print(jwt.get_algorithm_by_name(encryption_algorithm)) | |
payload = { | |
"username": user.username, | |
"display_name": user.display_name, | |
"email": user.email | |
} | |
payload = jwt.encode(payload, encryption_key, algorithm = encryption_algorithm) | |
print(payload) | |
return payload | |
def decode_data(token: str): | |
encryption_key = os.environ.get("ENCRYPTION_KEY") | |
encryption_algorithm = os.environ.get("ENCRYPTION_ALGORITHM") | |
if encryption_key is None: | |
raise ValueError("No ENCRYPTION_KEY set for the application") | |
if encryption_algorithm is None: | |
raise ValueError("No ENCRYPTION_ALGORITHM set for the application") | |
try: | |
payload = jwt.decode(token, encryption_key, algorithms = [encryption_algorithm]) | |
print(payload) | |
return payload | |
except jwt.JWTError: | |
return None | |
async def login(user: User): | |
payload = encode_data(user) | |
print(decode_data(payload)) | |
return {"token": payload} | |
async def verify(token: str): | |
return decode_data(token) | |
# @app.get("") |