from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from datasets import load_dataset, Dataset

app = FastAPI()

# Custom exception handler for better error visibility
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
    try:
        response = await call_next(request)
        return response
    except Exception as exc:
        return JSONResponse(status_code=500, content={"message": str(exc)})

# Load the datasets from the Hugging Face Hub
customers_dataset = load_dataset("dwb2023/blackbird-customers", split="train")
orders_dataset = load_dataset("dwb2023/blackbird-orders", split="train")

class GetUserInput(BaseModel):
    key: str
    value: str

class UpdateUserInput(BaseModel):
    user_id: str
    email: str = None
    phone: str = None

class GetOrderByIdInput(BaseModel):
    order_id: str

class GetCustomerOrdersInput(BaseModel):
    customer_id: str

class CancelOrderInput(BaseModel):
    order_id: str

class GetUserInfoInput(BaseModel):
    key: str
    value: str

@app.post("/get_user")
def get_user(input: GetUserInput):
    customer = [c for c in customers_dataset if c[input.key] == input.value]
    if customer:
        return customer[0]
    else:
        raise HTTPException(status_code=404, detail=f"Couldn't find a user with {input.key} of {input.value}")

@app.post("/update_user")
def update_user(input: UpdateUserInput):
    global customers_dataset
    customers_data = customers_dataset.to_list()
    user_found = False

    for customer in customers_data:
        if customer["id"] == input.user_id:
            if input.email:
                customer["email"] = input.email
            if input.phone:
                customer["phone"] = input.phone
            user_found = True
            break

    if user_found:
        updated_dataset = Dataset.from_list(customers_data)
        updated_dataset.push_to_hub("dwb2023/blackbird-customers", split="train")
        customers_dataset = updated_dataset  # Update the global dataset
        return {"message": "User information updated successfully"}
    else:
        raise HTTPException(status_code=404, detail=f"Couldn't find a user with ID {input.user_id}")

@app.post("/get_order_by_id")
def get_order_by_id(input: GetOrderByIdInput):
    order = [o for o in orders_dataset if o["id"] == input.order_id]
    if order:
        return order[0]
    else:
        raise HTTPException(status_code=404, detail="Order not found")

@app.post("/get_customer_orders")
def get_customer_orders(input: GetCustomerOrdersInput):
    customer_orders = [o for o in orders_dataset if o["customer_id"] == input.customer_id]
    return customer_orders

@app.post("/cancel_order")
def cancel_order(input: CancelOrderInput):
    global orders_dataset
    orders_data = orders_dataset.to_list()
    order_found = False

    for order in orders_data:
        if order["id"] == input.order_id:
            if order["status"] == "Processing":
                order["status"] = "Cancelled"
                order_found = True
                break
            else:
                raise HTTPException(status_code=400, detail="Order has already shipped. Can't cancel it.")

    if order_found:
        updated_orders_dataset = Dataset.from_list(orders_data)
        updated_orders_dataset.push_to_hub("dwb2023/blackbird-orders", split="train")
        orders_dataset = updated_orders_dataset  # Update the global dataset
        return {"status": "Cancelled"}
    else:
        raise HTTPException(status_code=404, detail="Order not found")

@app.post("/get_user_info")
def get_user_info(input: GetUserInfoInput):
    customer = [c for c in customers_dataset if c[input.key] == input.value]
    if customer:
        customer_id = customer[0]["id"]
        customer_orders = [o for o in orders_dataset if o["customer_id"] == customer_id]
        return {
            "user_info": customer[0],
            "orders": customer_orders
        }
    else:
        raise HTTPException(status_code=404, detail=f"Couldn't find a user with {input.key} of {input.value}")