Spaces:
Running
Running
| from transformers import TextClassificationPipeline | |
| import preprocess | |
| import segment | |
| class SponsorBlockClassificationPipeline(TextClassificationPipeline): | |
| def __init__(self, model, tokenizer): | |
| device = next(model.parameters()).device.index | |
| if device is None: | |
| device = -1 | |
| super().__init__(model=model, tokenizer=tokenizer, | |
| return_all_scores=True, truncation=True, device=device) | |
| def preprocess(self, data, **tokenizer_kwargs): | |
| # TODO add support for lists | |
| texts = [] | |
| if not isinstance(data, list): | |
| data = [data] | |
| for d in data: | |
| if isinstance(d, dict): # Otherwise, get data from transcript | |
| words = preprocess.get_words(d['video_id']) | |
| segment_words = segment.extract_segment( | |
| words, d['start'], d['end']) | |
| text = preprocess.clean_text( | |
| ' '.join(x['text'] for x in segment_words)) | |
| texts.append(text) | |
| elif isinstance(d, str): # If string, assume this is what user wants to classify | |
| texts.append(d) | |
| else: | |
| raise ValueError(f'Invalid input type: "{type(d)}"') | |
| return self.tokenizer( | |
| texts, return_tensors=self.framework, **tokenizer_kwargs) | |
| def main(): | |
| pass | |
| if __name__ == '__main__': | |
| main() | |