Spaces:
Runtime error
Runtime error
| # routes.py | |
| from fastapi import APIRouter, Depends, Request | |
| from starlette.responses import RedirectResponse | |
| from auth import oauth | |
| from database import get_or_create_user, update_user_credits, get_user_by_id | |
| from authlib.integrations.starlette_client import OAuthError | |
| import gradio as gr | |
| from utils.stripe_utils import create_checkout_session, verify_webhook, retrieve_stripe_session | |
| router = APIRouter() | |
| def get_user(request: Request): | |
| user = request.session.get('user') | |
| return user['name'] if user else None | |
| def public(request: Request, user = Depends(get_user)): | |
| root_url = gr.route_utils.get_root_url(request, "/", None) | |
| print(f'Root URL: {root_url}') | |
| if user: | |
| return RedirectResponse(url=f'{root_url}/gradio/') | |
| else: | |
| return RedirectResponse(url=f'{root_url}/main/') | |
| async def logout(request: Request): | |
| request.session.pop('user', None) | |
| return RedirectResponse(url='/') | |
| async def login(request: Request): | |
| root_url = gr.route_utils.get_root_url(request, "/login", None) | |
| redirect_uri = f"{root_url}/auth" | |
| return await oauth.google.authorize_redirect(request, redirect_uri) | |
| async def auth(request: Request): | |
| try: | |
| token = await oauth.google.authorize_access_token(request) | |
| user_info = token.get('userinfo') | |
| if user_info: | |
| google_id = user_info['sub'] | |
| email = user_info['email'] | |
| name = user_info['name'] | |
| given_name = user_info['given_name'] | |
| profile_picture = user_info.get('picture', '') | |
| user = get_or_create_user(google_id, email, name, given_name, profile_picture) | |
| request.session['user'] = user | |
| return RedirectResponse(url='/gradio') | |
| else: | |
| return RedirectResponse(url='/main') | |
| except OAuthError as e: | |
| print(f"OAuth Error: {str(e)}") | |
| return RedirectResponse(url='/main') | |
| # Handle Stripe payments | |
| async def buy_credits(request: Request): | |
| user = request.session.get('user') | |
| if not user: | |
| return {"error": "User not authenticated"} | |
| session = create_checkout_session(100, 50, user['id']) # $1 for 50 credits | |
| # Store the session ID and user ID in the session | |
| request.session['stripe_session_id'] = session['id'] | |
| request.session['user_id'] = user['id'] | |
| print(f"Stripe session created: {session['id']} for user {user['id']}") | |
| return RedirectResponse(session['url']) | |
| async def stripe_webhook(request: Request): | |
| payload = await request.body() | |
| sig_header = request.headers.get("Stripe-Signature") | |
| event = verify_webhook(payload, sig_header) | |
| if event is None: | |
| return {"error": "Invalid payload or signature"} | |
| if event['type'] == 'checkout.session.completed': | |
| session = event['data']['object'] | |
| user_id = session.get('client_reference_id') | |
| if user_id: | |
| # Fetch the user from the database | |
| user = get_user_by_id(user_id) # You'll need to implement this function | |
| if user: | |
| # Update user's credits | |
| new_credits = user['generation_credits'] + 50 # Assuming 50 credits were purchased | |
| update_user_credits(user['id'], new_credits, user['train_credits']) | |
| print(f"Credits updated for user {user['id']}") | |
| else: | |
| print(f"User not found for ID: {user_id}") | |
| else: | |
| print("No client_reference_id found in the session") | |
| return {"status": "success"} | |
| # @router.get("/success") | |
| # async def payment_success(request: Request): | |
| # print("Payment successful") | |
| # user = request.session.get('user') | |
| # print(user) | |
| # if user: | |
| # updated_user = get_user_by_id(user['id']) | |
| # if updated_user: | |
| # request.session['user'] = updated_user | |
| # return RedirectResponse(url='/gradio', status_code=303) | |
| # return RedirectResponse(url='/login', status_code=303) | |
| async def payment_cancel(request: Request): | |
| print("Payment cancelled") | |
| user = request.session.get('user') | |
| print(user) | |
| if user: | |
| return RedirectResponse(url='/gradio', status_code=303) | |
| return RedirectResponse(url='/login', status_code=303) | |
| async def payment_success(request: Request): | |
| print("Payment successful") | |
| stripe_session_id = request.session.get('stripe_session_id') | |
| user_id = request.session.get('user_id') | |
| print(f"Session data: stripe_session_id={stripe_session_id}, user_id={user_id}") | |
| if stripe_session_id and user_id: | |
| # Retrieve the Stripe session | |
| stripe_session = retrieve_stripe_session(stripe_session_id) | |
| if stripe_session.get('payment_status') == 'paid': | |
| user = get_user_by_id(user_id) | |
| if user: | |
| # Update the session with the latest user data | |
| request.session['user'] = user | |
| print(f"User session updated: {user}") | |
| # Clear the stripe_session_id and user_id from the session | |
| request.session.pop('stripe_session_id', None) | |
| request.session.pop('user_id', None) | |
| return RedirectResponse(url='/gradio', status_code=303) | |
| else: | |
| print(f"User not found for ID: {user_id}") | |
| else: | |
| print(f"Payment not completed for session: {stripe_session_id}") | |
| else: | |
| print("No Stripe session ID or user ID found in the session") | |
| return RedirectResponse(url='/login', status_code=303) |