Juan Leal commited on
Commit
6454648
·
1 Parent(s): d9396ca

Fixes MPS fallback error

Browse files
Files changed (1) hide show
  1. app_api.py +2 -4
app_api.py CHANGED
@@ -1,4 +1,6 @@
1
  # app_api.py
 
 
2
 
3
  import base64
4
  from io import BytesIO
@@ -200,10 +202,6 @@ def inpaint(request: InpaintingRequest):
200
  @app.get("/status/{job_id}")
201
  def get_status(job_id: str):
202
  task = AsyncResult(job_id, app=celery_app)
203
- # Check if the task ID exists
204
- if task.state == states.PENDING and not task.result:
205
- return {"status": "NOT_FOUND"}
206
-
207
  if task.state == states.PENDING:
208
  return {"status": "PENDING"}
209
  elif task.state == states.STARTED:
 
1
  # app_api.py
2
+ import os
3
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
4
 
5
  import base64
6
  from io import BytesIO
 
202
  @app.get("/status/{job_id}")
203
  def get_status(job_id: str):
204
  task = AsyncResult(job_id, app=celery_app)
 
 
 
 
205
  if task.state == states.PENDING:
206
  return {"status": "PENDING"}
207
  elif task.state == states.STARTED: