Spaces:
Running
Running
| from typing import Tuple, Union, Dict, List | |
| from multi_amr.data.postprocessing_graph import ParsedStatus | |
| from multi_amr.data.tokenization import AMRTokenizerWrapper | |
| from optimum.bettertransformer import BetterTransformer | |
| import penman | |
| import streamlit as st | |
| import torch | |
| from torch.quantization import quantize_dynamic | |
| from torch import nn, qint8 | |
| from transformers import MBartForConditionalGeneration, AutoConfig | |
| def get_resources(multilingual: bool, src_lang: str, quantize: bool = True, no_cuda: bool = False) -> Tuple[MBartForConditionalGeneration, AMRTokenizerWrapper]: | |
| """Get the relevant model, tokenizer and logits_processor. The loaded model depends on whether the multilingual | |
| model is requested, or not. If not, an English-only model is loaded. The model can be optionally quantized | |
| for better performance. | |
| :param multilingual: whether to load the multilingual model or not | |
| :param src_lang: source language | |
| :param quantize: whether to quantize the model with PyTorch's 'quantize_dynamic' | |
| :param no_cuda: whether to disable CUDA, even if it is available | |
| :return: the loaded model, and tokenizer wrapper | |
| """ | |
| model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en_es_nl" | |
| if not multilingual: | |
| if src_lang == "English": | |
| model_name = "BramVanroy/mbart-large-cc25-ft-amr30-en" | |
| elif src_lang == "Spanish": | |
| model_name = "BramVanroy/mbart-large-cc25-ft-amr30-es" | |
| elif src_lang == "Dutch": | |
| model_name = "BramVanroy/mbart-large-cc25-ft-amr30-nl" | |
| else: | |
| raise ValueError(f"Language {src_lang} not supported") | |
| # Tokenizer src_lang is reset during translation to the right language | |
| tok_wrapper = AMRTokenizerWrapper.from_pretrained(model_name, src_lang="en_XX") | |
| config = AutoConfig.from_pretrained(model_name) | |
| config.decoder_start_token_id = tok_wrapper.amr_token_id | |
| model = MBartForConditionalGeneration.from_pretrained(model_name, config=config) | |
| model.eval() | |
| embedding_size = model.get_input_embeddings().weight.shape[0] | |
| if len(tok_wrapper.tokenizer) > embedding_size: | |
| model.resize_token_embeddings(len(tok_wrapper.tokenizer)) | |
| model = BetterTransformer.transform(model, keep_original_model=False) | |
| if torch.cuda.is_available() and not no_cuda: | |
| model = model.to("cuda") | |
| elif quantize: # Quantization not supported on CUDA | |
| model = quantize_dynamic(model, {nn.Linear, nn.Dropout, nn.LayerNorm}, dtype=qint8) | |
| return model, tok_wrapper | |
| def translate(texts: List[str], src_lang: str, model: MBartForConditionalGeneration, tok_wrapper: AMRTokenizerWrapper, **gen_kwargs) -> Dict[str, List[Union[penman.Graph, ParsedStatus]]]: | |
| """Translates a given text of a given source language with a given model and tokenizer. The generation is guided by | |
| potential keyword-arguments, which can include arguments such as max length, logits processors, etc. | |
| :param texts: source text to translate (potentially a batch) | |
| :param src_lang: source language | |
| :param model: MBART model | |
| :param tok_wrapper: MBART tokenizer wrapper | |
| :param gen_kwargs: potential keyword arguments for the generation process | |
| :return: the translation (linearized AMR graph) | |
| """ | |
| if isinstance(texts, str): | |
| texts = [texts] | |
| tok_wrapper.src_lang = LANGUAGES[src_lang] | |
| encoded = tok_wrapper(texts, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| generated = model.generate(**encoded, output_scores=True, return_dict_in_generate=True, **gen_kwargs) | |
| generated["sequences"] = generated["sequences"].cpu() | |
| generated["sequences_scores"] = generated["sequences_scores"].cpu() | |
| best_scoring_results = {"graph": [], "status": []} | |
| beam_size = gen_kwargs["num_beams"] | |
| # Select the best item from the beam: the sequence with best status and highest score | |
| for sample_idx in range(0, len(generated["sequences_scores"]), beam_size): | |
| sequences = generated["sequences"][sample_idx: sample_idx + beam_size] | |
| scores = generated["sequences_scores"][sample_idx: sample_idx + beam_size].tolist() | |
| outputs = tok_wrapper.batch_decode_amr_ids(sequences) | |
| statuses = outputs["status"] | |
| graphs = outputs["graph"] | |
| zipped = zip(statuses, scores, graphs) | |
| # Lowest status first (OK=0, FIXED=1, BACKOFF=2), highest score second | |
| best = sorted(zipped, key=lambda item: (item[0].value, -item[1]))[0] | |
| best_scoring_results["graph"].append(best[2]) | |
| best_scoring_results["status"].append(best[0]) | |
| # Returns dictionary with "graph" and "status" keys | |
| return best_scoring_results | |
| LANGUAGES = { | |
| "English": "en_XX", | |
| "Dutch": "nl_XX", | |
| "Spanish": "es_XX", | |
| } | |