marquesafonso commited on
Commit
fc6dd1b
·
1 Parent(s): fc4371c

add validation for other video formats; double cache size; adapt to device_type change in transcriber

Browse files
main.py CHANGED
@@ -16,10 +16,6 @@ from fastapi.security import HTTPBasic
16
  from pydantic import BaseModel, field_validator
17
  from cachetools import TTLCache
18
 
19
- ## TODO: add word level highlighting option. WIP (Avoid caption char overflow by using a max chars heuristic in transcriber)
20
- ## TODO: prevent double for submission in process_video/
21
- ## TODO: add more video format options
22
- ## TODO: Add Box + Word highlighting mode options
23
  ## TODO: improve UI
24
 
25
  app = FastAPI()
@@ -27,9 +23,9 @@ security = HTTPBasic()
27
  static_dir = os.path.join(os.path.dirname(__file__), 'static')
28
  app.mount("/static", StaticFiles(directory=static_dir), name="static")
29
  templates = Jinja2Templates(directory=static_dir)
30
- cache = TTLCache(maxsize=1024, ttl=600)
31
 
32
- class MP4Video(BaseModel):
33
  video_file: UploadFile
34
 
35
  @property
@@ -41,7 +37,11 @@ class MP4Video(BaseModel):
41
 
42
  @field_validator('video_file')
43
  def validate_video_file(cls, v):
44
- if not v.filename.endswith('.mp4'):
 
 
 
 
45
  raise HTTPException(status_code=500, detail='Invalid video file type. Please upload an MP4 file.')
46
  return v
47
 
@@ -67,24 +67,26 @@ async def get_temp_dir():
67
  HTTPException(status_code=500, detail=str(e))
68
 
69
  @app.post("/transcribe/")
70
- async def transcribe_api(video_file: MP4Video = Depends(),
71
  task: str = Form("transcribe"),
72
  model_version: str = Form("deepdml/faster-whisper-large-v3-turbo-ct2"),
73
  max_words_per_line: int = Form(6),
 
74
  temp_dir: TemporaryDirectory = Depends(get_temp_dir)):
75
  try:
76
  video_path = os.path.join(temp_dir.name, video_file.filename)
77
  with open(video_path, 'wb') as f:
78
  shutil.copyfileobj(video_file.file, f)
79
 
80
- transcription_text, transcription_json = transcriber(video_path, max_words_per_line, task, model_version)
81
 
82
  uid = str(uuid4())
83
  cache[uid] = {
84
  "video_path": video_path,
85
  "transcription_text": transcription_text,
86
  "transcription_json": transcription_json,
87
- "temp_dir_path": temp_dir.name}
 
88
  return RedirectResponse(url=f"/process_settings/?uid={uid}", status_code=303)
89
 
90
  except Exception as e:
@@ -100,7 +102,8 @@ async def process_settings(request: Request, uid: str):
100
  "transcription_text": data["transcription_text"],
101
  "transcription_json": data["transcription_json"],
102
  "video_path": data["video_path"],
103
- "temp_dir_path": data["temp_dir_path"]
 
104
  })
105
 
106
  @app.post("/process_video/")
@@ -114,11 +117,11 @@ async def process_video_api(video_path: str = Form(...),
114
  text_color: Optional[str] = Form("white"),
115
  highlight_mode: Optional[bool] = Form(False),
116
  highlight_color: Optional[str] = Form("LightBlue"),
117
- caption_mode: Optional[str] = Form("desktop"),
118
  temp_dir: TemporaryDirectory = Depends(get_temp_dir)
119
  ):
120
  try:
121
- output_path = process_video(video_path, srt_string, srt_json, fontsize, font, bg_color, text_color, highlight_mode, highlight_color, caption_mode, temp_dir.name)
122
  with open(os.path.join(temp_dir.name, f"{video_path.split('.')[0]}.srt"), 'w+') as temp_srt_file:
123
  logging.info("Processing the video...")
124
  temp_srt_file.write(srt_string)
 
16
  from pydantic import BaseModel, field_validator
17
  from cachetools import TTLCache
18
 
 
 
 
 
19
  ## TODO: improve UI
20
 
21
  app = FastAPI()
 
23
  static_dir = os.path.join(os.path.dirname(__file__), 'static')
24
  app.mount("/static", StaticFiles(directory=static_dir), name="static")
25
  templates = Jinja2Templates(directory=static_dir)
26
+ cache = TTLCache(maxsize=2048, ttl=600)
27
 
28
+ class Video(BaseModel):
29
  video_file: UploadFile
30
 
31
  @property
 
37
 
38
  @field_validator('video_file')
39
  def validate_video_file(cls, v):
40
+ video_extensions = ('.webm', '.mkv', '.flv', '.vob', '.ogv', '.ogg', '.rrc', '.gifv',
41
+ '.mng', '.mov', '.avi', '.qt', '.wmv', '.yuv', '.rm', '.asf', '.amv', '.mp4',
42
+ '.m4p', '.m4v', '.mpg', '.mp2', '.mpeg', '.mpe', '.mpv', '.m4v', '.svi', '.3gp',
43
+ '.3g2', '.mxf', '.roq', '.nsv', '.flv', '.f4v', '.f4p', '.f4a', '.f4b', '.mod')
44
+ if not v.filename.endswith(video_extensions):
45
  raise HTTPException(status_code=500, detail='Invalid video file type. Please upload an MP4 file.')
46
  return v
47
 
 
67
  HTTPException(status_code=500, detail=str(e))
68
 
69
  @app.post("/transcribe/")
70
+ async def transcribe_api(video_file: Video = Depends(),
71
  task: str = Form("transcribe"),
72
  model_version: str = Form("deepdml/faster-whisper-large-v3-turbo-ct2"),
73
  max_words_per_line: int = Form(6),
74
+ device_type: str = Form("desktop"),
75
  temp_dir: TemporaryDirectory = Depends(get_temp_dir)):
76
  try:
77
  video_path = os.path.join(temp_dir.name, video_file.filename)
78
  with open(video_path, 'wb') as f:
79
  shutil.copyfileobj(video_file.file, f)
80
 
81
+ transcription_text, transcription_json = transcriber(video_path, max_words_per_line, task, model_version, device_type)
82
 
83
  uid = str(uuid4())
84
  cache[uid] = {
85
  "video_path": video_path,
86
  "transcription_text": transcription_text,
87
  "transcription_json": transcription_json,
88
+ "temp_dir_path": temp_dir.name,
89
+ "device_type": device_type}
90
  return RedirectResponse(url=f"/process_settings/?uid={uid}", status_code=303)
91
 
92
  except Exception as e:
 
102
  "transcription_text": data["transcription_text"],
103
  "transcription_json": data["transcription_json"],
104
  "video_path": data["video_path"],
105
+ "temp_dir_path": data["temp_dir_path"],
106
+ "device_type": data["device_type"]
107
  })
108
 
109
  @app.post("/process_video/")
 
117
  text_color: Optional[str] = Form("white"),
118
  highlight_mode: Optional[bool] = Form(False),
119
  highlight_color: Optional[str] = Form("LightBlue"),
120
+ device_type: Optional[str] = Form("desktop"),
121
  temp_dir: TemporaryDirectory = Depends(get_temp_dir)
122
  ):
123
  try:
124
+ output_path = process_video(video_path, srt_string, srt_json, fontsize, font, bg_color, text_color, highlight_mode, highlight_color, device_type, temp_dir.name)
125
  with open(os.path.join(temp_dir.name, f"{video_path.split('.')[0]}.srt"), 'w+') as temp_srt_file:
126
  logging.info("Processing the video...")
127
  temp_srt_file.write(srt_string)
static/process_settings.html CHANGED
@@ -165,17 +165,12 @@
165
  <select id="highlight_color" name="highlight_color">
166
  <option>Loading colors...</option>
167
  </select>
168
-
169
- <label for="caption_mode">Caption mode</label>
170
- <select name="caption_mode">
171
- <option value="desktop">Desktop</option>
172
- <option value="mobile">Mobile</option>
173
- </select>
174
  </div>
175
  </div>
176
 
177
  <input type="hidden" name="video_path" value="{{ video_path }}">
178
  <input type="hidden" name="temp_dir_path" value="{{ temp_dir_path }}">
 
179
  <input type="submit" name="submitButton" value="Submit">
180
  </form>
181
 
 
165
  <select id="highlight_color" name="highlight_color">
166
  <option>Loading colors...</option>
167
  </select>
 
 
 
 
 
 
168
  </div>
169
  </div>
170
 
171
  <input type="hidden" name="video_path" value="{{ video_path }}">
172
  <input type="hidden" name="temp_dir_path" value="{{ temp_dir_path }}">
173
+ <input type="hidden" name="device_type" value="{{ device_type }}">
174
  <input type="submit" name="submitButton" value="Submit">
175
  </form>
176
 
static/transcribe_video.html CHANGED
@@ -8,6 +8,8 @@
8
  form { background: white; padding: 2rem; border-radius: 10px; max-width: 600px; margin: auto; }
9
  label, select, input { display: block; width: 100%; margin-bottom: 1rem; }
10
  input[type="submit"] { background: #4CAF50; color: white; padding: 0.8rem; border: none; cursor: pointer; }
 
 
11
  </style>
12
  </head>
13
  <body>
@@ -32,6 +34,12 @@
32
 
33
  <label for="max_words_per_line">Max words per line</label>
34
  <input type="number" name="max_words_per_line" id="max_words_per_line" value="6">
 
 
 
 
 
 
35
 
36
  <div id="loading" style="display:none; text-align: center; margin-top: 10px; margin-bottom: 10px; font-weight: bold;">
37
  <i class="fas fa-spinner fa-spin"></i> Processing, please wait...
 
8
  form { background: white; padding: 2rem; border-radius: 10px; max-width: 600px; margin: auto; }
9
  label, select, input { display: block; width: 100%; margin-bottom: 1rem; }
10
  input[type="submit"] { background: #4CAF50; color: white; padding: 0.8rem; border: none; cursor: pointer; }
11
+ .radio-container { display: flex; gap: 2rem; margin-bottom: 1rem;}
12
+ .radio-option { display: flex; flex-direction: column; align-items: flex-start;}
13
  </style>
14
  </head>
15
  <body>
 
34
 
35
  <label for="max_words_per_line">Max words per line</label>
36
  <input type="number" name="max_words_per_line" id="max_words_per_line" value="6">
37
+
38
+ <label for="device_type">Device Type</label>
39
+ <select name="device_type">
40
+ <option value="desktop">Desktop</option>
41
+ <option value="mobile">Mobile</option>
42
+ </select>
43
 
44
  <div id="loading" style="display:none; text-align: center; margin-top: 10px; margin-bottom: 10px; font-weight: bold;">
45
  <i class="fas fa-spinner fa-spin"></i> Processing, please wait...
utils/process_video.py CHANGED
@@ -10,12 +10,12 @@ def process_video(invideo_file: str,
10
  text_color:str,
11
  highlight_mode: bool,
12
  highlight_color: str,
13
- caption_mode:str,
14
  temp_dir: str
15
  ):
16
  invideo_path_parts = os.path.normpath(invideo_file).split(os.path.sep)
17
  VIDEO_NAME = os.path.basename(invideo_file)
18
  OUTVIDEO_PATH = os.path.join(os.path.normpath('/'.join(invideo_path_parts[:-1])), f"result_{VIDEO_NAME}")
19
  logging.info("Subtitling...")
20
- subtitler(invideo_file, srt_string, srt_json, OUTVIDEO_PATH, fontsize, font, bg_color, text_color, highlight_mode, highlight_color, caption_mode, temp_dir)
21
  return OUTVIDEO_PATH
 
10
  text_color:str,
11
  highlight_mode: bool,
12
  highlight_color: str,
13
+ device_type:str,
14
  temp_dir: str
15
  ):
16
  invideo_path_parts = os.path.normpath(invideo_file).split(os.path.sep)
17
  VIDEO_NAME = os.path.basename(invideo_file)
18
  OUTVIDEO_PATH = os.path.join(os.path.normpath('/'.join(invideo_path_parts[:-1])), f"result_{VIDEO_NAME}")
19
  logging.info("Subtitling...")
20
+ subtitler(invideo_file, srt_string, srt_json, OUTVIDEO_PATH, fontsize, font, bg_color, text_color, highlight_mode, highlight_color, device_type, temp_dir)
21
  return OUTVIDEO_PATH
utils/subtitler.py CHANGED
@@ -18,11 +18,11 @@ def parse_srt(srt_string):
18
  i += 1
19
  return subtitles
20
 
21
- def filter_caption_width(caption_mode:str):
22
- if caption_mode == 'desktop':
23
  caption_width_ratio = 0.5
24
  caption_height_ratio = 0.8
25
- elif caption_mode == 'mobile':
26
  caption_width_ratio = 0.2
27
  caption_height_ratio = 0.7
28
  return caption_width_ratio, caption_height_ratio
@@ -38,7 +38,7 @@ def subtitler(video_file: str,
38
  text_color: str,
39
  highlight_mode: bool,
40
  highlight_color: str,
41
- caption_mode: str,
42
  temp_dir: str
43
  ):
44
  """Add subtitles to a video, with optional word-level highlighting."""
@@ -49,7 +49,7 @@ def subtitler(video_file: str,
49
 
50
  subtitle_clips = []
51
 
52
- caption_width_ratio, caption_height_ratio = filter_caption_width(caption_mode)
53
  subtitle_y_position = clip.h * caption_height_ratio
54
  if highlight_mode:
55
  srt_data = json.loads(json.dumps(eval(srt_json)))
 
18
  i += 1
19
  return subtitles
20
 
21
+ def filter_caption_width(device_type:str):
22
+ if device_type == 'desktop':
23
  caption_width_ratio = 0.5
24
  caption_height_ratio = 0.8
25
+ elif device_type == 'mobile':
26
  caption_width_ratio = 0.2
27
  caption_height_ratio = 0.7
28
  return caption_width_ratio, caption_height_ratio
 
38
  text_color: str,
39
  highlight_mode: bool,
40
  highlight_color: str,
41
+ device_type: str,
42
  temp_dir: str
43
  ):
44
  """Add subtitles to a video, with optional word-level highlighting."""
 
49
 
50
  subtitle_clips = []
51
 
52
+ caption_width_ratio, caption_height_ratio = filter_caption_width(device_type)
53
  subtitle_y_position = clip.h * caption_height_ratio
54
  if highlight_mode:
55
  srt_data = json.loads(json.dumps(eval(srt_json)))
utils/transcriber.py CHANGED
@@ -5,7 +5,8 @@ from dotenv import load_dotenv
5
  def transcriber(invideo_file:str,
6
  max_words_per_line:int,
7
  task:str,
8
- model_version:str
 
9
  ):
10
  load_dotenv()
11
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -17,6 +18,7 @@ def transcriber(invideo_file:str,
17
  max_words_per_line=max_words_per_line,
18
  task=task,
19
  model_version=model_version,
 
20
  api_name="/predict"
21
  )
22
  return result[0], result[3]
 
5
  def transcriber(invideo_file:str,
6
  max_words_per_line:int,
7
  task:str,
8
+ model_version:str,
9
+ device_type:str
10
  ):
11
  load_dotenv()
12
  HF_TOKEN = os.getenv("HF_TOKEN")
 
18
  max_words_per_line=max_words_per_line,
19
  task=task,
20
  model_version=model_version,
21
+ device_type=device_type,
22
  api_name="/predict"
23
  )
24
  return result[0], result[3]