Update README.md
Browse files
    	
        README.md
    CHANGED
    
    | @@ -23,17 +23,90 @@ It achieves the following results on the evaluation set: | |
| 23 |  | 
| 24 | 
             
            ## Model description
         | 
| 25 |  | 
| 26 | 
            -
             | 
| 27 | 
            -
             | 
| 28 | 
            -
            ##  | 
| 29 | 
            -
             | 
| 30 | 
            -
             | 
| 31 | 
            -
             | 
| 32 | 
            -
             | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 35 | 
            -
             | 
| 36 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 37 |  | 
| 38 | 
             
            ### Training hyperparameters
         | 
| 39 |  | 
|  | |
| 23 |  | 
| 24 | 
             
            ## Model description
         | 
| 25 |  | 
| 26 | 
            +
            Machine Translation model from Hindi to English on bart small model.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ## Inference and evaluation
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ```python
         | 
| 31 | 
            +
            import torch
         | 
| 32 | 
            +
            import evaluate
         | 
| 33 | 
            +
            from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            class BartSmall():
         | 
| 36 | 
            +
                def __init__(self, model_path = 'ar5entum/bart_hin_eng_mt', device = None):
         | 
| 37 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(model_path)
         | 
| 38 | 
            +
                    self.model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
         | 
| 39 | 
            +
                    if not device:
         | 
| 40 | 
            +
                        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         | 
| 41 | 
            +
                    self.device = device
         | 
| 42 | 
            +
                    self.model.to(device)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def predict(self, input_text):
         | 
| 45 | 
            +
                    inputs = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)
         | 
| 46 | 
            +
                    pred_ids = self.model.generate(inputs.input_ids, max_length=512, num_beams=4, early_stopping=True)
         | 
| 47 | 
            +
                    prediction = self.tokenizer.decode(pred_ids[0], skip_special_tokens=True)
         | 
| 48 | 
            +
                    return prediction
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                def predict_batch(self, input_texts, batch_size=32):
         | 
| 51 | 
            +
                    all_predictions = []
         | 
| 52 | 
            +
                    for i in range(0, len(input_texts), batch_size):
         | 
| 53 | 
            +
                        batch_texts = input_texts[i:i+batch_size]
         | 
| 54 | 
            +
                        inputs = self.tokenizer(batch_texts, return_tensors="pt", max_length=512, 
         | 
| 55 | 
            +
                                                truncation=True, padding=True).to(self.device)
         | 
| 56 | 
            +
                        
         | 
| 57 | 
            +
                        with torch.no_grad():
         | 
| 58 | 
            +
                            pred_ids = self.model.generate(inputs.input_ids, 
         | 
| 59 | 
            +
                                                           max_length=512, 
         | 
| 60 | 
            +
                                                           num_beams=4, 
         | 
| 61 | 
            +
                                                           early_stopping=True)
         | 
| 62 | 
            +
                        
         | 
| 63 | 
            +
                        predictions = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
         | 
| 64 | 
            +
                        all_predictions.extend(predictions)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    return all_predictions
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            model = BartSmall(device='cuda')
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            input_texts = [
         | 
| 71 | 
            +
                "यह शोध्य रकम है।", 
         | 
| 72 | 
            +
                "जानने के लिए देखें ये वीडियो.",
         | 
| 73 | 
            +
                "वह दो बेटियों व एक बेटे का पिता था।"
         | 
| 74 | 
            +
                ]
         | 
| 75 | 
            +
            ground_truths = [
         | 
| 76 | 
            +
                "This is a repayable amount.",
         | 
| 77 | 
            +
                "Watch this video to find out.",
         | 
| 78 | 
            +
                "He was a father of two daughters and a son."
         | 
| 79 | 
            +
                ]
         | 
| 80 | 
            +
            import time
         | 
| 81 | 
            +
            start = time.time()
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            predictions = model.predict_batch(input_texts, batch_size=len(input_texts))
         | 
| 84 | 
            +
            end = time.time()
         | 
| 85 | 
            +
            print("TIME: ", end-start)
         | 
| 86 | 
            +
            for i in range(len(input_texts)):
         | 
| 87 | 
            +
                print("‾‾‾‾‾‾‾‾‾‾‾‾")
         | 
| 88 | 
            +
                print("Input text:\t", input_texts[i])
         | 
| 89 | 
            +
                print("Prediction:\t", predictions[i])
         | 
| 90 | 
            +
                print("Ground Truth:\t", ground_truths[i])
         | 
| 91 | 
            +
            bleu = evaluate.load("bleu")
         | 
| 92 | 
            +
            results = bleu.compute(predictions=predictions, references=ground_truths)
         | 
| 93 | 
            +
            print(results)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            # TIME:  1.2374696731567383
         | 
| 96 | 
            +
            # ‾‾‾‾‾‾‾‾‾‾‾‾
         | 
| 97 | 
            +
            # Input text:	 यह शोध्य रकम है।
         | 
| 98 | 
            +
            # Prediction:	 This is a repayable amount.
         | 
| 99 | 
            +
            # Ground Truth:	 This is a repayable amount.
         | 
| 100 | 
            +
            # ‾‾‾‾‾‾‾‾‾‾‾‾
         | 
| 101 | 
            +
            # Input text:	 जानने के लिए देखें ये वीडियो.
         | 
| 102 | 
            +
            # Prediction:	 View these videos to know.
         | 
| 103 | 
            +
            # Ground Truth:	 Watch this video to find out.
         | 
| 104 | 
            +
            # ‾‾‾‾‾‾‾‾‾‾‾‾
         | 
| 105 | 
            +
            # Input text:	 वह दो बेटियों व एक बेटे का पिता था।
         | 
| 106 | 
            +
            # Prediction:	 He was a father of two daughters and a son.
         | 
| 107 | 
            +
            # Ground Truth:	 He was a father of two daughters and a son.
         | 
| 108 | 
            +
            # {'bleu': 0.747875245486914, 'precisions': [0.8260869565217391, 0.75, 0.7647058823529411, 0.7857142857142857], 'brevity_penalty': 0.9574533680683809, 'length_ratio': 0.9583333333333334, 'translation_length': 23, 'reference_length': 24}
         | 
| 109 | 
            +
            ```
         | 
| 110 |  | 
| 111 | 
             
            ### Training hyperparameters
         | 
| 112 |  | 
