import json import os import warnings from typing import List, Tuple, Union, Callable from itertools import chain from .stabilization import valid_ts __all__ = ['result_to_srt_vtt', 'result_to_ass', 'result_to_tsv', 'result_to_txt', 'save_as_json', 'load_result'] SUPPORTED_FORMATS = ('srt', 'vtt', 'ass', 'tsv', 'txt') def _save_as_file(content: str, path: str): with open(path, 'w', encoding='utf-8') as f: f.write(content) print(f'Saved: {os.path.abspath(path)}') def _get_segments(result: (dict, list), min_dur: float, reverse_text: Union[bool, tuple] = False): if isinstance(result, dict): if reverse_text: warnings.warn(f'[reverse_text]=True only applies to WhisperResult but result is {type(result)}') return result.get('segments') elif not isinstance(result, list) and callable(getattr(result, 'segments_to_dicts', None)): return result.apply_min_dur(min_dur, inplace=False).segments_to_dicts(reverse_text=reverse_text) return result def finalize_text(text: str, strip: bool = True): if not strip: return text return text.strip().replace('\n ', '\n') def sec2hhmmss(seconds: (float, int)): mm, ss = divmod(seconds, 60) hh, mm = divmod(mm, 60) return hh, mm, ss def sec2milliseconds(seconds: (float, int)) -> int: return round(seconds * 1000) def sec2centiseconds(seconds: (float, int)) -> int: return round(seconds * 100) def sec2vtt(seconds: (float, int)) -> str: hh, mm, ss = sec2hhmmss(seconds) return f'{hh:0>2.0f}:{mm:0>2.0f}:{ss:0>6.3f}' def sec2srt(seconds: (float, int)) -> str: return sec2vtt(seconds).replace(".", ",") def sec2ass(seconds: (float, int)) -> str: hh, mm, ss = sec2hhmmss(seconds) return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}' def segment2vttblock(segment: dict, strip=True) -> str: return f'{sec2vtt(segment["start"])} --> {sec2vtt(segment["end"])}\n' \ f'{finalize_text(segment["text"], strip)}' def segment2srtblock(segment: dict, idx: int, strip=True) -> str: return f'{idx}\n{sec2srt(segment["start"])} --> {sec2srt(segment["end"])}\n' \ f'{finalize_text(segment["text"], strip)}' def segment2assblock(segment: dict, idx: int, strip=True) -> str: return f'Dialogue: {idx},{sec2ass(segment["start"])},{sec2ass(segment["end"])},Default,,0,0,0,,' \ f'{finalize_text(segment["text"], strip)}' def segment2tsvblock(segment: dict, strip=True) -> str: return f'{sec2milliseconds(segment["start"])}' \ f'\t{sec2milliseconds(segment["end"])}' \ f'\t{segment["text"].strip() if strip else segment["text"]}' def words2segments(words: List[dict], tag: Tuple[str, str], reverse_text: bool = False) -> List[dict]: def add_tag(idx: int): return ''.join( ( f" {tag[0]}{w['word'][1:]}{tag[1]}" if w['word'].startswith(' ') else f"{tag[0]}{w['word']}{tag[1]}" ) if w['word'] not in ('', ' ') and idx_ == idx else w['word'] for idx_, w in idx_filled_words ) filled_words = [] for i, word in enumerate(words): curr_end = round(word['end'], 3) filled_words.append(dict(word=word['word'], start=round(word['start'], 3), end=curr_end)) if word != words[-1]: next_start = round(words[i + 1]['start'], 3) if next_start - curr_end != 0: filled_words.append(dict(word='', start=curr_end, end=next_start)) idx_filled_words = list(enumerate(filled_words)) if reverse_text: idx_filled_words = list(reversed(idx_filled_words)) segments = [dict(text=add_tag(i), start=filled_words[i]['start'], end=filled_words[i]['end']) for i in range(len(filled_words))] return segments def to_word_level_segments(segments: List[dict], tag: Tuple[str, str]) -> List[dict]: return list( chain.from_iterable( words2segments(s['words'], tag, reverse_text=s.get('reversed_text')) for s in segments ) ) def to_vtt_word_level_segments(segments: List[dict], tag: Tuple[str, str] = None) -> List[dict]: def to_segment_string(segment: dict): segment_string = '' prev_end = 0 for i, word in enumerate(segment['words']): if i != 0: curr_start = word['start'] if prev_end == curr_start: segment_string += f"<{sec2vtt(curr_start)}>" else: if segment_string.endswith(' '): segment_string = segment_string[:-1] elif segment['words'][i]['word'].startswith(' '): segment['words'][i]['word'] = segment['words'][i]['word'][1:] segment_string += f"<{sec2vtt(prev_end)}> <{sec2vtt(curr_start)}>" segment_string += word['word'] prev_end = word['end'] return segment_string return [ dict( text=to_segment_string(s), start=s['start'], end=s['end'] ) for s in segments ] def to_ass_word_level_segments(segments: List[dict], tag: Tuple[str, str], karaoke: bool = False) -> List[dict]: def to_segment_string(segment: dict): segment_string = '' for i, word in enumerate(segment['words']): curr_word, space = (word['word'][1:], " ") if word['word'].startswith(" ") else (word['word'], "") segment_string += ( space + r"{\k" + ("f" if karaoke else "") + f"{sec2centiseconds(word['end']-word['start'])}" + r"}" + curr_word ) return segment_string return [ dict( text=to_segment_string(s), start=s['start'], end=s['end'] ) for s in segments ] def to_word_level(segments: List[dict]) -> List[dict]: return [dict(text=w['word'], start=w['start'], end=w['end']) for s in segments for w in s['words']] def _confirm_word_level(segments: List[dict]) -> bool: if not all(bool(s.get('words')) for s in segments): warnings.warn('Result is missing word timestamps. Word-level timing cannot be exported. ' 'Use "word_level=False" to avoid this warning') return False return True def _preprocess_args(result: (dict, list), segment_level: bool, word_level: bool, min_dur: float, reverse_text: Union[bool, tuple] = False): assert segment_level or word_level, '`segment_level` or `word_level` must be True' segments = _get_segments(result, min_dur, reverse_text=reverse_text) if word_level: word_level = _confirm_word_level(segments) return segments, segment_level, word_level def result_to_any(result: (dict, list), filepath: str = None, filetype: str = None, segments2blocks: Callable = None, segment_level=True, word_level=True, min_dur: float = 0.02, tag: Tuple[str, str] = None, default_tag: Tuple[str, str] = None, strip=True, reverse_text: Union[bool, tuple] = False, to_word_level_string_callback: Callable = None): """ Generate file from ``result`` to display segment-level and/or word-level timestamp. Returns ------- str String of the content if ``filepath`` is ``None``. """ segments, segment_level, word_level = _preprocess_args( result, segment_level, word_level, min_dur, reverse_text=reverse_text ) if filetype is None: filetype = os.path.splitext(filepath)[-1][1:] or 'srt' if filetype.lower() not in SUPPORTED_FORMATS: raise NotImplementedError(f'{filetype} not supported') if filepath and not filepath.lower().endswith(f'.{filetype}'): filepath += f'.{filetype}' if word_level and segment_level: if tag is None: if default_tag is None: tag = ('', '') if filetype == 'srt' else ('', '') else: tag = default_tag if to_word_level_string_callback is None: to_word_level_string_callback = to_word_level_segments segments = to_word_level_string_callback(segments, tag) elif word_level: segments = to_word_level(segments) valid_ts(segments) if segments2blocks is None: sub_str = '\n\n'.join(segment2srtblock(s, i, strip=strip) for i, s in enumerate(segments)) else: sub_str = segments2blocks(segments) if filepath: _save_as_file(sub_str, filepath) else: return sub_str def result_to_srt_vtt(result: (dict, list), filepath: str = None, segment_level=True, word_level=True, min_dur: float = 0.02, tag: Tuple[str, str] = None, vtt: bool = None, strip=True, reverse_text: Union[bool, tuple] = False): """ Generate SRT/VTT from ``result`` to display segment-level and/or word-level timestamp. Parameters ---------- result : dict or list or stable_whisper.result.WhisperResult Result of transcription. filepath : str, default None, meaning content will be returned as a ``str`` Path to save file. segment_level : bool, default True Whether to use segment-level timestamps in output. word_level : bool, default True Whether to use word-level timestamps in output. min_dur : float, default 0.2 Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments. tag: tuple of (str, str), default None, meaning ('', '') if SRT else ('', '') Tag used to change the properties a word at its timestamp. vtt : bool, default None, meaning determined by extension of ``filepath`` or ``False`` if no valid extension. Whether to output VTT. strip : bool, default True Whether to remove spaces before and after text on each segment for output. reverse_text: bool or tuple, default False Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and ``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations. Returns ------- str String of the content if ``filepath`` is ``None``. Notes ----- ``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC seems to not suffer from this issue. Examples -------- >>> import stable_whisper >>> model = stable_whisper.load_model('base') >>> result = model.transcribe('audio.mp3') >>> result.to_srt_vtt('audio.srt') Saved: audio.srt """ is_srt = (filepath is None or not filepath.lower().endswith('.vtt')) if vtt is None else not vtt if is_srt: segments2blocks = None to_word_level_string_callback = None else: def segments2blocks(segments): return 'WEBVTT\n\n' + '\n\n'.join(segment2vttblock(s, strip=strip) for i, s in enumerate(segments)) to_word_level_string_callback = to_vtt_word_level_segments if tag is None else tag return result_to_any( result=result, filepath=filepath, filetype=('vtt', 'srt')[is_srt], segments2blocks=segments2blocks, segment_level=segment_level, word_level=word_level, min_dur=min_dur, tag=tag, strip=strip, reverse_text=reverse_text, to_word_level_string_callback=to_word_level_string_callback ) def result_to_tsv(result: (dict, list), filepath: str = None, segment_level: bool = None, word_level: bool = None, min_dur: float = 0.02, strip=True, reverse_text: Union[bool, tuple] = False): """ Generate TSV from ``result`` to display segment-level and/or word-level timestamp. Parameters ---------- result : dict or list or stable_whisper.result.WhisperResult Result of transcription. filepath : str, default None, meaning content will be returned as a ``str`` Path to save file. segment_level : bool, default True Whether to use segment-level timestamps in output. word_level : bool, default True Whether to use word-level timestamps in output. min_dur : float, default 0.2 Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments. strip : bool, default True Whether to remove spaces before and after text on each segment for output. reverse_text: bool or tuple, default False Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and ``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations. Returns ------- str String of the content if ``filepath`` is ``None``. Notes ----- ``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC seems to not suffer from this issue. Examples -------- >>> import stable_whisper >>> model = stable_whisper.load_model('base') >>> result = model.transcribe('audio.mp3') >>> result.to_tsv('audio.tsv') Saved: audio.tsv """ if segment_level is None and word_level is None: segment_level = True assert word_level is not segment_level, '[word_level] and [segment_level] cannot be the same ' \ 'since [tag] is not support for this format' def segments2blocks(segments): return '\n\n'.join(segment2tsvblock(s, strip=strip) for i, s in enumerate(segments)) return result_to_any( result=result, filepath=filepath, filetype='tsv', segments2blocks=segments2blocks, segment_level=segment_level, word_level=word_level, min_dur=min_dur, strip=strip, reverse_text=reverse_text ) def result_to_ass(result: (dict, list), filepath: str = None, segment_level=True, word_level=True, min_dur: float = 0.02, tag: Union[Tuple[str, str], int] = None, font: str = None, font_size: int = 24, strip=True, highlight_color: str = None, karaoke=False, reverse_text: Union[bool, tuple] = False, **kwargs): """ Generate Advanced SubStation Alpha (ASS) file from ``result`` to display segment-level and/or word-level timestamp. Parameters ---------- result : dict or list or stable_whisper.result.WhisperResult Result of transcription. filepath : str, default None, meaning content will be returned as a ``str`` Path to save file. segment_level : bool, default True Whether to use segment-level timestamps in output. word_level : bool, default True Whether to use word-level timestamps in output. min_dur : float, default 0.2 Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments. tag: tuple of (str, str) or int, default None, meaning use default highlighting Tag used to change the properties a word at its timestamp. -1 for individual word highlight tag. font : str, default `Arial` Word font. font_size : int, default 48 Word font size. strip : bool, default True Whether to remove spaces before and after text on each segment for output. highlight_color : str, default '00ff00' Hexadecimal of the color use for default highlights as ''. karaoke : bool, default False Whether to use progressive filling highlights (for karaoke effect). reverse_text: bool or tuple, default False Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and ``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations. kwargs: Format styles: 'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold', 'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline', 'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding' Returns ------- str String of the content if ``filepath`` is ``None``. Notes ----- ``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC seems to not suffer from this issue. Examples -------- >>> import stable_whisper >>> model = stable_whisper.load_model('base') >>> result = model.transcribe('audio.mp3') >>> result.to_ass('audio.ass') Saved: audio.ass """ if tag == ['-1']: # CLI tag = -1 if highlight_color is None: highlight_color = '00ff00' def segments2blocks(segments): fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff', 'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0', 'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100', 'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0', 'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'} for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()): kwargs[k] = f'&H{kwargs[k]}' fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict) if tag is None and 'PrimaryColour' not in kwargs: fmt_style_dict['PrimaryColour'] = \ highlight_color if highlight_color.startswith('&H') else f'&H{highlight_color}' if font: fmt_style_dict.update(Fontname=font) if font_size: fmt_style_dict.update(Fontsize=font_size) fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}' styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}' sub_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \ f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \ f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n' sub_str += '\n'.join(segment2assblock(s, i, strip=strip) for i, s in enumerate(segments)) return sub_str if tag is not None and karaoke: warnings.warn(f'[tag] is not support for [karaoke]=True; [tag] will be ignored.') return result_to_any( result=result, filepath=filepath, filetype='ass', segments2blocks=segments2blocks, segment_level=segment_level, word_level=word_level, min_dur=min_dur, tag=None if tag == -1 else tag, default_tag=(r'{\1c' + f'{highlight_color}&' + '}', r'{\r}'), strip=strip, reverse_text=reverse_text, to_word_level_string_callback=( (lambda s, t: to_ass_word_level_segments(s, t, karaoke=karaoke)) if karaoke or (word_level and segment_level and tag is None) else None ) ) def result_to_txt( result: (dict, list), filepath: str = None, min_dur: float = 0.02, strip=True, reverse_text: Union[bool, tuple] = False ): """ Generate plain-text without timestamps from ``result``. Parameters ---------- result : dict or list or stable_whisper.result.WhisperResult Result of transcription. filepath : str, default None, meaning content will be returned as a ``str`` Path to save file. min_dur : float, default 0.2 Minimum duration allowed for any word/segment before the word/segments are merged with adjacent word/segments. strip : bool, default True Whether to remove spaces before and after text on each segment for output. reverse_text: bool or tuple, default False Whether to reverse the order of words for each segment or provide the ``prepend_punctuations`` and ``append_punctuations`` as tuple pair instead of ``True`` which is for the default punctuations. Returns ------- str String of the content if ``filepath`` is ``None``. Notes ----- ``reverse_text`` will not fix RTL text not displaying tags properly which is an issue with some video player. VLC seems to not suffer from this issue. Examples -------- >>> import stable_whisper >>> model = stable_whisper.load_model('base') >>> result = model.transcribe('audio.mp3') >>> result.to_txt('audio.txt') Saved: audio.txt """ def segments2blocks(segments: dict, _strip=True) -> str: return '\n'.join(f'{segment["text"].strip() if _strip else segment["text"]}' for segment in segments) return result_to_any( result=result, filepath=filepath, filetype='txt', segments2blocks=segments2blocks, segment_level=True, word_level=False, min_dur=min_dur, strip=strip, reverse_text=reverse_text ) def save_as_json(result: dict, path: str, ensure_ascii: bool = False, **kwargs): """ Save ``result`` as JSON file to ``path``. Parameters ---------- result : dict or list or stable_whisper.result.WhisperResult Result of transcription. path : str Path to save file. ensure_ascii : bool, default False Whether to escape non-ASCII characters. Examples -------- >>> import stable_whisper >>> model = stable_whisper.load_model('base') >>> result = model.transcribe('audio.mp3') >>> result.save_as_json('audio.json') Saved: audio.json """ if not isinstance(result, dict) and callable(getattr(result, 'to_dict')): result = result.to_dict() if not path.lower().endswith('.json'): path += '.json' result = json.dumps(result, allow_nan=True, ensure_ascii=ensure_ascii, **kwargs) _save_as_file(result, path) def load_result(json_path: str) -> dict: """ Return a ``dict`` of the contents in ``json_path``. """ with open(json_path, 'r', encoding='utf-8') as f: return json.load(f)