FastAPI_Test / main.py
xqt's picture
Update main.py
f728e97 verified
raw
history blame
3.9 kB
import fastapi
import enum
import numpy
import pydantic
import typing
import jwt
import jwt.algorithms
import os
class ModelName(str, enum.Enum):
alexnet = "alexnet"
resnet = "resnet"
lenet = "lenet"
app = fastapi.FastAPI()
print(jwt.algorithms.get_default_algorithms())
array = numpy.array(range(1000))
def check_valid_email(email: str):
#initial_query = fastapi.Query(min_length = 3, max_length = 32)
if "@" in email:
return email
raise ValueError("Email is invalid")
@app.get("/")
def index():
return {"name" : "First Data"}
@app.get("/items/{item_id}")
async def read_item(item_id: int):
return {"item_id": item_id}
@app.get("/users/me")
async def read_user_me():
return {"user_id": "the current user"}
@app.get("/users/{user_id}")
async def read_user(user_id: str):
return {"user_id": user_id}
@app.get("/models/{model_name}")
async def get_model(model_name: ModelName):
if model_name == ModelName.alexnet:
return {"model_name": model_name, "message": "Deep Learning FTW!"}
if model_name.value == "lenet":
return {"model_name": model_name, "message": "LeCNN all the images"}
return {"model_name": model_name, "message": "Have some residuals"}
@app.get("/files/{file_path:path}")
async def read_file(file_path: str):
return {"file_path": file_path}
@app.get("/array_at_val")
async def get_array(start: int, skip: int = 10):
print(array[start : start + skip].tolist())
return {"array": array[start : start + skip].tolist()}
@app.get("/test")
async def test(test : bool):
return {"test": test}
@app.get("/c/{channel_owner}/store/{product_id}")
async def read_item(channel_owner: str, product_id: int, q: str = None):
return {"product_owner": channel_owner, "product_id": product_id, "q": q}
class User(pydantic.BaseModel):
username: typing.Annotated[str, fastapi.Query(min_length = 3, max_length = 32)]
display_name: str
email: typing.Annotated[str, pydantic.AfterValidator(check_valid_email)]
@app.post("/create_user")
async def create_user(user: User):
return {"user": user}
@app.get("c/{channel_owner}/followers")
async def get_followers(channel_owner: str, limit: typing.Annotated[int, fastapi.Path(title = "The number of followers to retrieve")]):
return {"channel_owner": channel_owner, "limit": limit}
def encode_data(user: User):
encryption_key = os.environ.get("ENCRYPTION_KEY")
encryption_algorithm = os.environ.get("ENCRYPTION_ALGORITHM")
if encryption_key is None:
raise ValueError("No ENCRYPTION_KEY set for the application")
if encryption_algorithm is None:
raise ValueError("No ENCRYPTION_ALGORITHM set for the application")
print(f"{encryption_key} and {encryption_algorithm}")
print(jwt.get_algorithm_by_name(encryption_algorithm))
payload = {
"username": user.username,
"display_name": user.display_name,
"email": user.email
}
payload = jwt.encode(payload, encryption_key, algorithm = encryption_algorithm)
print(payload)
return payload
def decode_data(token: str):
encryption_key = os.environ.get("ENCRYPTION_KEY")
encryption_algorithm = os.environ.get("ENCRYPTION_ALGORITHM")
if encryption_key is None:
raise ValueError("No ENCRYPTION_KEY set for the application")
if encryption_algorithm is None:
raise ValueError("No ENCRYPTION_ALGORITHM set for the application")
try:
payload = jwt.decode(token, encryption_key, algorithms = [encryption_algorithm])
print(payload)
return payload
except jwt.JWTError:
return None
@app.post("/login")
async def login(user: User):
payload = encode_data(user)
print(decode_data(payload))
return {"token": payload}
@app.get("/verify")
async def verify(token: str):
return decode_data(token)
# @app.get("")