marquesafonso's picture
add validation for other video formats; double cache size; adapt to device_type change in transcriber
fc6dd1b
import shutil, os, logging, uvicorn
from typing import Optional
from uuid import uuid4
from tempfile import TemporaryDirectory
from utils.transcriber import transcriber
from utils.process_video import process_video
from utils.zip_response import zip_response
from utils.read_html import read_html
from fastapi import FastAPI, UploadFile, HTTPException, Request, Form, Depends
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse, Response, RedirectResponse
from fastapi.security import HTTPBasic
from pydantic import BaseModel, field_validator
from cachetools import TTLCache
## TODO: improve UI
app = FastAPI()
security = HTTPBasic()
static_dir = os.path.join(os.path.dirname(__file__), 'static')
app.mount("/static", StaticFiles(directory=static_dir), name="static")
templates = Jinja2Templates(directory=static_dir)
cache = TTLCache(maxsize=2048, ttl=600)
class Video(BaseModel):
video_file: UploadFile
@property
def filename(self):
return self.video_file.filename
@property
def file(self):
return self.video_file.file
@field_validator('video_file')
def validate_video_file(cls, v):
video_extensions = ('.webm', '.mkv', '.flv', '.vob', '.ogv', '.ogg', '.rrc', '.gifv',
'.mng', '.mov', '.avi', '.qt', '.wmv', '.yuv', '.rm', '.asf', '.amv', '.mp4',
'.m4p', '.m4v', '.mpg', '.mp2', '.mpeg', '.mpe', '.mpv', '.m4v', '.svi', '.3gp',
'.3g2', '.mxf', '.roq', '.nsv', '.flv', '.f4v', '.f4p', '.f4a', '.f4b', '.mod')
if not v.filename.endswith(video_extensions):
raise HTTPException(status_code=500, detail='Invalid video file type. Please upload an MP4 file.')
return v
@app.get("/")
async def root():
html_content = f"""
{read_html(os.path.join(os.getcwd(),"static/landing_page.html"))}
"""
return HTMLResponse(content=html_content)
@app.get("/transcribe_video/")
async def get_form():
html_content = f"""
{read_html(os.path.join(os.getcwd(),"static/transcribe_video.html"))}
"""
return HTMLResponse(content=html_content)
async def get_temp_dir():
dir = TemporaryDirectory(delete=False)
try:
yield dir
except Exception as e:
HTTPException(status_code=500, detail=str(e))
@app.post("/transcribe/")
async def transcribe_api(video_file: Video = Depends(),
task: str = Form("transcribe"),
model_version: str = Form("deepdml/faster-whisper-large-v3-turbo-ct2"),
max_words_per_line: int = Form(6),
device_type: str = Form("desktop"),
temp_dir: TemporaryDirectory = Depends(get_temp_dir)):
try:
video_path = os.path.join(temp_dir.name, video_file.filename)
with open(video_path, 'wb') as f:
shutil.copyfileobj(video_file.file, f)
transcription_text, transcription_json = transcriber(video_path, max_words_per_line, task, model_version, device_type)
uid = str(uuid4())
cache[uid] = {
"video_path": video_path,
"transcription_text": transcription_text,
"transcription_json": transcription_json,
"temp_dir_path": temp_dir.name,
"device_type": device_type}
return RedirectResponse(url=f"/process_settings/?uid={uid}", status_code=303)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/process_settings/")
async def process_settings(request: Request, uid: str):
data = cache.get(uid)
if not data:
raise HTTPException(404, "Data not found")
return templates.TemplateResponse("process_settings.html", {
"request": request,
"transcription_text": data["transcription_text"],
"transcription_json": data["transcription_json"],
"video_path": data["video_path"],
"temp_dir_path": data["temp_dir_path"],
"device_type": data["device_type"]
})
@app.post("/process_video/")
async def process_video_api(video_path: str = Form(...),
temp_dir_path: str = Form(...),
srt_string: str = Form(...),
srt_json: str = Form(...),
fontsize: Optional[int] = Form(42),
font: Optional[str] = Form("Helvetica-Bold"),
bg_color: Optional[str] = Form("transparent"),
text_color: Optional[str] = Form("white"),
highlight_mode: Optional[bool] = Form(False),
highlight_color: Optional[str] = Form("LightBlue"),
device_type: Optional[str] = Form("desktop"),
temp_dir: TemporaryDirectory = Depends(get_temp_dir)
):
try:
output_path = process_video(video_path, srt_string, srt_json, fontsize, font, bg_color, text_color, highlight_mode, highlight_color, device_type, temp_dir.name)
with open(os.path.join(temp_dir.name, f"{video_path.split('.')[0]}.srt"), 'w+') as temp_srt_file:
logging.info("Processing the video...")
temp_srt_file.write(srt_string)
logging.info("Zipping response...")
with open(os.path.join(temp_dir.name, f"{video_path.split('.')[0]}.zip"), 'w+b') as temp_zip_file:
zip_file = zip_response(temp_zip_file.name, [output_path, temp_srt_file.name])
return Response(
content = zip_file,
media_type="application/zip",
headers={"Content-Disposition": f"attachment; filename={os.path.basename(video_path).split('.')[0]}.zip"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if temp_dir_path and os.path.exists(temp_dir_path):
shutil.rmtree(temp_dir_path, ignore_errors=True)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)