Spaces:
Running
Running
| import re | |
| from contextlib import contextmanager | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from fuzzysearch import find_near_matches | |
| from pyarabic import araby | |
| from torch import nn | |
| from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex | |
| multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL) | |
| # ASAD-NEW_AraBERT_PREP-Balanced | |
| class NewArabicPreprocessorBalanced(ArabertPreprocessor): | |
| def __init__( | |
| self, | |
| model_name: str, | |
| keep_emojis: bool = False, | |
| remove_html_markup: bool = True, | |
| replace_urls_emails_mentions: bool = True, | |
| strip_tashkeel: bool = True, | |
| strip_tatweel: bool = True, | |
| insert_white_spaces: bool = True, | |
| remove_non_digit_repetition: bool = True, | |
| replace_slash_with_dash: bool = None, | |
| map_hindi_numbers_to_arabic: bool = None, | |
| apply_farasa_segmentation: bool = None, | |
| ): | |
| if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name: | |
| keep_emojis = True | |
| remove_non_digit_repetition = True | |
| super().__init__( | |
| model_name=model_name, | |
| keep_emojis=keep_emojis, | |
| remove_html_markup=remove_html_markup, | |
| replace_urls_emails_mentions=replace_urls_emails_mentions, | |
| strip_tashkeel=strip_tashkeel, | |
| strip_tatweel=strip_tatweel, | |
| insert_white_spaces=insert_white_spaces, | |
| remove_non_digit_repetition=remove_non_digit_repetition, | |
| replace_slash_with_dash=replace_slash_with_dash, | |
| map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic, | |
| apply_farasa_segmentation=apply_farasa_segmentation, | |
| ) | |
| self.true_model_name = model_name | |
| def preprocess(self, text): | |
| if "UBC-NLP" in self.true_model_name: | |
| return self.ubc_prep(text) | |
| def ubc_prep(self, text): | |
| text = re.sub("\s", " ", text) | |
| text = text.replace("\\n", " ") | |
| text = text.replace("\\r", " ") | |
| text = araby.strip_tashkeel(text) | |
| text = araby.strip_tatweel(text) | |
| # replace all possible URLs | |
| for reg in url_regexes: | |
| text = re.sub(reg, " URL ", text) | |
| text = re.sub("(URL\s*)+", " URL ", text) | |
| # replace mentions with USER | |
| text = re.sub(user_mention_regex, " USER ", text) | |
| text = re.sub("(USER\s*)+", " USER ", text) | |
| # replace hashtags with HASHTAG | |
| # text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
| text = text.replace("#", " HASH ") | |
| text = text.replace("_", " ") | |
| text = " ".join(text.split()) | |
| # text = re.sub("\B\\[Uu]\w+", "", text) | |
| text = text.replace("\\U0001f97a", "🥺") | |
| text = text.replace("\\U0001f928", "🤨") | |
| text = text.replace("\\U0001f9d8", "😀") | |
| text = text.replace("\\U0001f975", "😥") | |
| text = text.replace("\\U0001f92f", "😲") | |
| text = text.replace("\\U0001f92d", "🤭") | |
| text = text.replace("\\U0001f9d1", "😐") | |
| text = text.replace("\\U000e0067", "") | |
| text = text.replace("\\U000e006e", "") | |
| text = text.replace("\\U0001f90d", "♥") | |
| text = text.replace("\\U0001f973", "🎉") | |
| text = text.replace("\\U0001fa79", "") | |
| text = text.replace("\\U0001f92b", "🤐") | |
| text = text.replace("\\U0001f9da", "🦋") | |
| text = text.replace("\\U0001f90e", "♥") | |
| text = text.replace("\\U0001f9d0", "🧐") | |
| text = text.replace("\\U0001f9cf", "") | |
| text = text.replace("\\U0001f92c", "😠") | |
| text = text.replace("\\U0001f9f8", "😸") | |
| text = text.replace("\\U0001f9b6", "💩") | |
| text = text.replace("\\U0001f932", "🤲") | |
| text = text.replace("\\U0001f9e1", "🧡") | |
| text = text.replace("\\U0001f974", "☹") | |
| text = text.replace("\\U0001f91f", "") | |
| text = text.replace("\\U0001f9fb", "💩") | |
| text = text.replace("\\U0001f92a", "🤪") | |
| text = text.replace("\\U0001f9fc", "") | |
| text = text.replace("\\U000e0065", "") | |
| text = text.replace("\\U0001f92e", "💩") | |
| text = text.replace("\\U000e007f", "") | |
| text = text.replace("\\U0001f970", "🥰") | |
| text = text.replace("\\U0001f929", "🤩") | |
| text = text.replace("\\U0001f6f9", "") | |
| text = text.replace("🤍", "♥") | |
| text = text.replace("🦠", "😷") | |
| text = text.replace("🤢", "مقرف") | |
| text = text.replace("🤮", "مقرف") | |
| text = text.replace("🕠", "⌚") | |
| text = text.replace("🤬", "😠") | |
| text = text.replace("🤧", "😷") | |
| text = text.replace("🥳", "🎉") | |
| text = text.replace("🥵", "🔥") | |
| text = text.replace("🥴", "☹") | |
| text = text.replace("🤫", "🤐") | |
| text = text.replace("🤥", "كذاب") | |
| text = text.replace("\\u200d", " ") | |
| text = text.replace("u200d", " ") | |
| text = text.replace("\\u200c", " ") | |
| text = text.replace("u200c", " ") | |
| text = text.replace('"', "'") | |
| text = text.replace("\\xa0", "") | |
| text = text.replace("\\u2066", " ") | |
| text = re.sub("\B\\\[Uu]\w+", "", text) | |
| text = super(NewArabicPreprocessorBalanced, self).preprocess(text) | |
| text = " ".join(text.split()) | |
| return text | |
| """CNNMarbertArabicPreprocessor""" | |
| # ASAD-CNN_MARBERT | |
| class CNNMarbertArabicPreprocessor(ArabertPreprocessor): | |
| def __init__( | |
| self, | |
| model_name, | |
| keep_emojis=False, | |
| remove_html_markup=True, | |
| replace_urls_emails_mentions=True, | |
| remove_elongations=True, | |
| ): | |
| if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name: | |
| keep_emojis = True | |
| remove_elongations = False | |
| super().__init__( | |
| model_name, | |
| keep_emojis, | |
| remove_html_markup, | |
| replace_urls_emails_mentions, | |
| remove_elongations, | |
| ) | |
| self.true_model_name = model_name | |
| def preprocess(self, text): | |
| if "UBC-NLP" in self.true_model_name: | |
| return self.ubc_prep(text) | |
| def ubc_prep(self, text): | |
| text = re.sub("\s", " ", text) | |
| text = text.replace("\\n", " ") | |
| text = araby.strip_tashkeel(text) | |
| text = araby.strip_tatweel(text) | |
| # replace all possible URLs | |
| for reg in url_regexes: | |
| text = re.sub(reg, " URL ", text) | |
| text = re.sub("(URL\s*)+", " URL ", text) | |
| # replace mentions with USER | |
| text = re.sub(user_mention_regex, " USER ", text) | |
| text = re.sub("(USER\s*)+", " USER ", text) | |
| # replace hashtags with HASHTAG | |
| # text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
| text = text.replace("#", " HASH ") | |
| text = text.replace("_", " ") | |
| text = " ".join(text.split()) | |
| text = super(CNNMarbertArabicPreprocessor, self).preprocess(text) | |
| text = text.replace("\u200d", " ") | |
| text = text.replace("u200d", " ") | |
| text = text.replace("\u200c", " ") | |
| text = text.replace("u200c", " ") | |
| text = text.replace('"', "'") | |
| # text = re.sub('[\d\.]+', ' NUM ', text) | |
| # text = re.sub('(NUM\s*)+', ' NUM ', text) | |
| text = multiple_char_pattern.sub(r"\1\1", text) | |
| text = " ".join(text.split()) | |
| return text | |
| """Trial5ArabicPreprocessor""" | |
| class Trial5ArabicPreprocessor(ArabertPreprocessor): | |
| def __init__( | |
| self, | |
| model_name, | |
| keep_emojis=False, | |
| remove_html_markup=True, | |
| replace_urls_emails_mentions=True, | |
| ): | |
| if "UBC-NLP" in model_name: | |
| keep_emojis = True | |
| super().__init__( | |
| model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions | |
| ) | |
| self.true_model_name = model_name | |
| def preprocess(self, text): | |
| if "UBC-NLP" in self.true_model_name: | |
| return self.ubc_prep(text) | |
| def ubc_prep(self, text): | |
| text = re.sub("\s", " ", text) | |
| text = text.replace("\\n", " ") | |
| text = araby.strip_tashkeel(text) | |
| text = araby.strip_tatweel(text) | |
| # replace all possible URLs | |
| for reg in url_regexes: | |
| text = re.sub(reg, " URL ", text) | |
| # replace mentions with USER | |
| text = re.sub(user_mention_regex, " USER ", text) | |
| # replace hashtags with HASHTAG | |
| # text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
| text = text.replace("#", " HASH TAG ") | |
| text = text.replace("_", " ") | |
| text = " ".join(text.split()) | |
| text = super(Trial5ArabicPreprocessor, self).preprocess(text) | |
| # text = text.replace("السلام عليكم"," ") | |
| # text = text.replace(find_near_matches("السلام عليكم",text,max_deletions=3,max_l_dist=3)[0].matched," ") | |
| return text | |
| """SarcasmArabicPreprocessor""" | |
| class SarcasmArabicPreprocessor(ArabertPreprocessor): | |
| def __init__( | |
| self, | |
| model_name, | |
| keep_emojis=False, | |
| remove_html_markup=True, | |
| replace_urls_emails_mentions=True, | |
| ): | |
| if "UBC-NLP" in model_name: | |
| keep_emojis = True | |
| super().__init__( | |
| model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions | |
| ) | |
| self.true_model_name = model_name | |
| def preprocess(self, text): | |
| if "UBC-NLP" in self.true_model_name: | |
| return self.ubc_prep(text) | |
| else: | |
| return super(SarcasmArabicPreprocessor, self).preprocess(text) | |
| def ubc_prep(self, text): | |
| text = re.sub("\s", " ", text) | |
| text = text.replace("\\n", " ") | |
| text = araby.strip_tashkeel(text) | |
| text = araby.strip_tatweel(text) | |
| # replace all possible URLs | |
| for reg in url_regexes: | |
| text = re.sub(reg, " URL ", text) | |
| # replace mentions with USER | |
| text = re.sub(user_mention_regex, " USER ", text) | |
| # replace hashtags with HASHTAG | |
| # text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
| text = text.replace("#", " HASH TAG ") | |
| text = text.replace("_", " ") | |
| text = text.replace('"', " ") | |
| text = " ".join(text.split()) | |
| text = super(SarcasmArabicPreprocessor, self).preprocess(text) | |
| return text | |
| """NoAOAArabicPreprocessor""" | |
| class NoAOAArabicPreprocessor(ArabertPreprocessor): | |
| def __init__( | |
| self, | |
| model_name, | |
| keep_emojis=False, | |
| remove_html_markup=True, | |
| replace_urls_emails_mentions=True, | |
| ): | |
| if "UBC-NLP" in model_name: | |
| keep_emojis = True | |
| super().__init__( | |
| model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions | |
| ) | |
| self.true_model_name = model_name | |
| def preprocess(self, text): | |
| if "UBC-NLP" in self.true_model_name: | |
| return self.ubc_prep(text) | |
| else: | |
| return super(NoAOAArabicPreprocessor, self).preprocess(text) | |
| def ubc_prep(self, text): | |
| text = re.sub("\s", " ", text) | |
| text = text.replace("\\n", " ") | |
| text = araby.strip_tashkeel(text) | |
| text = araby.strip_tatweel(text) | |
| # replace all possible URLs | |
| for reg in url_regexes: | |
| text = re.sub(reg, " URL ", text) | |
| # replace mentions with USER | |
| text = re.sub(user_mention_regex, " USER ", text) | |
| # replace hashtags with HASHTAG | |
| # text = re.sub(r"#[\w\d]+", " HASH TAG ", text) | |
| text = text.replace("#", " HASH TAG ") | |
| text = text.replace("_", " ") | |
| text = " ".join(text.split()) | |
| text = super(NoAOAArabicPreprocessor, self).preprocess(text) | |
| text = text.replace("السلام عليكم", " ") | |
| text = text.replace("ورحمة الله وبركاته", " ") | |
| matched = find_near_matches("السلام عليكم", text, max_deletions=3, max_l_dist=3) | |
| if len(matched) > 0: | |
| text = text.replace(matched[0].matched, " ") | |
| matched = find_near_matches( | |
| "ورحمة الله وبركاته", text, max_deletions=3, max_l_dist=3 | |
| ) | |
| if len(matched) > 0: | |
| text = text.replace(matched[0].matched, " ") | |
| return text | |
| class CnnBertForSequenceClassification(BertPreTrainedModel): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.num_labels = config.num_labels | |
| self.config = config | |
| self.bert = BertModel(config) | |
| filter_sizes = [1, 2, 3, 4, 5] | |
| num_filters = 32 | |
| self.convs1 = nn.ModuleList( | |
| [nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes] | |
| ) | |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
| self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels) | |
| self.init_weights() | |
| def forward( | |
| self, | |
| input_ids=None, | |
| attention_mask=None, | |
| token_type_ids=None, | |
| position_ids=None, | |
| head_mask=None, | |
| inputs_embeds=None, | |
| labels=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| return_dict=None, | |
| ): | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| outputs = self.bert( | |
| input_ids, | |
| attention_mask=attention_mask, | |
| token_type_ids=token_type_ids, | |
| position_ids=position_ids, | |
| head_mask=head_mask, | |
| inputs_embeds=inputs_embeds, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| x = outputs[2][-4:] | |
| x = torch.stack(x, dim=1) | |
| x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1] | |
| x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] | |
| x = torch.cat(x, 1) | |
| x = self.dropout(x) | |
| logits = self.classifier(x) | |
| loss = None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and ( | |
| labels.dtype == torch.long or labels.dtype == torch.int | |
| ): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = nn.MSELoss() | |
| if self.num_labels == 1: | |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
| else: | |
| loss = loss_fct(logits, labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = nn.CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = nn.BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| if not return_dict: | |
| output = (logits,) + outputs[2:] | |
| return ((loss,) + output) if loss is not None else output | |
| return SequenceClassifierOutput( | |
| loss=loss, | |
| logits=logits, | |
| hidden_states=None, | |
| attentions=outputs.attentions, | |
| ) | |
| class CNNTextClassificationPipeline: | |
| def __init__(self, model_path, device, return_all_scores=False): | |
| self.model_path = model_path | |
| self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path) | |
| # Special handling | |
| self.device = torch.device("cpu" if device < 0 else f"cuda:{device}") | |
| if self.device.type == "cuda": | |
| self.model = self.model.to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.return_all_scores = return_all_scores | |
| def device_placement(self): | |
| """ | |
| Context Manager allowing tensor allocation on the user-specified device in framework agnostic way. | |
| Returns: | |
| Context manager | |
| Examples:: | |
| # Explicitly ask for tensor allocation on CUDA device :0 | |
| pipe = pipeline(..., device=0) | |
| with pipe.device_placement(): | |
| # Every framework specific tensor allocation will be done on the request device | |
| output = pipe(...) | |
| """ | |
| if self.device.type == "cuda": | |
| torch.cuda.set_device(self.device) | |
| yield | |
| def ensure_tensor_on_device(self, **inputs): | |
| """ | |
| Ensure PyTorch tensors are on the specified device. | |
| Args: | |
| inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`. | |
| Return: | |
| :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device. | |
| """ | |
| return { | |
| name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor | |
| for name, tensor in inputs.items() | |
| } | |
| def __call__(self, text): | |
| """ | |
| Classify the text(s) given as inputs. | |
| Args: | |
| args (:obj:`str` or :obj:`List[str]`): | |
| One or several texts (or one list of prompts) to classify. | |
| Return: | |
| A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys: | |
| - **label** (:obj:`str`) -- The label predicted. | |
| - **score** (:obj:`float`) -- The corresponding probability. | |
| If ``self.return_all_scores=True``, one such dictionary is returned per label. | |
| """ | |
| # outputs = super().__call__(*args, **kwargs) | |
| inputs = self.tokenizer.batch_encode_plus( | |
| text, | |
| add_special_tokens=True, | |
| max_length=64, | |
| padding=True, | |
| truncation="longest_first", | |
| return_tensors="pt", | |
| ) | |
| with torch.no_grad(): | |
| inputs = self.ensure_tensor_on_device(**inputs) | |
| predictions = self.model(**inputs)[0].cpu() | |
| predictions = predictions.numpy() | |
| if self.model.config.num_labels == 1: | |
| scores = 1.0 / (1.0 + np.exp(-predictions)) | |
| else: | |
| scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True) | |
| if self.return_all_scores: | |
| return [ | |
| [ | |
| {"label": self.model.config.id2label[i], "score": score.item()} | |
| for i, score in enumerate(item) | |
| ] | |
| for item in scores | |
| ] | |
| else: | |
| return [ | |
| {"label": self.inv_label_map[item.argmax()], "score": item.max().item()} | |
| for item in scores | |
| ] | |