waleko commited on
Commit
22b56e7
·
1 Parent(s): d81b460

Add translation application with Gradio interface

Browse files

A new application for text translation from English to Russian is introduced. It uses a trained transformer model for translation and exposes the functionality through a clean interface using Gradio. The application accepts a string of English text and returns a Jason structure with the translation result that includes various details such as tokenized input and output text, output scores, and cross attention matrix.

Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +44 -0
  3. translate.py +82 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea/
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from translate import translator_fn
3
+
4
+
5
+ def predict(text):
6
+ result = translator_fn(text)
7
+ return {
8
+ "input_text": result.input_text,
9
+ "input_tokens": result.input_tokens,
10
+ "n_input": result.n_input,
11
+ "output_text": result.output_text,
12
+ "output_tokens": result.output_tokens,
13
+ "n_output": result.n_output,
14
+ "output_scores": result.output_scores,
15
+ "cross_attention": result.cross_attention.tolist(),
16
+ }
17
+
18
+
19
+ gradio_app = gr.Interface(
20
+ predict,
21
+ inputs=gr.Text(placeholder="Enter a sentence to translate...", label="Input text"),
22
+ outputs=[gr.Json(description="Model output", label="Model output")],
23
+ title="En2Ru Scientific Translator",
24
+ description="Translate scientific texts from English to Russian",
25
+ examples=[
26
+ [
27
+ r"There is no closed form to implement the KL divergence by the definition of (REF ) and (REF ) for "
28
+ r"Gaussian Mixture Models. Instead, we resort to the Monte Carlo simulation method proposed in [1]}. "
29
+ r"Then, the KL divergence can be caculated by: \(D_{KL_{MC}}(p||q) =\frac{1}{n} \sum _{i=1}^{n} log("
30
+ r"\frac{p(x_i)}{q(x_i)})\) \(D_{KL_{MC}}(q||p) =\frac{1}{n} \sum _{i=1}^{n} log(\frac{q(y_i)}{p(y_i)})\)"],
31
+ [
32
+ r"Almost all currently used classifiers are not intrinsically well-calibrated [1]}, which means their "
33
+ r"output scores can't be interpreted as probabilities. This is an issue when the model is used for "
34
+ r"decision making, as a component in a more general probabilistic pipeline, or simply when one needs a "
35
+ r"quantification of the uncertainty in model's predictions, for example in high risk applications."],
36
+ [
37
+ r"First, with the development of the high-torque electric actuators, such as [1]}, [2]} the robots are "
38
+ r"becoming more dynamical. These actuators allow them not only to move at high speeds, but also to "
39
+ r"rapidly create forces and torques to perform dynamic actions, such as running, jumping, etc."],
40
+ ],
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ gradio_app.launch()
translate.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ # Load model directly
7
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
+ import torch
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("under-tree/transformer-en-ru")
11
+ model = AutoModelForSeq2SeqLM.from_pretrained("under-tree/transformer-en-ru")
12
+
13
+
14
+ @dataclass
15
+ class TranslationResult:
16
+ input_text: str
17
+ n_input: int
18
+ input_tokens: List[str]
19
+ n_output: int
20
+ output_text: str
21
+ output_tokens: List[str]
22
+ output_scores: List[List[Tuple[str, float]]]
23
+ cross_attention: np.ndarray
24
+
25
+
26
+ def translator_fn(input_text: str, k=10) -> TranslationResult:
27
+ # Preprocess input
28
+ inputs = tokenizer(input_text, return_tensors="pt")
29
+ input_tokens = tokenizer.batch_decode(inputs.input_ids[0])
30
+ input_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in input_tokens])
31
+
32
+ # Generate output
33
+ outputs = model.generate(**inputs, return_dict_in_generate=True, output_scores=True, output_attentions=True)
34
+ output_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
35
+ output_tokens = tokenizer.batch_decode(outputs.sequences[0])
36
+ output_special_mask = torch.tensor([1 if t in tokenizer.all_special_tokens else 0 for t in output_tokens])
37
+
38
+ # Get cross attention matrix
39
+ cross_attention = torch.stack([torch.stack(t) for t in outputs.cross_attentions])
40
+ attention_matrix = cross_attention.mean(dim=4).mean(dim=3).mean(dim=2).mean(dim=1).detach().cpu().numpy()
41
+
42
+ # Get top tokens
43
+ top_scores = []
44
+ len_input = len(input_tokens)
45
+ len_output = len(output_tokens)
46
+
47
+ for i in range(len_output - 1):
48
+ if i + 1 < len_output and output_special_mask[i + 1] == 1:
49
+ # Skip special tokens (e.g. </s>, <pad>, etc.)
50
+ continue
51
+ top_elements, top_indices = outputs.scores[i].mean(dim=0).topk(k)
52
+ top_elements = top_elements.exp()
53
+ top_elements /= top_elements.sum()
54
+
55
+ top_indices = tokenizer.batch_decode(top_indices)
56
+
57
+ # filter out special tokens
58
+ top_pairs = [(m, t.item()) for t, m in zip(top_elements, top_indices) if m not in tokenizer.all_special_tokens]
59
+ top_scores.append(top_pairs)
60
+
61
+ # Filter out special tokens from all elements
62
+ clean_output_tokens = [t for t, m in zip(output_tokens, output_special_mask) if m == 0]
63
+ clean_input_tokens = [t for t, m in zip(input_tokens, input_special_mask) if m == 0]
64
+ clean_attention_matrix = attention_matrix[:len_output, :len_input] # for padding
65
+ clean_attention_matrix = np.delete(clean_attention_matrix, np.where(output_special_mask == 1), axis=0)
66
+ clean_attention_matrix = np.delete(clean_attention_matrix, np.where(input_special_mask == 1), axis=1)
67
+
68
+ n_input = len(clean_input_tokens)
69
+ n_output = len(clean_output_tokens)
70
+
71
+ assert clean_attention_matrix.shape == (n_output, n_input)
72
+ assert len(top_scores) == n_output
73
+ return TranslationResult(
74
+ input_text=input_text,
75
+ n_input=n_input,
76
+ input_tokens=clean_input_tokens,
77
+ output_text=output_text,
78
+ n_output=n_output,
79
+ output_tokens=clean_output_tokens,
80
+ output_scores=top_scores,
81
+ cross_attention=clean_attention_matrix
82
+ )