Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	Duplicate from OFA-Sys/OFA-vqa
Browse filesCo-authored-by: Julien Chaumond <[email protected]>
This view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +27 -0
- LICENSE +201 -0
- README.md +15 -0
- Starry_Night.jpeg +0 -0
- app.py +153 -0
- cat-4894153_1920.jpg +0 -0
- checkpoints.md +13 -0
- colab.md +8 -0
- criterions/__init__.py +2 -0
- criterions/label_smoothed_cross_entropy.py +343 -0
- criterions/scst_loss.py +280 -0
- data/__init__.py +0 -0
- data/data_utils.py +601 -0
- data/file_dataset.py +102 -0
- data/mm_data/__init__.py +0 -0
- data/mm_data/caption_dataset.py +154 -0
- data/mm_data/refcoco_dataset.py +168 -0
- data/mm_data/vqa_gen_dataset.py +211 -0
- data/ofa_dataset.py +74 -0
- datasets.md +10 -0
- evaluate.py +156 -0
- fairseq/.github/ISSUE_TEMPLATE.md +3 -0
- fairseq/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
- fairseq/.github/ISSUE_TEMPLATE/documentation.md +15 -0
- fairseq/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
- fairseq/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
- fairseq/.github/PULL_REQUEST_TEMPLATE.md +16 -0
- fairseq/.github/stale.yml +30 -0
- fairseq/.github/workflows/build.yml +55 -0
- fairseq/.github/workflows/build_wheels.yml +41 -0
- fairseq/.gitignore +136 -0
- fairseq/.gitmodules +4 -0
- fairseq/CODE_OF_CONDUCT.md +77 -0
- fairseq/CONTRIBUTING.md +28 -0
- fairseq/LICENSE +21 -0
- fairseq/README.md +229 -0
- fairseq/docs/Makefile +20 -0
- fairseq/docs/_static/theme_overrides.css +9 -0
- fairseq/docs/command_line_tools.rst +85 -0
- fairseq/docs/conf.py +134 -0
- fairseq/docs/criterions.rst +31 -0
- fairseq/docs/data.rst +58 -0
- fairseq/docs/docutils.conf +2 -0
- fairseq/docs/fairseq_logo.png +0 -0
- fairseq/docs/getting_started.rst +216 -0
- fairseq/docs/hydra_integration.md +284 -0
- fairseq/docs/index.rst +49 -0
- fairseq/docs/lr_scheduler.rst +34 -0
- fairseq/docs/make.bat +36 -0
- fairseq/docs/models.rst +104 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,27 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bin.* filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            *.zstandard filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                             Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright 1999-2022 Alibaba Group Holding Ltd.
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: OFA-Visual_Question_Answering
         | 
| 3 | 
            +
            emoji: 🎓
         | 
| 4 | 
            +
            colorFrom: blue
         | 
| 5 | 
            +
            colorTo: pink
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            app_file: app.py
         | 
| 8 | 
            +
            pinned: false
         | 
| 9 | 
            +
            license: apache-2.0
         | 
| 10 | 
            +
            duplicated_from: OFA-Sys/OFA-vqa
         | 
| 11 | 
            +
            ---
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
    	
        Starry_Night.jpeg
    ADDED
    
    |   | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,153 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            os.system('git clone https://github.com/pytorch/fairseq.git; cd fairseq;'
         | 
| 4 | 
            +
                      'pip install --use-feature=in-tree-build ./; cd ..')
         | 
| 5 | 
            +
            os.system('ls -l')
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import numpy as np
         | 
| 9 | 
            +
            import re
         | 
| 10 | 
            +
            from fairseq import utils,tasks
         | 
| 11 | 
            +
            from fairseq import checkpoint_utils
         | 
| 12 | 
            +
            from fairseq import distributed_utils, options, tasks, utils
         | 
| 13 | 
            +
            from fairseq.dataclass.utils import convert_namespace_to_omegaconf
         | 
| 14 | 
            +
            from utils.zero_shot_utils import zero_shot_step
         | 
| 15 | 
            +
            from tasks.mm_tasks.vqa_gen import VqaGenTask
         | 
| 16 | 
            +
            from models.ofa import OFAModel
         | 
| 17 | 
            +
            from PIL import Image
         | 
| 18 | 
            +
            from torchvision import transforms
         | 
| 19 | 
            +
            import gradio as gr
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            # Register VQA task
         | 
| 22 | 
            +
            tasks.register_task('vqa_gen',VqaGenTask)
         | 
| 23 | 
            +
            # turn on cuda if GPU is available
         | 
| 24 | 
            +
            use_cuda = torch.cuda.is_available()
         | 
| 25 | 
            +
            # use fp16 only when GPU is available
         | 
| 26 | 
            +
            use_fp16 = False
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            os.system('wget https://ofa-silicon.oss-us-west-1.aliyuncs.com/checkpoints/ofa_large_384.pt; '
         | 
| 29 | 
            +
                      'mkdir -p checkpoints; mv ofa_large_384.pt checkpoints/ofa_large_384.pt')
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            # specify some options for evaluation
         | 
| 32 | 
            +
            parser = options.get_generation_parser()
         | 
| 33 | 
            +
            input_args = ["", "--task=vqa_gen", "--beam=100", "--unnormalized", "--path=checkpoints/ofa_large_384.pt", "--bpe-dir=utils/BPE"]
         | 
| 34 | 
            +
            args = options.parse_args_and_arch(parser, input_args)
         | 
| 35 | 
            +
            cfg = convert_namespace_to_omegaconf(args)
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            # Load pretrained ckpt & config
         | 
| 38 | 
            +
            task = tasks.setup_task(cfg.task)
         | 
| 39 | 
            +
            models, cfg = checkpoint_utils.load_model_ensemble(
         | 
| 40 | 
            +
                utils.split_paths(cfg.common_eval.path),
         | 
| 41 | 
            +
                task=task
         | 
| 42 | 
            +
            )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            # Move models to GPU
         | 
| 45 | 
            +
            for model in models:
         | 
| 46 | 
            +
                model.eval()
         | 
| 47 | 
            +
                if use_fp16:
         | 
| 48 | 
            +
                    model.half()
         | 
| 49 | 
            +
                if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
         | 
| 50 | 
            +
                    model.cuda()
         | 
| 51 | 
            +
                model.prepare_for_inference_(cfg)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # Initialize generator
         | 
| 54 | 
            +
            generator = task.build_generator(models, cfg.generation)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            # Image transform
         | 
| 57 | 
            +
            from torchvision import transforms
         | 
| 58 | 
            +
            mean = [0.5, 0.5, 0.5]
         | 
| 59 | 
            +
            std = [0.5, 0.5, 0.5]
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            patch_resize_transform = transforms.Compose([
         | 
| 62 | 
            +
                lambda image: image.convert("RGB"),
         | 
| 63 | 
            +
                transforms.Resize((cfg.task.patch_image_size, cfg.task.patch_image_size), interpolation=Image.BICUBIC),
         | 
| 64 | 
            +
                transforms.ToTensor(),
         | 
| 65 | 
            +
                transforms.Normalize(mean=mean, std=std),
         | 
| 66 | 
            +
            ])
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            # Text preprocess
         | 
| 69 | 
            +
            bos_item = torch.LongTensor([task.src_dict.bos()])
         | 
| 70 | 
            +
            eos_item = torch.LongTensor([task.src_dict.eos()])
         | 
| 71 | 
            +
            pad_idx = task.src_dict.pad()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            # Normalize the question
         | 
| 74 | 
            +
            def pre_question(question, max_ques_words):
         | 
| 75 | 
            +
                question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
         | 
| 76 | 
            +
                question = re.sub(
         | 
| 77 | 
            +
                    r"\s{2,}",
         | 
| 78 | 
            +
                    ' ',
         | 
| 79 | 
            +
                    question,
         | 
| 80 | 
            +
                )
         | 
| 81 | 
            +
                question = question.rstrip('\n')
         | 
| 82 | 
            +
                question = question.strip(' ')
         | 
| 83 | 
            +
                # truncate question
         | 
| 84 | 
            +
                question_words = question.split(' ')
         | 
| 85 | 
            +
                if len(question_words) > max_ques_words:
         | 
| 86 | 
            +
                    question = ' '.join(question_words[:max_ques_words])
         | 
| 87 | 
            +
                return question
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            def encode_text(text, length=None, append_bos=False, append_eos=False):
         | 
| 90 | 
            +
                s = task.tgt_dict.encode_line(
         | 
| 91 | 
            +
                    line=task.bpe.encode(text),
         | 
| 92 | 
            +
                    add_if_not_exist=False,
         | 
| 93 | 
            +
                    append_eos=False
         | 
| 94 | 
            +
                ).long()
         | 
| 95 | 
            +
                if length is not None:
         | 
| 96 | 
            +
                    s = s[:length]
         | 
| 97 | 
            +
                if append_bos:
         | 
| 98 | 
            +
                    s = torch.cat([bos_item, s])
         | 
| 99 | 
            +
                if append_eos:
         | 
| 100 | 
            +
                    s = torch.cat([s, eos_item])
         | 
| 101 | 
            +
                return s
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            # Construct input for open-domain VQA task
         | 
| 104 | 
            +
            def construct_sample(image: Image, question: str):
         | 
| 105 | 
            +
                patch_image = patch_resize_transform(image).unsqueeze(0)
         | 
| 106 | 
            +
                patch_mask = torch.tensor([True])
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                question = pre_question(question, task.cfg.max_src_length)
         | 
| 109 | 
            +
                question = question + '?' if not question.endswith('?') else question
         | 
| 110 | 
            +
                src_text = encode_text(' {}'.format(question), append_bos=True, append_eos=True).unsqueeze(0)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
         | 
| 113 | 
            +
                ref_dict = np.array([{'yes': 1.0}]) # just placeholder
         | 
| 114 | 
            +
                sample = {
         | 
| 115 | 
            +
                    "id":np.array(['42']),
         | 
| 116 | 
            +
                    "net_input": {
         | 
| 117 | 
            +
                        "src_tokens": src_text,
         | 
| 118 | 
            +
                        "src_lengths": src_length,
         | 
| 119 | 
            +
                        "patch_images": patch_image,
         | 
| 120 | 
            +
                        "patch_masks": patch_mask,
         | 
| 121 | 
            +
                    },
         | 
| 122 | 
            +
                    "ref_dict": ref_dict,
         | 
| 123 | 
            +
                }
         | 
| 124 | 
            +
                return sample
         | 
| 125 | 
            +
              
         | 
| 126 | 
            +
            # Function to turn FP32 to FP16
         | 
| 127 | 
            +
            def apply_half(t):
         | 
| 128 | 
            +
                if t.dtype is torch.float32:
         | 
| 129 | 
            +
                    return t.to(dtype=torch.half)
         | 
| 130 | 
            +
                return t
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
            # Function for image captioning
         | 
| 134 | 
            +
            def open_domain_vqa(Image, Question):
         | 
| 135 | 
            +
                sample = construct_sample(Image, Question)
         | 
| 136 | 
            +
                sample = utils.move_to_cuda(sample) if use_cuda else sample
         | 
| 137 | 
            +
                sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
         | 
| 138 | 
            +
                # Run eval step for open-domain VQA
         | 
| 139 | 
            +
                with torch.no_grad():
         | 
| 140 | 
            +
                    result, scores = zero_shot_step(task, generator, models, sample)
         | 
| 141 | 
            +
                return result[0]['answer']
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            title = "OFA-Visual_Question_Answering"
         | 
| 145 | 
            +
            description = "Gradio Demo for OFA-Visual_Question_Answering. Upload your own image (high-resolution images are recommended) or click any one of the examples, and click " \
         | 
| 146 | 
            +
                          "\"Submit\" and then wait for OFA's answer. "
         | 
| 147 | 
            +
            article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
         | 
| 148 | 
            +
                      "Repo</a></p> "
         | 
| 149 | 
            +
            examples = [['cat-4894153_1920.jpg', 'where are the cats?'], ['men-6245003_1920.jpg', 'how many people are in the image?'], ['labrador-retriever-7004193_1920.jpg', 'what breed is the dog in the picture?'], ['Starry_Night.jpeg', 'what style does the picture belong to?']]
         | 
| 150 | 
            +
            io = gr.Interface(fn=open_domain_vqa, inputs=[gr.inputs.Image(type='pil'), "textbox"], outputs=gr.outputs.Textbox(label="Answer"),
         | 
| 151 | 
            +
                              title=title, description=description, article=article, examples=examples,
         | 
| 152 | 
            +
                              allow_flagging=False, allow_screenshot=False)
         | 
| 153 | 
            +
            io.launch()
         | 
    	
        cat-4894153_1920.jpg
    ADDED
    
    |   | 
    	
        checkpoints.md
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Checkpoints
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            We provide links for you to download our checkpoints. We will release all the checkpoints including pretrained and finetuned models on different tasks. 
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ## Pretraining
         | 
| 6 | 
            +
            * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/ofa_large.pt"> Pre-trained checkpoint (OFA-Large) </a>
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ## Finetuning
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/caption_large_best_clean.pt"> Finetuned checkpoint for Caption on COCO </a>
         | 
| 11 | 
            +
            * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcoco_large_best.pt"> Finetuned checkpoint for RefCOCO </a>
         | 
| 12 | 
            +
            * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocoplus_large_best.pt"> Finetuned checkpoint for RefCOCO+ </a>
         | 
| 13 | 
            +
            * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/checkpoints/refcocog_large_best.pt"> Finetuned checkpoint for RefCOCOg </a>
         | 
    	
        colab.md
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Colab Notebooks
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            We provide Colab notebooks of different downstream task for you guys to enjoy OFA. See below.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            * Image Captioning: [![][colab]](https://colab.research.google.com/drive/1Q4eNhhhLcgOP4hHqwZwU1ijOlabgve1W?usp=sharing)
         | 
| 6 | 
            +
            * Referring Expression Comprehension: [![][colab]](https://colab.research.google.com/drive/1AHQNRdaUpRTgr3XySHSlba8aXwBAjwPB?usp=sharing)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            [colab]: <https://colab.research.google.com/assets/colab-badge.svg>
         | 
    	
        criterions/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .scst_loss import ScstRewardCriterion
         | 
| 2 | 
            +
            from .label_smoothed_cross_entropy import AjustLabelSmoothedCrossEntropyCriterion
         | 
    	
        criterions/label_smoothed_cross_entropy.py
    ADDED
    
    | @@ -0,0 +1,343 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 4 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            from dataclasses import dataclass, field
         | 
| 8 | 
            +
            from typing import Optional
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torch.nn.functional as F
         | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            from fairseq import metrics, utils
         | 
| 14 | 
            +
            from fairseq.criterions import FairseqCriterion, register_criterion
         | 
| 15 | 
            +
            from fairseq.dataclass import FairseqDataclass
         | 
| 16 | 
            +
            from omegaconf import II
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            @dataclass
         | 
| 20 | 
            +
            class AjustLabelSmoothedCrossEntropyCriterionConfig(FairseqDataclass):
         | 
| 21 | 
            +
                label_smoothing: float = field(
         | 
| 22 | 
            +
                    default=0.0,
         | 
| 23 | 
            +
                    metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
         | 
| 24 | 
            +
                )
         | 
| 25 | 
            +
                report_accuracy: bool = field(
         | 
| 26 | 
            +
                    default=False,
         | 
| 27 | 
            +
                    metadata={"help": "report accuracy metric"},
         | 
| 28 | 
            +
                )
         | 
| 29 | 
            +
                ignore_prefix_size: int = field(
         | 
| 30 | 
            +
                    default=0,
         | 
| 31 | 
            +
                    metadata={"help": "Ignore first N tokens"},
         | 
| 32 | 
            +
                )
         | 
| 33 | 
            +
                ignore_eos: bool = field(
         | 
| 34 | 
            +
                    default=False,
         | 
| 35 | 
            +
                    metadata={"help": "Ignore eos token"},
         | 
| 36 | 
            +
                )
         | 
| 37 | 
            +
                sentence_avg: bool = II("optimization.sentence_avg")
         | 
| 38 | 
            +
                drop_worst_ratio: float = field(
         | 
| 39 | 
            +
                    default=0.0,
         | 
| 40 | 
            +
                    metadata={"help": "ratio for discarding bad samples"},
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
                drop_worst_after: int = field(
         | 
| 43 | 
            +
                    default=0,
         | 
| 44 | 
            +
                    metadata={"help": "steps for discarding bad samples"},
         | 
| 45 | 
            +
                )
         | 
| 46 | 
            +
                use_rdrop: bool = field(
         | 
| 47 | 
            +
                    default=False, metadata={"help": "use R-Drop"}
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
                reg_alpha: float = field(
         | 
| 50 | 
            +
                    default=1.0, metadata={"help": "weight for R-Drop"}
         | 
| 51 | 
            +
                )
         | 
| 52 | 
            +
                sample_patch_num: int = field(
         | 
| 53 | 
            +
                    default=196, metadata={"help": "sample patchs for v1"}
         | 
| 54 | 
            +
                )
         | 
| 55 | 
            +
                constraint_range: Optional[str] = field(
         | 
| 56 | 
            +
                    default=None,
         | 
| 57 | 
            +
                    metadata={"help": "constraint range"}
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def construct_rdrop_sample(x):
         | 
| 62 | 
            +
                if isinstance(x, dict):
         | 
| 63 | 
            +
                    for key in x:
         | 
| 64 | 
            +
                        x[key] = construct_rdrop_sample(x[key])
         | 
| 65 | 
            +
                    return x
         | 
| 66 | 
            +
                elif isinstance(x, torch.Tensor):
         | 
| 67 | 
            +
                    return x.repeat(2, *([1] * (x.dim()-1)))
         | 
| 68 | 
            +
                elif isinstance(x, int):
         | 
| 69 | 
            +
                    return x * 2
         | 
| 70 | 
            +
                elif isinstance(x, np.ndarray):
         | 
| 71 | 
            +
                    return x.repeat(2)
         | 
| 72 | 
            +
                else:
         | 
| 73 | 
            +
                    raise NotImplementedError
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def kl_loss(p, q):
         | 
| 77 | 
            +
                p_loss = F.kl_div(p, torch.exp(q), reduction='sum')
         | 
| 78 | 
            +
                q_loss = F.kl_div(q, torch.exp(p), reduction='sum')
         | 
| 79 | 
            +
                loss = (p_loss + q_loss) / 2
         | 
| 80 | 
            +
                return loss
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def label_smoothed_nll_loss(
         | 
| 84 | 
            +
                    lprobs, target, epsilon, update_num, reduce=True,
         | 
| 85 | 
            +
                    drop_worst_ratio=0.0, drop_worst_after=0, use_rdrop=False, reg_alpha=1.0,
         | 
| 86 | 
            +
                    constraint_masks=None, constraint_start=None, constraint_end=None
         | 
| 87 | 
            +
            ):
         | 
| 88 | 
            +
                if target.dim() == lprobs.dim() - 1:
         | 
| 89 | 
            +
                    target = target.unsqueeze(-1)
         | 
| 90 | 
            +
                nll_loss = -lprobs.gather(dim=-1, index=target).squeeze(-1)
         | 
| 91 | 
            +
                if constraint_masks is not None:
         | 
| 92 | 
            +
                    smooth_loss = -lprobs.masked_fill(~constraint_masks, 0).sum(dim=-1, keepdim=True).squeeze(-1)
         | 
| 93 | 
            +
                    eps_i = epsilon / (constraint_masks.sum(1) - 1 + 1e-6)
         | 
| 94 | 
            +
                elif constraint_start is not None and constraint_end is not None:
         | 
| 95 | 
            +
                    constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
         | 
| 96 | 
            +
                    smooth_loss = -lprobs[:, constraint_range].sum(dim=-1, keepdim=True).squeeze(-1)
         | 
| 97 | 
            +
                    eps_i = epsilon / (len(constraint_range) - 1 + 1e-6)
         | 
| 98 | 
            +
                else:
         | 
| 99 | 
            +
                    smooth_loss = -lprobs.sum(dim=-1, keepdim=True).squeeze(-1)
         | 
| 100 | 
            +
                    eps_i = epsilon / (lprobs.size(-1) - 1)
         | 
| 101 | 
            +
                loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
         | 
| 102 | 
            +
                if drop_worst_ratio > 0 and update_num > drop_worst_after:
         | 
| 103 | 
            +
                    if use_rdrop:
         | 
| 104 | 
            +
                        true_batch_size = loss.size(0) // 2
         | 
| 105 | 
            +
                        _, indices = torch.topk(loss[:true_batch_size], k=int(true_batch_size * (1 - drop_worst_ratio)), largest=False)
         | 
| 106 | 
            +
                        loss = torch.cat([loss[indices], loss[indices+true_batch_size]])
         | 
| 107 | 
            +
                        nll_loss = torch.cat([nll_loss[indices], nll_loss[indices+true_batch_size]])
         | 
| 108 | 
            +
                        lprobs = torch.cat([lprobs[indices], lprobs[indices+true_batch_size]])
         | 
| 109 | 
            +
                    else:
         | 
| 110 | 
            +
                        loss, indices = torch.topk(loss, k=int(loss.shape[0] * (1 - drop_worst_ratio)), largest=False)
         | 
| 111 | 
            +
                        nll_loss = nll_loss[indices]
         | 
| 112 | 
            +
                        lprobs = lprobs[indices]
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                ntokens = loss.numel()
         | 
| 115 | 
            +
                nll_loss = nll_loss.sum()
         | 
| 116 | 
            +
                loss = loss.sum()
         | 
| 117 | 
            +
                if use_rdrop:
         | 
| 118 | 
            +
                    true_batch_size = lprobs.size(0) // 2
         | 
| 119 | 
            +
                    p = lprobs[:true_batch_size]
         | 
| 120 | 
            +
                    q = lprobs[true_batch_size:]
         | 
| 121 | 
            +
                    if constraint_start is not None and constraint_end is not None:
         | 
| 122 | 
            +
                        constraint_range = [0, 1, 2, 3] + list(range(constraint_start, constraint_end))
         | 
| 123 | 
            +
                        p = p[:, constraint_range]
         | 
| 124 | 
            +
                        q = q[:, constraint_range]
         | 
| 125 | 
            +
                    loss += kl_loss(p, q) * reg_alpha
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                return loss, nll_loss, ntokens
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
            @register_criterion(
         | 
| 131 | 
            +
                "ajust_label_smoothed_cross_entropy", dataclass=AjustLabelSmoothedCrossEntropyCriterionConfig
         | 
| 132 | 
            +
            )
         | 
| 133 | 
            +
            class AjustLabelSmoothedCrossEntropyCriterion(FairseqCriterion):
         | 
| 134 | 
            +
                def __init__(
         | 
| 135 | 
            +
                    self,
         | 
| 136 | 
            +
                    task,
         | 
| 137 | 
            +
                    sentence_avg,
         | 
| 138 | 
            +
                    label_smoothing,
         | 
| 139 | 
            +
                    ignore_prefix_size=0,
         | 
| 140 | 
            +
                    ignore_eos=False,
         | 
| 141 | 
            +
                    report_accuracy=False,
         | 
| 142 | 
            +
                    drop_worst_ratio=0,
         | 
| 143 | 
            +
                    drop_worst_after=0,
         | 
| 144 | 
            +
                    use_rdrop=False,
         | 
| 145 | 
            +
                    reg_alpha=1.0,
         | 
| 146 | 
            +
                    sample_patch_num=196,
         | 
| 147 | 
            +
                    constraint_range=None
         | 
| 148 | 
            +
                ):
         | 
| 149 | 
            +
                    super().__init__(task)
         | 
| 150 | 
            +
                    self.sentence_avg = sentence_avg
         | 
| 151 | 
            +
                    self.eps = label_smoothing
         | 
| 152 | 
            +
                    self.ignore_prefix_size = ignore_prefix_size
         | 
| 153 | 
            +
                    self.ignore_eos = ignore_eos
         | 
| 154 | 
            +
                    self.report_accuracy = report_accuracy
         | 
| 155 | 
            +
                    self.drop_worst_ratio = drop_worst_ratio
         | 
| 156 | 
            +
                    self.drop_worst_after = drop_worst_after
         | 
| 157 | 
            +
                    self.use_rdrop = use_rdrop
         | 
| 158 | 
            +
                    self.reg_alpha = reg_alpha
         | 
| 159 | 
            +
                    self.sample_patch_num = sample_patch_num
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    self.constraint_start = None
         | 
| 162 | 
            +
                    self.constraint_end = None
         | 
| 163 | 
            +
                    if constraint_range is not None:
         | 
| 164 | 
            +
                        constraint_start, constraint_end = constraint_range.split(',')
         | 
| 165 | 
            +
                        self.constraint_start = int(constraint_start)
         | 
| 166 | 
            +
                        self.constraint_end = int(constraint_end)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                def forward(self, model, sample, update_num=0, reduce=True):
         | 
| 169 | 
            +
                    """Compute the loss for the given sample.
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    Returns a tuple with three elements:
         | 
| 172 | 
            +
                    1) the loss
         | 
| 173 | 
            +
                    2) the sample size, which is used as the denominator for the gradient
         | 
| 174 | 
            +
                    3) logging outputs to display while training
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
                    if isinstance(sample, list):
         | 
| 177 | 
            +
                        if self.sample_patch_num > 0:
         | 
| 178 | 
            +
                            sample[0]['net_input']['sample_patch_num'] = self.sample_patch_num
         | 
| 179 | 
            +
                        loss_v1, sample_size_v1, logging_output_v1 = self.forward(model, sample[0], update_num, reduce)
         | 
| 180 | 
            +
                        loss_v2, sample_size_v2, logging_output_v2 = self.forward(model, sample[1], update_num, reduce)
         | 
| 181 | 
            +
                        loss = loss_v1 / sample_size_v1 + loss_v2 / sample_size_v2
         | 
| 182 | 
            +
                        sample_size = 1
         | 
| 183 | 
            +
                        logging_output = {
         | 
| 184 | 
            +
                            "loss": loss.data,
         | 
| 185 | 
            +
                            "loss_v1": loss_v1.data,
         | 
| 186 | 
            +
                            "loss_v2": loss_v2.data,
         | 
| 187 | 
            +
                            "nll_loss": logging_output_v1["nll_loss"].data / sample_size_v1 + logging_output_v2["nll_loss"].data / sample_size_v2,
         | 
| 188 | 
            +
                            "ntokens": logging_output_v1["ntokens"] + logging_output_v2["ntokens"],
         | 
| 189 | 
            +
                            "nsentences": logging_output_v1["nsentences"] + logging_output_v2["nsentences"],
         | 
| 190 | 
            +
                            "sample_size": 1,
         | 
| 191 | 
            +
                            "sample_size_v1": sample_size_v1,
         | 
| 192 | 
            +
                            "sample_size_v2": sample_size_v2,
         | 
| 193 | 
            +
                        }
         | 
| 194 | 
            +
                        return loss, sample_size, logging_output
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    if self.use_rdrop:
         | 
| 197 | 
            +
                        construct_rdrop_sample(sample)
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    net_output = model(**sample["net_input"])
         | 
| 200 | 
            +
                    loss, nll_loss, ntokens = self.compute_loss(model, net_output, sample, update_num, reduce=reduce)
         | 
| 201 | 
            +
                    sample_size = (
         | 
| 202 | 
            +
                        sample["target"].size(0) if self.sentence_avg else ntokens
         | 
| 203 | 
            +
                    )
         | 
| 204 | 
            +
                    logging_output = {
         | 
| 205 | 
            +
                        "loss": loss.data,
         | 
| 206 | 
            +
                        "nll_loss": nll_loss.data,
         | 
| 207 | 
            +
                        "ntokens": sample["ntokens"],
         | 
| 208 | 
            +
                        "nsentences": sample["nsentences"],
         | 
| 209 | 
            +
                        "sample_size": sample_size,
         | 
| 210 | 
            +
                    }
         | 
| 211 | 
            +
                    if self.report_accuracy:
         | 
| 212 | 
            +
                        n_correct, total = self.compute_accuracy(model, net_output, sample)
         | 
| 213 | 
            +
                        logging_output["n_correct"] = utils.item(n_correct.data)
         | 
| 214 | 
            +
                        logging_output["total"] = utils.item(total.data)
         | 
| 215 | 
            +
                    return loss, sample_size, logging_output
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                def get_lprobs_and_target(self, model, net_output, sample):
         | 
| 218 | 
            +
                    conf = sample['conf'][:, None, None] if 'conf' in sample and sample['conf'] is not None else 1
         | 
| 219 | 
            +
                    constraint_masks = None
         | 
| 220 | 
            +
                    if "constraint_masks" in sample and sample["constraint_masks"] is not None:
         | 
| 221 | 
            +
                        constraint_masks = sample["constraint_masks"]
         | 
| 222 | 
            +
                        net_output[0].masked_fill_(~constraint_masks, -math.inf)
         | 
| 223 | 
            +
                    if self.constraint_start is not None and self.constraint_end is not None:
         | 
| 224 | 
            +
                        net_output[0][:, :, 4:self.constraint_start] = -math.inf
         | 
| 225 | 
            +
                        net_output[0][:, :, self.constraint_end:] = -math.inf
         | 
| 226 | 
            +
                    lprobs = model.get_normalized_probs(net_output, log_probs=True) * conf
         | 
| 227 | 
            +
                    target = model.get_targets(sample, net_output)
         | 
| 228 | 
            +
                    if self.ignore_prefix_size > 0:
         | 
| 229 | 
            +
                        lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
         | 
| 230 | 
            +
                        target = target[:, self.ignore_prefix_size :].contiguous()
         | 
| 231 | 
            +
                        if constraint_masks is not None:
         | 
| 232 | 
            +
                            constraint_masks = constraint_masks[:, self.ignore_prefix_size :, :].contiguous()
         | 
| 233 | 
            +
                    if self.ignore_eos:
         | 
| 234 | 
            +
                        bsz, seq_len, embed_dim = lprobs.size()
         | 
| 235 | 
            +
                        eos_indices = target.eq(self.task.tgt_dict.eos())
         | 
| 236 | 
            +
                        lprobs = lprobs[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
         | 
| 237 | 
            +
                        target = target[~eos_indices].reshape(bsz, seq_len-1)
         | 
| 238 | 
            +
                        if constraint_masks is not None:
         | 
| 239 | 
            +
                            constraint_masks = constraint_masks[~eos_indices].reshape(bsz, seq_len-1, embed_dim)
         | 
| 240 | 
            +
                    if constraint_masks is not None:
         | 
| 241 | 
            +
                        constraint_masks = constraint_masks.view(-1, constraint_masks.size(-1))
         | 
| 242 | 
            +
                    return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def compute_loss(self, model, net_output, sample, update_num, reduce=True):
         | 
| 245 | 
            +
                    lprobs, target, constraint_masks = self.get_lprobs_and_target(model, net_output, sample)
         | 
| 246 | 
            +
                    if constraint_masks is not None:
         | 
| 247 | 
            +
                        constraint_masks = constraint_masks[target != self.padding_idx]
         | 
| 248 | 
            +
                    lprobs = lprobs[target != self.padding_idx]
         | 
| 249 | 
            +
                    target = target[target != self.padding_idx]
         | 
| 250 | 
            +
                    loss, nll_loss, ntokens = label_smoothed_nll_loss(
         | 
| 251 | 
            +
                        lprobs,
         | 
| 252 | 
            +
                        target,
         | 
| 253 | 
            +
                        self.eps,
         | 
| 254 | 
            +
                        update_num,
         | 
| 255 | 
            +
                        reduce=reduce,
         | 
| 256 | 
            +
                        drop_worst_ratio=self.drop_worst_ratio,
         | 
| 257 | 
            +
                        drop_worst_after=self.drop_worst_after,
         | 
| 258 | 
            +
                        use_rdrop=self.use_rdrop,
         | 
| 259 | 
            +
                        reg_alpha=self.reg_alpha,
         | 
| 260 | 
            +
                        constraint_masks=constraint_masks,
         | 
| 261 | 
            +
                        constraint_start=self.constraint_start,
         | 
| 262 | 
            +
                        constraint_end=self.constraint_end
         | 
| 263 | 
            +
                    )
         | 
| 264 | 
            +
                    return loss, nll_loss, ntokens
         | 
| 265 | 
            +
             | 
| 266 | 
            +
                def compute_accuracy(self, model, net_output, sample):
         | 
| 267 | 
            +
                    lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
         | 
| 268 | 
            +
                    mask = target.ne(self.padding_idx)
         | 
| 269 | 
            +
                    n_correct = torch.sum(
         | 
| 270 | 
            +
                        lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
                    total = torch.sum(mask)
         | 
| 273 | 
            +
                    return n_correct, total
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                @classmethod
         | 
| 276 | 
            +
                def reduce_metrics(cls, logging_outputs) -> None:
         | 
| 277 | 
            +
                    """Aggregate logging outputs from data parallel training."""
         | 
| 278 | 
            +
                    loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
         | 
| 279 | 
            +
                    loss_sum_v1 = sum(log.get("loss_v1", 0) for log in logging_outputs)
         | 
| 280 | 
            +
                    loss_sum_v2 = sum(log.get("loss_v2", 0) for log in logging_outputs)
         | 
| 281 | 
            +
                    nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
         | 
| 282 | 
            +
                    ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
         | 
| 283 | 
            +
                    nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
         | 
| 284 | 
            +
                    sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
         | 
| 285 | 
            +
                    sample_size_v1 = sum(log.get("sample_size_v1", 0) for log in logging_outputs)
         | 
| 286 | 
            +
                    sample_size_v2 = sum(log.get("sample_size_v2", 0) for log in logging_outputs)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    metrics.log_scalar(
         | 
| 289 | 
            +
                        "loss", loss_sum / sample_size, sample_size, round=3
         | 
| 290 | 
            +
                    )
         | 
| 291 | 
            +
                    metrics.log_scalar(
         | 
| 292 | 
            +
                        "loss_v1", loss_sum_v1 / max(sample_size_v1, 1), max(sample_size_v1, 1), round=3
         | 
| 293 | 
            +
                    )
         | 
| 294 | 
            +
                    metrics.log_scalar(
         | 
| 295 | 
            +
                        "loss_v2", loss_sum_v2 / max(sample_size_v2, 1), max(sample_size_v2, 1), round=3
         | 
| 296 | 
            +
                    )
         | 
| 297 | 
            +
                    metrics.log_scalar(
         | 
| 298 | 
            +
                        "nll_loss", nll_loss_sum / sample_size, ntokens, round=3
         | 
| 299 | 
            +
                    )
         | 
| 300 | 
            +
                    metrics.log_derived(
         | 
| 301 | 
            +
                        "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
         | 
| 302 | 
            +
                    )
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    metrics.log_scalar(
         | 
| 305 | 
            +
                        "ntokens", ntokens, 1, round=3
         | 
| 306 | 
            +
                    )
         | 
| 307 | 
            +
                    metrics.log_scalar(
         | 
| 308 | 
            +
                        "nsentences", nsentences, 1, round=3
         | 
| 309 | 
            +
                    )
         | 
| 310 | 
            +
                    metrics.log_scalar(
         | 
| 311 | 
            +
                        "sample_size", sample_size, 1, round=3
         | 
| 312 | 
            +
                    )
         | 
| 313 | 
            +
                    metrics.log_scalar(
         | 
| 314 | 
            +
                        "sample_size_v1", sample_size_v1, 1, round=3
         | 
| 315 | 
            +
                    )
         | 
| 316 | 
            +
                    metrics.log_scalar(
         | 
| 317 | 
            +
                        "sample_size_v2", sample_size_v2, 1, round=3
         | 
| 318 | 
            +
                    )
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
         | 
| 321 | 
            +
                    if total > 0:
         | 
| 322 | 
            +
                        metrics.log_scalar("total", total)
         | 
| 323 | 
            +
                        n_correct = utils.item(
         | 
| 324 | 
            +
                            sum(log.get("n_correct", 0) for log in logging_outputs)
         | 
| 325 | 
            +
                        )
         | 
| 326 | 
            +
                        metrics.log_scalar("n_correct", n_correct)
         | 
| 327 | 
            +
                        metrics.log_derived(
         | 
| 328 | 
            +
                            "accuracy",
         | 
| 329 | 
            +
                            lambda meters: round(
         | 
| 330 | 
            +
                                meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
         | 
| 331 | 
            +
                            )
         | 
| 332 | 
            +
                            if meters["total"].sum > 0
         | 
| 333 | 
            +
                            else float("nan"),
         | 
| 334 | 
            +
                        )
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                @staticmethod
         | 
| 337 | 
            +
                def logging_outputs_can_be_summed() -> bool:
         | 
| 338 | 
            +
                    """
         | 
| 339 | 
            +
                    Whether the logging outputs returned by `forward` can be summed
         | 
| 340 | 
            +
                    across workers prior to calling `reduce_metrics`. Setting this
         | 
| 341 | 
            +
                    to True will improves distributed training speed.
         | 
| 342 | 
            +
                    """
         | 
| 343 | 
            +
                    return True
         | 
    	
        criterions/scst_loss.py
    ADDED
    
    | @@ -0,0 +1,280 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 4 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            import string
         | 
| 8 | 
            +
            from dataclasses import dataclass, field
         | 
| 9 | 
            +
            from collections import OrderedDict
         | 
| 10 | 
            +
            from typing import Optional
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from fairseq import metrics, utils
         | 
| 14 | 
            +
            from fairseq.criterions import FairseqCriterion, register_criterion
         | 
| 15 | 
            +
            from fairseq.dataclass import FairseqDataclass
         | 
| 16 | 
            +
            from omegaconf import II
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from data import data_utils
         | 
| 19 | 
            +
            from utils.cider.pyciderevalcap.ciderD.ciderD import CiderD
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def scst_loss(lprobs, target, reward, ignore_index=None, reduce=True):
         | 
| 23 | 
            +
                loss = -lprobs.gather(dim=-1, index=target.unsqueeze(-1)).squeeze() * reward.unsqueeze(-1)
         | 
| 24 | 
            +
                if ignore_index is not None:
         | 
| 25 | 
            +
                    pad_mask = target.eq(ignore_index)
         | 
| 26 | 
            +
                    loss.masked_fill_(pad_mask, 0.0)
         | 
| 27 | 
            +
                    ntokens = (~pad_mask).sum()
         | 
| 28 | 
            +
                else:
         | 
| 29 | 
            +
                    loss = loss.squeeze(-1)
         | 
| 30 | 
            +
                    ntokens = target.numel()
         | 
| 31 | 
            +
                if reduce:
         | 
| 32 | 
            +
                    loss = loss.sum()
         | 
| 33 | 
            +
                return loss, ntokens
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            @dataclass
         | 
| 36 | 
            +
            class ScstRewardCriterionConfig(FairseqDataclass):
         | 
| 37 | 
            +
                scst_cider_cached_tokens: str = field(
         | 
| 38 | 
            +
                    default="coco-train-words.p",
         | 
| 39 | 
            +
                    metadata={"help": "path to cached cPickle file used to calculate CIDEr scores"},
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                ignore_prefix_size: int = field(
         | 
| 42 | 
            +
                    default=0,
         | 
| 43 | 
            +
                    metadata={"help": "Ignore first N tokens"},
         | 
| 44 | 
            +
                )
         | 
| 45 | 
            +
                sentence_avg: bool = II("optimization.sentence_avg")
         | 
| 46 | 
            +
                constraint_range: Optional[str] = field(
         | 
| 47 | 
            +
                    default=None,
         | 
| 48 | 
            +
                    metadata={"help": "constraint range"}
         | 
| 49 | 
            +
                )
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            @register_criterion(
         | 
| 53 | 
            +
                "scst_reward_criterion", dataclass=ScstRewardCriterionConfig
         | 
| 54 | 
            +
            )
         | 
| 55 | 
            +
            class ScstRewardCriterion(FairseqCriterion):
         | 
| 56 | 
            +
                CIDER_REWARD_WEIGHT = 1
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __init__(
         | 
| 59 | 
            +
                    self,
         | 
| 60 | 
            +
                    task,
         | 
| 61 | 
            +
                    scst_cider_cached_tokens,
         | 
| 62 | 
            +
                    sentence_avg,
         | 
| 63 | 
            +
                    ignore_prefix_size=0,
         | 
| 64 | 
            +
                    constraint_range=None
         | 
| 65 | 
            +
                ):
         | 
| 66 | 
            +
                    super().__init__(task)
         | 
| 67 | 
            +
                    self.scst_cider_scorer = CiderD(df=scst_cider_cached_tokens)
         | 
| 68 | 
            +
                    self.sentence_avg = sentence_avg
         | 
| 69 | 
            +
                    self.ignore_prefix_size = ignore_prefix_size
         | 
| 70 | 
            +
                    self.transtab = str.maketrans({key: None for key in string.punctuation})
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    self.constraint_start = None
         | 
| 73 | 
            +
                    self.constraint_end = None
         | 
| 74 | 
            +
                    if constraint_range is not None:
         | 
| 75 | 
            +
                        constraint_start, constraint_end = constraint_range.split(',')
         | 
| 76 | 
            +
                        self.constraint_start = int(constraint_start)
         | 
| 77 | 
            +
                        self.constraint_end = int(constraint_end)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                def forward(self, model, sample, update_num=0, reduce=True):
         | 
| 80 | 
            +
                    """Compute the loss for the given sample.
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    Returns a tuple with three elements:
         | 
| 83 | 
            +
                    1) the loss
         | 
| 84 | 
            +
                    2) the sample size, which is used as the denominator for the gradient
         | 
| 85 | 
            +
                    3) logging outputs to display while training
         | 
| 86 | 
            +
                    """
         | 
| 87 | 
            +
                    loss, score, ntokens, nsentences = self.compute_loss(model, sample, reduce=reduce)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    sample_size = (
         | 
| 90 | 
            +
                        nsentences if self.sentence_avg else ntokens
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    logging_output = {
         | 
| 93 | 
            +
                        "loss": loss.data,
         | 
| 94 | 
            +
                        "score": score,
         | 
| 95 | 
            +
                        "ntokens": ntokens,
         | 
| 96 | 
            +
                        "nsentences": nsentences,
         | 
| 97 | 
            +
                        "sample_size": sample_size,
         | 
| 98 | 
            +
                    }
         | 
| 99 | 
            +
                    return loss, sample_size, logging_output
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def _calculate_eval_scores(self, gen_res, gt_idx, gt_res):
         | 
| 102 | 
            +
                    '''
         | 
| 103 | 
            +
                    gen_res: generated captions, list of str
         | 
| 104 | 
            +
                    gt_idx: list of int, of the same length as gen_res
         | 
| 105 | 
            +
                    gt_res: ground truth captions, list of list of str.
         | 
| 106 | 
            +
                        gen_res[i] corresponds to gt_res[gt_idx[i]]
         | 
| 107 | 
            +
                        Each image can have multiple ground truth captions
         | 
| 108 | 
            +
                    '''
         | 
| 109 | 
            +
                    gen_res_size = len(gen_res)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    res = OrderedDict()
         | 
| 112 | 
            +
                    for i in range(gen_res_size):
         | 
| 113 | 
            +
                        res[i] = [self._wrap_sentence(gen_res[i].strip().translate(self.transtab))]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    gts = OrderedDict()
         | 
| 116 | 
            +
                    gt_res_ = [
         | 
| 117 | 
            +
                        [self._wrap_sentence(gt_res[i][j].strip().translate(self.transtab)) for j in range(len(gt_res[i]))]
         | 
| 118 | 
            +
                            for i in range(len(gt_res))
         | 
| 119 | 
            +
                    ]
         | 
| 120 | 
            +
                    for i in range(gen_res_size):
         | 
| 121 | 
            +
                        gts[i] = gt_res_[gt_idx[i]]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    res_ = [{'image_id':i, 'caption': res[i]} for i in range(len(res))]
         | 
| 124 | 
            +
                    _, batch_cider_scores = self.scst_cider_scorer.compute_score(gts, res_)
         | 
| 125 | 
            +
                    scores = self.CIDER_REWARD_WEIGHT * batch_cider_scores
         | 
| 126 | 
            +
                    return scores
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                @classmethod
         | 
| 129 | 
            +
                def _wrap_sentence(self, s):
         | 
| 130 | 
            +
                    # ensure the sentence ends with <eos> token
         | 
| 131 | 
            +
                    # in order to keep consisitent with cider_cached_tokens
         | 
| 132 | 
            +
                    r = s.strip()
         | 
| 133 | 
            +
                    if r.endswith('.'):
         | 
| 134 | 
            +
                        r = r[:-1]
         | 
| 135 | 
            +
                    r += ' <eos>'
         | 
| 136 | 
            +
                    return r
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def get_generator_out(self, model, sample):
         | 
| 139 | 
            +
                    def decode(toks):
         | 
| 140 | 
            +
                        hypo = toks.int().cpu()
         | 
| 141 | 
            +
                        hypo_str = self.task.tgt_dict.string(hypo)
         | 
| 142 | 
            +
                        hypo_str = self.task.bpe.decode(hypo_str).strip()
         | 
| 143 | 
            +
                        return hypo, hypo_str
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    model.eval()
         | 
| 146 | 
            +
                    with torch.no_grad():
         | 
| 147 | 
            +
                        self.task.scst_generator.model.eval()
         | 
| 148 | 
            +
                        gen_out = self.task.scst_generator.generate([model], sample)
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    gen_target = []
         | 
| 151 | 
            +
                    gen_res = []
         | 
| 152 | 
            +
                    gt_res = []
         | 
| 153 | 
            +
                    for i in range(len(gen_out)):
         | 
| 154 | 
            +
                        for j in range(len(gen_out[i])):
         | 
| 155 | 
            +
                            hypo, hypo_str = decode(gen_out[i][j]["tokens"])
         | 
| 156 | 
            +
                            gen_target.append(hypo)
         | 
| 157 | 
            +
                            gen_res.append(hypo_str)
         | 
| 158 | 
            +
                        gt_res.append(
         | 
| 159 | 
            +
                            decode(utils.strip_pad(sample["target"][i], self.padding_idx))[1].split('&&')
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    return gen_target, gen_res, gt_res
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def get_reward_and_scores(self, gen_res, gt_res, device):
         | 
| 165 | 
            +
                    batch_size = len(gt_res)
         | 
| 166 | 
            +
                    gen_res_size = len(gen_res)
         | 
| 167 | 
            +
                    seq_per_img = gen_res_size // batch_size
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    gt_idx = [i // seq_per_img for i in range(gen_res_size)]
         | 
| 170 | 
            +
                    scores = self._calculate_eval_scores(gen_res, gt_idx, gt_res)
         | 
| 171 | 
            +
                    sc_ = scores.reshape(batch_size, seq_per_img)
         | 
| 172 | 
            +
                    baseline = (sc_.sum(1, keepdims=True) - sc_) / (sc_.shape[1] - 1)
         | 
| 173 | 
            +
                    # sample - baseline
         | 
| 174 | 
            +
                    reward = scores.reshape(batch_size, seq_per_img)
         | 
| 175 | 
            +
                    reward = reward - baseline
         | 
| 176 | 
            +
                    reward = reward.reshape(gen_res_size)
         | 
| 177 | 
            +
                    reward = torch.as_tensor(reward, device=device, dtype=torch.float64)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    return reward, scores
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def get_net_output(self, model, sample, gen_target):
         | 
| 182 | 
            +
                    def merge(sample_list, eos=self.task.tgt_dict.eos(), move_eos_to_beginning=False):
         | 
| 183 | 
            +
                        return data_utils.collate_tokens(
         | 
| 184 | 
            +
                            sample_list,
         | 
| 185 | 
            +
                            pad_idx=self.padding_idx,
         | 
| 186 | 
            +
                            eos_idx=eos,
         | 
| 187 | 
            +
                            left_pad=False,
         | 
| 188 | 
            +
                            move_eos_to_beginning=move_eos_to_beginning,
         | 
| 189 | 
            +
                        )
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                    batch_size = len(sample["target"])
         | 
| 192 | 
            +
                    gen_target_size = len(gen_target)
         | 
| 193 | 
            +
                    seq_per_img = gen_target_size // batch_size
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    model.train()
         | 
| 196 | 
            +
                    sample_src_tokens = torch.repeat_interleave(
         | 
| 197 | 
            +
                        sample['net_input']['src_tokens'], seq_per_img, dim=0
         | 
| 198 | 
            +
                    )
         | 
| 199 | 
            +
                    sample_src_lengths = torch.repeat_interleave(
         | 
| 200 | 
            +
                        sample['net_input']['src_lengths'], seq_per_img, dim=0
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
                    sample_patch_images = torch.repeat_interleave(
         | 
| 203 | 
            +
                        sample['net_input']['patch_images'], seq_per_img, dim=0
         | 
| 204 | 
            +
                    )
         | 
| 205 | 
            +
                    sample_patch_masks = torch.repeat_interleave(
         | 
| 206 | 
            +
                        sample['net_input']['patch_masks'], seq_per_img, dim=0
         | 
| 207 | 
            +
                    )
         | 
| 208 | 
            +
                    gen_prev_output_tokens = torch.as_tensor(
         | 
| 209 | 
            +
                        merge(gen_target, eos=self.task.tgt_dict.bos(), move_eos_to_beginning=True),
         | 
| 210 | 
            +
                        device=sample["target"].device, dtype=torch.int64
         | 
| 211 | 
            +
                    )
         | 
| 212 | 
            +
                    gen_target_tokens = torch.as_tensor(
         | 
| 213 | 
            +
                        merge(gen_target), device=sample["target"].device, dtype=torch.int64
         | 
| 214 | 
            +
                    )
         | 
| 215 | 
            +
                    net_output = model(
         | 
| 216 | 
            +
                        src_tokens=sample_src_tokens, src_lengths=sample_src_lengths,
         | 
| 217 | 
            +
                        patch_images=sample_patch_images, patch_masks=sample_patch_masks,
         | 
| 218 | 
            +
                        prev_output_tokens=gen_prev_output_tokens
         | 
| 219 | 
            +
                    )
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    return net_output, gen_target_tokens
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def get_lprobs_and_target(self, model, net_output, gen_target):
         | 
| 224 | 
            +
                    if self.constraint_start is not None and self.constraint_end is not None:
         | 
| 225 | 
            +
                        net_output[0][:, :, 4:self.constraint_start] = -math.inf
         | 
| 226 | 
            +
                        net_output[0][:, :, self.constraint_end:] = -math.inf
         | 
| 227 | 
            +
                    lprobs = model.get_normalized_probs(net_output, log_probs=True)
         | 
| 228 | 
            +
                    if self.ignore_prefix_size > 0:
         | 
| 229 | 
            +
                        if getattr(lprobs, "batch_first", False):
         | 
| 230 | 
            +
                            lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
         | 
| 231 | 
            +
                            gen_target = gen_target[:, self.ignore_prefix_size :].contiguous()
         | 
| 232 | 
            +
                        else:
         | 
| 233 | 
            +
                            lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
         | 
| 234 | 
            +
                            gen_target = gen_target[self.ignore_prefix_size :, :].contiguous()
         | 
| 235 | 
            +
                    return lprobs, gen_target
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                def compute_loss(self, model, sample, reduce=True):
         | 
| 238 | 
            +
                    gen_target, gen_res, gt_res = self.get_generator_out(model, sample)
         | 
| 239 | 
            +
                    reward, scores = self.get_reward_and_scores(gen_res, gt_res, device=sample["target"].device)
         | 
| 240 | 
            +
                    net_output, gen_target_tokens = self.get_net_output(model, sample, gen_target)
         | 
| 241 | 
            +
                    gen_lprobs, gen_target_tokens = self.get_lprobs_and_target(model, net_output, gen_target_tokens)
         | 
| 242 | 
            +
                    loss, ntokens = scst_loss(gen_lprobs, gen_target_tokens, reward, ignore_index=self.padding_idx, reduce=reduce)
         | 
| 243 | 
            +
                    nsentences = gen_target_tokens.size(0)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    return loss, scores.sum(), ntokens, nsentences
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                @classmethod
         | 
| 248 | 
            +
                def reduce_metrics(cls, logging_outputs) -> None:
         | 
| 249 | 
            +
                    """Aggregate logging outputs from data parallel training."""
         | 
| 250 | 
            +
                    loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
         | 
| 251 | 
            +
                    score_sum = sum(log.get("score", 0) for log in logging_outputs)
         | 
| 252 | 
            +
                    ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
         | 
| 253 | 
            +
                    nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
         | 
| 254 | 
            +
                    sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    metrics.log_scalar(
         | 
| 257 | 
            +
                        "loss", loss_sum / sample_size, sample_size, round=3
         | 
| 258 | 
            +
                    )
         | 
| 259 | 
            +
                    metrics.log_scalar(
         | 
| 260 | 
            +
                        "score", score_sum / nsentences, nsentences, round=3
         | 
| 261 | 
            +
                    )
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    metrics.log_scalar(
         | 
| 264 | 
            +
                        "ntokens", ntokens, 1, round=3
         | 
| 265 | 
            +
                    )
         | 
| 266 | 
            +
                    metrics.log_scalar(
         | 
| 267 | 
            +
                        "nsentences", nsentences, 1, round=3
         | 
| 268 | 
            +
                    )
         | 
| 269 | 
            +
                    metrics.log_scalar(
         | 
| 270 | 
            +
                        "sample_size", sample_size, 1, round=3
         | 
| 271 | 
            +
                    )
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                @staticmethod
         | 
| 274 | 
            +
                def logging_outputs_can_be_summed() -> bool:
         | 
| 275 | 
            +
                    """
         | 
| 276 | 
            +
                    Whether the logging outputs returned by `forward` can be summed
         | 
| 277 | 
            +
                    across workers prior to calling `reduce_metrics`. Setting this
         | 
| 278 | 
            +
                    to True will improves distributed training speed.
         | 
| 279 | 
            +
                    """
         | 
| 280 | 
            +
                    return True
         | 
    	
        data/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        data/data_utils.py
    ADDED
    
    | @@ -0,0 +1,601 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 4 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            try:
         | 
| 7 | 
            +
                from collections.abc import Iterable
         | 
| 8 | 
            +
            except ImportError:
         | 
| 9 | 
            +
                from collections import Iterable
         | 
| 10 | 
            +
            import contextlib
         | 
| 11 | 
            +
            import itertools
         | 
| 12 | 
            +
            import logging
         | 
| 13 | 
            +
            import re
         | 
| 14 | 
            +
            import warnings
         | 
| 15 | 
            +
            from typing import Optional, Tuple
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import numpy as np
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            from fairseq.file_io import PathManager
         | 
| 21 | 
            +
            from fairseq import utils
         | 
| 22 | 
            +
            import os
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def infer_language_pair(path):
         | 
| 28 | 
            +
                """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
         | 
| 29 | 
            +
                src, dst = None, None
         | 
| 30 | 
            +
                for filename in PathManager.ls(path):
         | 
| 31 | 
            +
                    parts = filename.split(".")
         | 
| 32 | 
            +
                    if len(parts) >= 3 and len(parts[1].split("-")) == 2:
         | 
| 33 | 
            +
                        return parts[1].split("-")
         | 
| 34 | 
            +
                return src, dst
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def collate_tokens(
         | 
| 38 | 
            +
                values,
         | 
| 39 | 
            +
                pad_idx,
         | 
| 40 | 
            +
                eos_idx=None,
         | 
| 41 | 
            +
                left_pad=False,
         | 
| 42 | 
            +
                move_eos_to_beginning=False,
         | 
| 43 | 
            +
                pad_to_length=None,
         | 
| 44 | 
            +
                pad_to_multiple=1,
         | 
| 45 | 
            +
                pad_to_bsz=None,
         | 
| 46 | 
            +
            ):
         | 
| 47 | 
            +
                """Convert a list of 1d tensors into a padded 2d tensor."""
         | 
| 48 | 
            +
                size = max(v.size(0) for v in values)
         | 
| 49 | 
            +
                size = size if pad_to_length is None else max(size, pad_to_length)
         | 
| 50 | 
            +
                if pad_to_multiple != 1 and size % pad_to_multiple != 0:
         | 
| 51 | 
            +
                    size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def copy_tensor(src, dst):
         | 
| 54 | 
            +
                    assert dst.numel() == src.numel()
         | 
| 55 | 
            +
                    if move_eos_to_beginning:
         | 
| 56 | 
            +
                        if eos_idx is None:
         | 
| 57 | 
            +
                            # if no eos_idx is specified, then use the last token in src
         | 
| 58 | 
            +
                            dst[0] = src[-1]
         | 
| 59 | 
            +
                        else:
         | 
| 60 | 
            +
                            dst[0] = eos_idx
         | 
| 61 | 
            +
                        dst[1:] = src[:-1]
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        dst.copy_(src)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                if values[0].dim() == 1:
         | 
| 66 | 
            +
                    res = values[0].new(len(values), size).fill_(pad_idx)
         | 
| 67 | 
            +
                elif values[0].dim() == 2:
         | 
| 68 | 
            +
                    assert move_eos_to_beginning is False
         | 
| 69 | 
            +
                    res = values[0].new(len(values), size, values[0].size(1)).fill_(pad_idx)
         | 
| 70 | 
            +
                else:
         | 
| 71 | 
            +
                    raise NotImplementedError
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                for i, v in enumerate(values):
         | 
| 74 | 
            +
                    copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
         | 
| 75 | 
            +
                return res
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def load_indexed_dataset(
         | 
| 79 | 
            +
                path, dictionary=None, dataset_impl=None, combine=False, default="cached"
         | 
| 80 | 
            +
            ):
         | 
| 81 | 
            +
                """A helper function for loading indexed datasets.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                Args:
         | 
| 84 | 
            +
                    path (str): path to indexed dataset (e.g., 'data-bin/train')
         | 
| 85 | 
            +
                    dictionary (~fairseq.data.Dictionary): data dictionary
         | 
| 86 | 
            +
                    dataset_impl (str, optional): which dataset implementation to use. If
         | 
| 87 | 
            +
                        not provided, it will be inferred automatically. For legacy indexed
         | 
| 88 | 
            +
                        data we use the 'cached' implementation by default.
         | 
| 89 | 
            +
                    combine (bool, optional): automatically load and combine multiple
         | 
| 90 | 
            +
                        datasets. For example, if *path* is 'data-bin/train', then we will
         | 
| 91 | 
            +
                        combine 'data-bin/train', 'data-bin/train1', ... and return a
         | 
| 92 | 
            +
                        single ConcatDataset instance.
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                import fairseq.data.indexed_dataset as indexed_dataset
         | 
| 95 | 
            +
                from fairseq.data.concat_dataset import ConcatDataset
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                datasets = []
         | 
| 98 | 
            +
                for k in itertools.count():
         | 
| 99 | 
            +
                    path_k = path + (str(k) if k > 0 else "")
         | 
| 100 | 
            +
                    try:
         | 
| 101 | 
            +
                        path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
         | 
| 102 | 
            +
                    except Exception as e:
         | 
| 103 | 
            +
                        if "StorageException: [404] Path not found" in str(e):
         | 
| 104 | 
            +
                            logger.warning(f"path_k: {e} not found")
         | 
| 105 | 
            +
                        else:
         | 
| 106 | 
            +
                            raise e
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    dataset_impl_k = dataset_impl
         | 
| 109 | 
            +
                    if dataset_impl_k is None:
         | 
| 110 | 
            +
                        dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
         | 
| 111 | 
            +
                    dataset = indexed_dataset.make_dataset(
         | 
| 112 | 
            +
                        path_k,
         | 
| 113 | 
            +
                        impl=dataset_impl_k or default,
         | 
| 114 | 
            +
                        fix_lua_indexing=True,
         | 
| 115 | 
            +
                        dictionary=dictionary,
         | 
| 116 | 
            +
                    )
         | 
| 117 | 
            +
                    if dataset is None:
         | 
| 118 | 
            +
                        break
         | 
| 119 | 
            +
                    logger.info("loaded {:,} examples from: {}".format(len(dataset), path_k))
         | 
| 120 | 
            +
                    datasets.append(dataset)
         | 
| 121 | 
            +
                    if not combine:
         | 
| 122 | 
            +
                        break
         | 
| 123 | 
            +
                if len(datasets) == 0:
         | 
| 124 | 
            +
                    return None
         | 
| 125 | 
            +
                elif len(datasets) == 1:
         | 
| 126 | 
            +
                    return datasets[0]
         | 
| 127 | 
            +
                else:
         | 
| 128 | 
            +
                    return ConcatDataset(datasets)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
             | 
| 131 | 
            +
            @contextlib.contextmanager
         | 
| 132 | 
            +
            def numpy_seed(seed, *addl_seeds):
         | 
| 133 | 
            +
                """Context manager which seeds the NumPy PRNG with the specified seed and
         | 
| 134 | 
            +
                restores the state afterward"""
         | 
| 135 | 
            +
                if seed is None:
         | 
| 136 | 
            +
                    yield
         | 
| 137 | 
            +
                    return
         | 
| 138 | 
            +
                if len(addl_seeds) > 0:
         | 
| 139 | 
            +
                    seed = int(hash((seed, *addl_seeds)) % 1e6)
         | 
| 140 | 
            +
                state = np.random.get_state()
         | 
| 141 | 
            +
                np.random.seed(seed)
         | 
| 142 | 
            +
                try:
         | 
| 143 | 
            +
                    yield
         | 
| 144 | 
            +
                finally:
         | 
| 145 | 
            +
                    np.random.set_state(state)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
             | 
| 148 | 
            +
            def collect_filtered(function, iterable, filtered):
         | 
| 149 | 
            +
                """
         | 
| 150 | 
            +
                Similar to :func:`filter` but collects filtered elements in ``filtered``.
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                Args:
         | 
| 153 | 
            +
                    function (callable): function that returns ``False`` for elements that
         | 
| 154 | 
            +
                        should be filtered
         | 
| 155 | 
            +
                    iterable (iterable): iterable to filter
         | 
| 156 | 
            +
                    filtered (list): list to store filtered elements
         | 
| 157 | 
            +
                """
         | 
| 158 | 
            +
                for el in iterable:
         | 
| 159 | 
            +
                    if function(el):
         | 
| 160 | 
            +
                        yield el
         | 
| 161 | 
            +
                    else:
         | 
| 162 | 
            +
                        filtered.append(el)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
         | 
| 166 | 
            +
                def compare_leq(a, b):
         | 
| 167 | 
            +
                    return a <= b if not isinstance(a, tuple) else max(a) <= b
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                def check_size(idx):
         | 
| 170 | 
            +
                    if isinstance(max_positions, float) or isinstance(max_positions, int):
         | 
| 171 | 
            +
                        return size_fn(idx) <= max_positions
         | 
| 172 | 
            +
                    elif isinstance(max_positions, dict):
         | 
| 173 | 
            +
                        idx_size = size_fn(idx)
         | 
| 174 | 
            +
                        assert isinstance(idx_size, dict)
         | 
| 175 | 
            +
                        intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
         | 
| 176 | 
            +
                        return all(
         | 
| 177 | 
            +
                            all(
         | 
| 178 | 
            +
                                a is None or b is None or a <= b
         | 
| 179 | 
            +
                                for a, b in zip(idx_size[key], max_positions[key])
         | 
| 180 | 
            +
                            )
         | 
| 181 | 
            +
                            for key in intersect_keys
         | 
| 182 | 
            +
                        )
         | 
| 183 | 
            +
                    else:
         | 
| 184 | 
            +
                        # For MultiCorpusSampledDataset, will generalize it later
         | 
| 185 | 
            +
                        if not isinstance(size_fn(idx), Iterable):
         | 
| 186 | 
            +
                            return all(size_fn(idx) <= b for b in max_positions)
         | 
| 187 | 
            +
                        return all(
         | 
| 188 | 
            +
                            a is None or b is None or a <= b
         | 
| 189 | 
            +
                            for a, b in zip(size_fn(idx), max_positions)
         | 
| 190 | 
            +
                        )
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                ignored = []
         | 
| 193 | 
            +
                itr = collect_filtered(check_size, indices, ignored)
         | 
| 194 | 
            +
                indices = np.fromiter(itr, dtype=np.int64, count=-1)
         | 
| 195 | 
            +
                return indices, ignored
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            def filter_by_size(indices, dataset, max_positions, raise_exception=False):
         | 
| 199 | 
            +
                """
         | 
| 200 | 
            +
                [deprecated] Filter indices based on their size.
         | 
| 201 | 
            +
                Use `FairseqDataset::filter_indices_by_size` instead.
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                Args:
         | 
| 204 | 
            +
                    indices (List[int]): ordered list of dataset indices
         | 
| 205 | 
            +
                    dataset (FairseqDataset): fairseq dataset instance
         | 
| 206 | 
            +
                    max_positions (tuple): filter elements larger than this size.
         | 
| 207 | 
            +
                        Comparisons are done component-wise.
         | 
| 208 | 
            +
                    raise_exception (bool, optional): if ``True``, raise an exception if
         | 
| 209 | 
            +
                        any elements are filtered (default: False).
         | 
| 210 | 
            +
                """
         | 
| 211 | 
            +
                warnings.warn(
         | 
| 212 | 
            +
                    "data_utils.filter_by_size is deprecated. "
         | 
| 213 | 
            +
                    "Use `FairseqDataset::filter_indices_by_size` instead.",
         | 
| 214 | 
            +
                    stacklevel=2,
         | 
| 215 | 
            +
                )
         | 
| 216 | 
            +
                if isinstance(max_positions, float) or isinstance(max_positions, int):
         | 
| 217 | 
            +
                    if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
         | 
| 218 | 
            +
                        ignored = indices[dataset.sizes[indices] > max_positions].tolist()
         | 
| 219 | 
            +
                        indices = indices[dataset.sizes[indices] <= max_positions]
         | 
| 220 | 
            +
                    elif (
         | 
| 221 | 
            +
                        hasattr(dataset, "sizes")
         | 
| 222 | 
            +
                        and isinstance(dataset.sizes, list)
         | 
| 223 | 
            +
                        and len(dataset.sizes) == 1
         | 
| 224 | 
            +
                    ):
         | 
| 225 | 
            +
                        ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
         | 
| 226 | 
            +
                        indices = indices[dataset.sizes[0][indices] <= max_positions]
         | 
| 227 | 
            +
                    else:
         | 
| 228 | 
            +
                        indices, ignored = _filter_by_size_dynamic(
         | 
| 229 | 
            +
                            indices, dataset.size, max_positions
         | 
| 230 | 
            +
                        )
         | 
| 231 | 
            +
                else:
         | 
| 232 | 
            +
                    indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                if len(ignored) > 0 and raise_exception:
         | 
| 235 | 
            +
                    raise Exception(
         | 
| 236 | 
            +
                        (
         | 
| 237 | 
            +
                            "Size of sample #{} is invalid (={}) since max_positions={}, "
         | 
| 238 | 
            +
                            "skip this example with --skip-invalid-size-inputs-valid-test"
         | 
| 239 | 
            +
                        ).format(ignored[0], dataset.size(ignored[0]), max_positions)
         | 
| 240 | 
            +
                    )
         | 
| 241 | 
            +
                if len(ignored) > 0:
         | 
| 242 | 
            +
                    logger.warning(
         | 
| 243 | 
            +
                        (
         | 
| 244 | 
            +
                            "{} samples have invalid sizes and will be skipped, "
         | 
| 245 | 
            +
                            "max_positions={}, first few sample ids={}"
         | 
| 246 | 
            +
                        ).format(len(ignored), max_positions, ignored[:10])
         | 
| 247 | 
            +
                    )
         | 
| 248 | 
            +
                return indices
         | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
         | 
| 252 | 
            +
                """Filter a list of sample indices. Remove those that are longer
         | 
| 253 | 
            +
                    than specified in max_sizes.
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                Args:
         | 
| 256 | 
            +
                    indices (np.array): original array of sample indices
         | 
| 257 | 
            +
                    max_sizes (int or list[int] or tuple[int]): max sample size,
         | 
| 258 | 
            +
                        can be defined separately for src and tgt (then list or tuple)
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                Returns:
         | 
| 261 | 
            +
                    np.array: filtered sample array
         | 
| 262 | 
            +
                    list: list of removed indices
         | 
| 263 | 
            +
                """
         | 
| 264 | 
            +
                if max_sizes is None:
         | 
| 265 | 
            +
                    return indices, []
         | 
| 266 | 
            +
                if type(max_sizes) in (int, float):
         | 
| 267 | 
            +
                    max_src_size, max_tgt_size = max_sizes, max_sizes
         | 
| 268 | 
            +
                else:
         | 
| 269 | 
            +
                    max_src_size, max_tgt_size = max_sizes
         | 
| 270 | 
            +
                if tgt_sizes is None:
         | 
| 271 | 
            +
                    ignored = indices[src_sizes[indices] > max_src_size]
         | 
| 272 | 
            +
                else:
         | 
| 273 | 
            +
                    ignored = indices[
         | 
| 274 | 
            +
                        (src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
         | 
| 275 | 
            +
                    ]
         | 
| 276 | 
            +
                if len(ignored) > 0:
         | 
| 277 | 
            +
                    if tgt_sizes is None:
         | 
| 278 | 
            +
                        indices = indices[src_sizes[indices] <= max_src_size]
         | 
| 279 | 
            +
                    else:
         | 
| 280 | 
            +
                        indices = indices[
         | 
| 281 | 
            +
                            (src_sizes[indices] <= max_src_size)
         | 
| 282 | 
            +
                            & (tgt_sizes[indices] <= max_tgt_size)
         | 
| 283 | 
            +
                        ]
         | 
| 284 | 
            +
                return indices, ignored.tolist()
         | 
| 285 | 
            +
             | 
| 286 | 
            +
             | 
| 287 | 
            +
            def batch_by_size(
         | 
| 288 | 
            +
                indices,
         | 
| 289 | 
            +
                num_tokens_fn,
         | 
| 290 | 
            +
                num_tokens_vec=None,
         | 
| 291 | 
            +
                max_tokens=None,
         | 
| 292 | 
            +
                max_sentences=None,
         | 
| 293 | 
            +
                required_batch_size_multiple=1,
         | 
| 294 | 
            +
                fixed_shapes=None,
         | 
| 295 | 
            +
            ):
         | 
| 296 | 
            +
                """
         | 
| 297 | 
            +
                Yield mini-batches of indices bucketed by size. Batches may contain
         | 
| 298 | 
            +
                sequences of different lengths.
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                Args:
         | 
| 301 | 
            +
                    indices (List[int]): ordered list of dataset indices
         | 
| 302 | 
            +
                    num_tokens_fn (callable): function that returns the number of tokens at
         | 
| 303 | 
            +
                        a given index
         | 
| 304 | 
            +
                    num_tokens_vec (List[int], optional): precomputed vector of the number
         | 
| 305 | 
            +
                        of tokens for each index in indices (to enable faster batch generation)
         | 
| 306 | 
            +
                    max_tokens (int, optional): max number of tokens in each batch
         | 
| 307 | 
            +
                        (default: None).
         | 
| 308 | 
            +
                    max_sentences (int, optional): max number of sentences in each
         | 
| 309 | 
            +
                        batch (default: None).
         | 
| 310 | 
            +
                    required_batch_size_multiple (int, optional): require batch size to
         | 
| 311 | 
            +
                        be less than N or a multiple of N (default: 1).
         | 
| 312 | 
            +
                    fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
         | 
| 313 | 
            +
                        only be created with the given shapes. *max_sentences* and
         | 
| 314 | 
            +
                        *required_batch_size_multiple* will be ignored (default: None).
         | 
| 315 | 
            +
                """
         | 
| 316 | 
            +
                try:
         | 
| 317 | 
            +
                    from fairseq.data.data_utils_fast import (
         | 
| 318 | 
            +
                        batch_by_size_fn,
         | 
| 319 | 
            +
                        batch_by_size_vec,
         | 
| 320 | 
            +
                        batch_fixed_shapes_fast,
         | 
| 321 | 
            +
                    )
         | 
| 322 | 
            +
                except ImportError:
         | 
| 323 | 
            +
                    raise ImportError(
         | 
| 324 | 
            +
                        "Please build Cython components with: "
         | 
| 325 | 
            +
                        "`python setup.py build_ext --inplace`"
         | 
| 326 | 
            +
                    )
         | 
| 327 | 
            +
                except ValueError:
         | 
| 328 | 
            +
                    raise ValueError(
         | 
| 329 | 
            +
                        "Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
         | 
| 330 | 
            +
                    )
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                # added int() to avoid TypeError: an integer is required
         | 
| 333 | 
            +
                max_tokens = (
         | 
| 334 | 
            +
                    int(max_tokens) if max_tokens is not None else -1
         | 
| 335 | 
            +
                )
         | 
| 336 | 
            +
                max_sentences = max_sentences if max_sentences is not None else -1
         | 
| 337 | 
            +
                bsz_mult = required_batch_size_multiple
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                if not isinstance(indices, np.ndarray):
         | 
| 340 | 
            +
                    indices = np.fromiter(indices, dtype=np.int64, count=-1)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                if num_tokens_vec is not None and not isinstance(num_tokens_vec, np.ndarray):
         | 
| 343 | 
            +
                    num_tokens_vec = np.fromiter(num_tokens_vec, dtype=np.int64, count=-1)
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                if fixed_shapes is None:
         | 
| 346 | 
            +
                    if num_tokens_vec is None:
         | 
| 347 | 
            +
                        return batch_by_size_fn(
         | 
| 348 | 
            +
                            indices,
         | 
| 349 | 
            +
                            num_tokens_fn,
         | 
| 350 | 
            +
                            max_tokens,
         | 
| 351 | 
            +
                            max_sentences,
         | 
| 352 | 
            +
                            bsz_mult,
         | 
| 353 | 
            +
                        )
         | 
| 354 | 
            +
                    else:
         | 
| 355 | 
            +
                        return batch_by_size_vec(
         | 
| 356 | 
            +
                            indices,
         | 
| 357 | 
            +
                            num_tokens_vec,
         | 
| 358 | 
            +
                            max_tokens,
         | 
| 359 | 
            +
                            max_sentences,
         | 
| 360 | 
            +
                            bsz_mult,
         | 
| 361 | 
            +
                        )
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                else:
         | 
| 364 | 
            +
                    fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
         | 
| 365 | 
            +
                    sort_order = np.lexsort(
         | 
| 366 | 
            +
                        [
         | 
| 367 | 
            +
                            fixed_shapes[:, 1].argsort(),  # length
         | 
| 368 | 
            +
                            fixed_shapes[:, 0].argsort(),  # bsz
         | 
| 369 | 
            +
                        ]
         | 
| 370 | 
            +
                    )
         | 
| 371 | 
            +
                    fixed_shapes_sorted = fixed_shapes[sort_order]
         | 
| 372 | 
            +
                    return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
             | 
| 375 | 
            +
            def post_process(sentence: str, symbol: str):
         | 
| 376 | 
            +
                if symbol == "sentencepiece":
         | 
| 377 | 
            +
                    sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
         | 
| 378 | 
            +
                elif symbol == "wordpiece":
         | 
| 379 | 
            +
                    sentence = sentence.replace(" ", "").replace("_", " ").strip()
         | 
| 380 | 
            +
                elif symbol == "letter":
         | 
| 381 | 
            +
                    sentence = sentence.replace(" ", "").replace("|", " ").strip()
         | 
| 382 | 
            +
                elif symbol == "silence":
         | 
| 383 | 
            +
                    import re
         | 
| 384 | 
            +
                    sentence = sentence.replace("<SIL>", "")
         | 
| 385 | 
            +
                    sentence = re.sub(' +', ' ', sentence).strip()
         | 
| 386 | 
            +
                elif symbol == "_EOW":
         | 
| 387 | 
            +
                    sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
         | 
| 388 | 
            +
                elif symbol in {"subword_nmt", "@@ ", "@@"}:
         | 
| 389 | 
            +
                    if symbol == "subword_nmt":
         | 
| 390 | 
            +
                        symbol = "@@ "
         | 
| 391 | 
            +
                    sentence = (sentence + " ").replace(symbol, "").rstrip()
         | 
| 392 | 
            +
                elif symbol == "none":
         | 
| 393 | 
            +
                    pass
         | 
| 394 | 
            +
                elif symbol is not None:
         | 
| 395 | 
            +
                    raise NotImplementedError(f"Unknown post_process option: {symbol}")
         | 
| 396 | 
            +
                return sentence
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            def compute_mask_indices(
         | 
| 400 | 
            +
                shape: Tuple[int, int],
         | 
| 401 | 
            +
                padding_mask: Optional[torch.Tensor],
         | 
| 402 | 
            +
                mask_prob: float,
         | 
| 403 | 
            +
                mask_length: int,
         | 
| 404 | 
            +
                mask_type: str = "static",
         | 
| 405 | 
            +
                mask_other: float = 0.0,
         | 
| 406 | 
            +
                min_masks: int = 0,
         | 
| 407 | 
            +
                no_overlap: bool = False,
         | 
| 408 | 
            +
                min_space: int = 0,
         | 
| 409 | 
            +
            ) -> np.ndarray:
         | 
| 410 | 
            +
                """
         | 
| 411 | 
            +
                Computes random mask spans for a given shape
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                Args:
         | 
| 414 | 
            +
                    shape: the the shape for which to compute masks.
         | 
| 415 | 
            +
                        should be of size 2 where first element is batch size and 2nd is timesteps
         | 
| 416 | 
            +
                    padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
         | 
| 417 | 
            +
                    mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
         | 
| 418 | 
            +
                        number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
         | 
| 419 | 
            +
                        however due to overlaps, the actual number will be smaller (unless no_overlap is True)
         | 
| 420 | 
            +
                    mask_type: how to compute mask lengths
         | 
| 421 | 
            +
                        static = fixed size
         | 
| 422 | 
            +
                        uniform = sample from uniform distribution [mask_other, mask_length*2]
         | 
| 423 | 
            +
                        normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
         | 
| 424 | 
            +
                        poisson = sample from possion distribution with lambda = mask length
         | 
| 425 | 
            +
                    min_masks: minimum number of masked spans
         | 
| 426 | 
            +
                    no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
         | 
| 427 | 
            +
                    min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
         | 
| 428 | 
            +
                """
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                bsz, all_sz = shape
         | 
| 431 | 
            +
                mask = np.full((bsz, all_sz), False)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                all_num_mask = int(
         | 
| 434 | 
            +
                    # add a random number for probabilistic rounding
         | 
| 435 | 
            +
                    mask_prob * all_sz / float(mask_length)
         | 
| 436 | 
            +
                    + np.random.rand()
         | 
| 437 | 
            +
                )
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                all_num_mask = max(min_masks, all_num_mask)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                mask_idcs = []
         | 
| 442 | 
            +
                for i in range(bsz):
         | 
| 443 | 
            +
                    if padding_mask is not None:
         | 
| 444 | 
            +
                        sz = all_sz - padding_mask[i].long().sum().item()
         | 
| 445 | 
            +
                        num_mask = int(
         | 
| 446 | 
            +
                            # add a random number for probabilistic rounding
         | 
| 447 | 
            +
                            mask_prob * sz / float(mask_length)
         | 
| 448 | 
            +
                            + np.random.rand()
         | 
| 449 | 
            +
                        )
         | 
| 450 | 
            +
                        num_mask = max(min_masks, num_mask)
         | 
| 451 | 
            +
                    else:
         | 
| 452 | 
            +
                        sz = all_sz
         | 
| 453 | 
            +
                        num_mask = all_num_mask
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    if mask_type == "static":
         | 
| 456 | 
            +
                        lengths = np.full(num_mask, mask_length)
         | 
| 457 | 
            +
                    elif mask_type == "uniform":
         | 
| 458 | 
            +
                        lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
         | 
| 459 | 
            +
                    elif mask_type == "normal":
         | 
| 460 | 
            +
                        lengths = np.random.normal(mask_length, mask_other, size=num_mask)
         | 
| 461 | 
            +
                        lengths = [max(1, int(round(x))) for x in lengths]
         | 
| 462 | 
            +
                    elif mask_type == "poisson":
         | 
| 463 | 
            +
                        lengths = np.random.poisson(mask_length, size=num_mask)
         | 
| 464 | 
            +
                        lengths = [int(round(x)) for x in lengths]
         | 
| 465 | 
            +
                    else:
         | 
| 466 | 
            +
                        raise Exception("unknown mask selection " + mask_type)
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    if sum(lengths) == 0:
         | 
| 469 | 
            +
                        lengths[0] = min(mask_length, sz - 1)
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    if no_overlap:
         | 
| 472 | 
            +
                        mask_idc = []
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                        def arrange(s, e, length, keep_length):
         | 
| 475 | 
            +
                            span_start = np.random.randint(s, e - length)
         | 
| 476 | 
            +
                            mask_idc.extend(span_start + i for i in range(length))
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                            new_parts = []
         | 
| 479 | 
            +
                            if span_start - s - min_space >= keep_length:
         | 
| 480 | 
            +
                                new_parts.append((s, span_start - min_space + 1))
         | 
| 481 | 
            +
                            if e - span_start - keep_length - min_space > keep_length:
         | 
| 482 | 
            +
                                new_parts.append((span_start + length + min_space, e))
         | 
| 483 | 
            +
                            return new_parts
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                        parts = [(0, sz)]
         | 
| 486 | 
            +
                        min_length = min(lengths)
         | 
| 487 | 
            +
                        for length in sorted(lengths, reverse=True):
         | 
| 488 | 
            +
                            lens = np.fromiter(
         | 
| 489 | 
            +
                                (e - s if e - s >= length + min_space else 0 for s, e in parts),
         | 
| 490 | 
            +
                                np.int,
         | 
| 491 | 
            +
                            )
         | 
| 492 | 
            +
                            l_sum = np.sum(lens)
         | 
| 493 | 
            +
                            if l_sum == 0:
         | 
| 494 | 
            +
                                break
         | 
| 495 | 
            +
                            probs = lens / np.sum(lens)
         | 
| 496 | 
            +
                            c = np.random.choice(len(parts), p=probs)
         | 
| 497 | 
            +
                            s, e = parts.pop(c)
         | 
| 498 | 
            +
                            parts.extend(arrange(s, e, length, min_length))
         | 
| 499 | 
            +
                        mask_idc = np.asarray(mask_idc)
         | 
| 500 | 
            +
                    else:
         | 
| 501 | 
            +
                        min_len = min(lengths)
         | 
| 502 | 
            +
                        if sz - min_len <= num_mask:
         | 
| 503 | 
            +
                            min_len = sz - num_mask - 1
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                        mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                        mask_idc = np.asarray(
         | 
| 508 | 
            +
                            [
         | 
| 509 | 
            +
                                mask_idc[j] + offset
         | 
| 510 | 
            +
                                for j in range(len(mask_idc))
         | 
| 511 | 
            +
                                for offset in range(lengths[j])
         | 
| 512 | 
            +
                            ]
         | 
| 513 | 
            +
                        )
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                min_len = min([len(m) for m in mask_idcs])
         | 
| 518 | 
            +
                for i, mask_idc in enumerate(mask_idcs):
         | 
| 519 | 
            +
                    if len(mask_idc) > min_len:
         | 
| 520 | 
            +
                        mask_idc = np.random.choice(mask_idc, min_len, replace=False)
         | 
| 521 | 
            +
                    mask[i, mask_idc] = True
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                return mask
         | 
| 524 | 
            +
             | 
| 525 | 
            +
             | 
| 526 | 
            +
            def get_mem_usage():
         | 
| 527 | 
            +
                try:
         | 
| 528 | 
            +
                    import psutil
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    mb = 1024 * 1024
         | 
| 531 | 
            +
                    return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
         | 
| 532 | 
            +
                except ImportError:
         | 
| 533 | 
            +
                    return "N/A"
         | 
| 534 | 
            +
             | 
| 535 | 
            +
             | 
| 536 | 
            +
            # lens: torch.LongTensor
         | 
| 537 | 
            +
            # returns: torch.BoolTensor
         | 
| 538 | 
            +
            def lengths_to_padding_mask(lens):
         | 
| 539 | 
            +
                bsz, max_lens = lens.size(0), torch.max(lens).item()
         | 
| 540 | 
            +
                mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
         | 
| 541 | 
            +
                mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
         | 
| 542 | 
            +
                return mask
         | 
| 543 | 
            +
             | 
| 544 | 
            +
             | 
| 545 | 
            +
            # lens: torch.LongTensor
         | 
| 546 | 
            +
            # returns: torch.BoolTensor
         | 
| 547 | 
            +
            def lengths_to_mask(lens):
         | 
| 548 | 
            +
                return ~lengths_to_padding_mask(lens)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
             | 
| 551 | 
            +
            def get_buckets(sizes, num_buckets):
         | 
| 552 | 
            +
                buckets = np.unique(
         | 
| 553 | 
            +
                    np.percentile(
         | 
| 554 | 
            +
                        sizes,
         | 
| 555 | 
            +
                        np.linspace(0, 100, num_buckets + 1),
         | 
| 556 | 
            +
                        interpolation='lower',
         | 
| 557 | 
            +
                    )[1:]
         | 
| 558 | 
            +
                )
         | 
| 559 | 
            +
                return buckets
         | 
| 560 | 
            +
             | 
| 561 | 
            +
             | 
| 562 | 
            +
            def get_bucketed_sizes(orig_sizes, buckets):
         | 
| 563 | 
            +
                sizes = np.copy(orig_sizes)
         | 
| 564 | 
            +
                assert np.min(sizes) >= 0
         | 
| 565 | 
            +
                start_val = -1
         | 
| 566 | 
            +
                for end_val in buckets:
         | 
| 567 | 
            +
                    mask = (sizes > start_val) & (sizes <= end_val)
         | 
| 568 | 
            +
                    sizes[mask] = end_val
         | 
| 569 | 
            +
                    start_val = end_val
         | 
| 570 | 
            +
                return sizes
         | 
| 571 | 
            +
             | 
| 572 | 
            +
             | 
| 573 | 
            +
             | 
| 574 | 
            +
            def _find_extra_valid_paths(dataset_path: str) -> set:
         | 
| 575 | 
            +
                paths = utils.split_paths(dataset_path)
         | 
| 576 | 
            +
                all_valid_paths = set()
         | 
| 577 | 
            +
                for sub_dir in paths:
         | 
| 578 | 
            +
                    contents = PathManager.ls(sub_dir)
         | 
| 579 | 
            +
                    valid_paths = [c for c in contents if re.match("valid*[0-9].*", c) is not None]
         | 
| 580 | 
            +
                    all_valid_paths |= {os.path.basename(p) for p in valid_paths}
         | 
| 581 | 
            +
                # Remove .bin, .idx etc
         | 
| 582 | 
            +
                roots = {os.path.splitext(p)[0] for p in all_valid_paths}
         | 
| 583 | 
            +
                return roots
         | 
| 584 | 
            +
             | 
| 585 | 
            +
             | 
| 586 | 
            +
            def raise_if_valid_subsets_unintentionally_ignored(train_cfg) -> None:
         | 
| 587 | 
            +
                """Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
         | 
| 588 | 
            +
                if (
         | 
| 589 | 
            +
                    train_cfg.dataset.ignore_unused_valid_subsets
         | 
| 590 | 
            +
                    or train_cfg.dataset.combine_valid_subsets
         | 
| 591 | 
            +
                    or train_cfg.dataset.disable_validation
         | 
| 592 | 
            +
                    or not hasattr(train_cfg.task, "data")
         | 
| 593 | 
            +
                ):
         | 
| 594 | 
            +
                    return
         | 
| 595 | 
            +
                other_paths = _find_extra_valid_paths(train_cfg.task.data)
         | 
| 596 | 
            +
                specified_subsets = train_cfg.dataset.valid_subset.split(",")
         | 
| 597 | 
            +
                ignored_paths = [p for p in other_paths if p not in specified_subsets]
         | 
| 598 | 
            +
                if ignored_paths:
         | 
| 599 | 
            +
                    advice = "Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
         | 
| 600 | 
            +
                    msg = f"Valid paths {ignored_paths} will be ignored. {advice}"
         | 
| 601 | 
            +
                    raise ValueError(msg)
         | 
    	
        data/file_dataset.py
    ADDED
    
    | @@ -0,0 +1,102 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import pickle
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class FileDataset:
         | 
| 7 | 
            +
                def __init__(self, file_path, selected_col_ids=None, dtypes=None, separator="\t", cached_index=False):
         | 
| 8 | 
            +
                    self.file_path = file_path
         | 
| 9 | 
            +
                    assert os.path.exists(self.file_path), "Error: The local datafile {} not exists!".format(self.file_path)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                    self.separator = separator
         | 
| 12 | 
            +
                    if selected_col_ids is None:
         | 
| 13 | 
            +
                        # default to all fields
         | 
| 14 | 
            +
                        self.selected_col_ids = list(
         | 
| 15 | 
            +
                            range(len(open(self.file_path).readline().rstrip("\n").split(self.separator))))
         | 
| 16 | 
            +
                    else:
         | 
| 17 | 
            +
                        self.selected_col_ids = [int(col_id) for col_id in selected_col_ids.split(",")]
         | 
| 18 | 
            +
                    if dtypes is None:
         | 
| 19 | 
            +
                        # default to str
         | 
| 20 | 
            +
                        self.dtypes = [str for col_id in self.selected_col_ids]
         | 
| 21 | 
            +
                    else:
         | 
| 22 | 
            +
                        self.dtypes = [eval(col_dtype) for col_dtype in dtypes.split(",")]
         | 
| 23 | 
            +
                        assert len(self.dtypes) == len(self.selected_col_ids)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    self.data_cnt = 0
         | 
| 26 | 
            +
                    try:
         | 
| 27 | 
            +
                        self.slice_id = torch.distributed.get_rank()
         | 
| 28 | 
            +
                        self.slice_count = torch.distributed.get_world_size()
         | 
| 29 | 
            +
                    except Exception:
         | 
| 30 | 
            +
                        self.slice_id = 0
         | 
| 31 | 
            +
                        self.slice_count = 1
         | 
| 32 | 
            +
                    self.cached_index = cached_index
         | 
| 33 | 
            +
                    self._init_seek_index()
         | 
| 34 | 
            +
                    self._reader = self._get_reader()
         | 
| 35 | 
            +
                    print("file {} slice_id {} row count {} total row count {}".format(
         | 
| 36 | 
            +
                        self.file_path, self.slice_id, self.row_count, self.total_row_count)
         | 
| 37 | 
            +
                    )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def _init_seek_index(self):
         | 
| 40 | 
            +
                    if self.cached_index:
         | 
| 41 | 
            +
                        cache_path = "{}.index".format(self.file_path)
         | 
| 42 | 
            +
                        assert os.path.exists(cache_path), "cache file {} not exists!".format(cache_path)
         | 
| 43 | 
            +
                        self.total_row_count, self.lineid_to_offset = pickle.load(open(cache_path, "rb"))
         | 
| 44 | 
            +
                        print("local datafile {} slice_id {} use cached row_count and line_idx-to-offset mapping".format(
         | 
| 45 | 
            +
                            self.file_path, self.slice_id))
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        # make an iteration over the file to get row_count and line_idx-to-offset mapping
         | 
| 48 | 
            +
                        fp = open(self.file_path, "r")
         | 
| 49 | 
            +
                        print("local datafile {} slice_id {} begin to initialize row_count and line_idx-to-offset mapping".format(
         | 
| 50 | 
            +
                            self.file_path, self.slice_id))
         | 
| 51 | 
            +
                        self.total_row_count = 0
         | 
| 52 | 
            +
                        offset = 0
         | 
| 53 | 
            +
                        self.lineid_to_offset = []
         | 
| 54 | 
            +
                        for line in fp:
         | 
| 55 | 
            +
                            self.lineid_to_offset.append(offset)
         | 
| 56 | 
            +
                            self.total_row_count += 1
         | 
| 57 | 
            +
                            offset += len(line.encode('utf-8'))
         | 
| 58 | 
            +
                    self._compute_start_pos_and_row_count()
         | 
| 59 | 
            +
                    print("local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping".format(
         | 
| 60 | 
            +
                        self.file_path, self.slice_id))
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def _compute_start_pos_and_row_count(self):
         | 
| 63 | 
            +
                    self.row_count = self.total_row_count // self.slice_count
         | 
| 64 | 
            +
                    if self.slice_id < self.total_row_count - self.row_count * self.slice_count:
         | 
| 65 | 
            +
                        self.row_count += 1
         | 
| 66 | 
            +
                        self.start_pos = self.row_count * self.slice_id
         | 
| 67 | 
            +
                    else:
         | 
| 68 | 
            +
                        self.start_pos = self.row_count * self.slice_id + (self.total_row_count - self.row_count * self.slice_count)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def _get_reader(self):
         | 
| 71 | 
            +
                    fp = open(self.file_path, "r")
         | 
| 72 | 
            +
                    fp.seek(self.lineid_to_offset[self.start_pos])
         | 
| 73 | 
            +
                    return fp
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _seek(self, offset=0):
         | 
| 76 | 
            +
                    try:
         | 
| 77 | 
            +
                        print("slice_id {} seek offset {}".format(self.slice_id, self.start_pos + offset))
         | 
| 78 | 
            +
                        self._reader.seek(self.lineid_to_offset[self.start_pos + offset])
         | 
| 79 | 
            +
                        self.data_cnt = offset
         | 
| 80 | 
            +
                    except Exception:
         | 
| 81 | 
            +
                        print("slice_id {} seek offset {}".format(self.slice_id, offset))
         | 
| 82 | 
            +
                        self._reader.seek(self.lineid_to_offset[offset])
         | 
| 83 | 
            +
                        self.data_cnt = offset
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                def __del__(self):
         | 
| 86 | 
            +
                    self._reader.close()
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __len__(self):
         | 
| 89 | 
            +
                    return self.row_count
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def get_total_row_count(self):
         | 
| 92 | 
            +
                    return self.total_row_count
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __getitem__(self, index):
         | 
| 95 | 
            +
                    if self.data_cnt == self.row_count:
         | 
| 96 | 
            +
                        print("reach the end of datafile, start a new reader")
         | 
| 97 | 
            +
                        self.data_cnt = 0
         | 
| 98 | 
            +
                        self._reader = self._get_reader()
         | 
| 99 | 
            +
                    column_l = self._reader.readline().rstrip("\n").split(self.separator)
         | 
| 100 | 
            +
                    self.data_cnt += 1
         | 
| 101 | 
            +
                    column_l = [dtype(column_l[col_id]) for col_id, dtype in zip(self.selected_col_ids, self.dtypes)]
         | 
| 102 | 
            +
                    return column_l
         | 
    	
        data/mm_data/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        data/mm_data/caption_dataset.py
    ADDED
    
    | @@ -0,0 +1,154 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 4 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
            from io import BytesIO
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            import warnings
         | 
| 9 | 
            +
            import string
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import base64
         | 
| 14 | 
            +
            from torchvision import transforms
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            from PIL import Image, ImageFile
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from data import data_utils
         | 
| 19 | 
            +
            from data.ofa_dataset import OFADataset
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 22 | 
            +
            ImageFile.MAX_IMAGE_PIXELS = None
         | 
| 23 | 
            +
            Image.MAX_IMAGE_PIXELS = None
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 26 | 
            +
            warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
         | 
| 29 | 
            +
            IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            def collate(samples, pad_idx, eos_idx):
         | 
| 33 | 
            +
                if len(samples) == 0:
         | 
| 34 | 
            +
                    return {}
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                def merge(key):
         | 
| 37 | 
            +
                    return data_utils.collate_tokens(
         | 
| 38 | 
            +
                        [s[key] for s in samples],
         | 
| 39 | 
            +
                        pad_idx,
         | 
| 40 | 
            +
                        eos_idx=eos_idx,
         | 
| 41 | 
            +
                    )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                id = np.array([s["id"] for s in samples])
         | 
| 44 | 
            +
                src_tokens = merge("source")
         | 
| 45 | 
            +
                src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
         | 
| 48 | 
            +
                patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                prev_output_tokens = None
         | 
| 51 | 
            +
                target = None
         | 
| 52 | 
            +
                if samples[0].get("target", None) is not None:
         | 
| 53 | 
            +
                    target = merge("target")
         | 
| 54 | 
            +
                    tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
         | 
| 55 | 
            +
                    ntokens = tgt_lengths.sum().item()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    if samples[0].get("prev_output_tokens", None) is not None:
         | 
| 58 | 
            +
                        prev_output_tokens = merge("prev_output_tokens")
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    ntokens = src_lengths.sum().item()
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                batch = {
         | 
| 63 | 
            +
                    "id": id,
         | 
| 64 | 
            +
                    "nsentences": len(samples),
         | 
| 65 | 
            +
                    "ntokens": ntokens,
         | 
| 66 | 
            +
                    "net_input": {
         | 
| 67 | 
            +
                        "src_tokens": src_tokens,
         | 
| 68 | 
            +
                        "src_lengths": src_lengths,
         | 
| 69 | 
            +
                        "patch_images": patch_images,
         | 
| 70 | 
            +
                        "patch_masks": patch_masks,
         | 
| 71 | 
            +
                        "prev_output_tokens": prev_output_tokens
         | 
| 72 | 
            +
                    },
         | 
| 73 | 
            +
                    "target": target,
         | 
| 74 | 
            +
                }
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                return batch
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class CaptionDataset(OFADataset):
         | 
| 80 | 
            +
                def __init__(
         | 
| 81 | 
            +
                    self,
         | 
| 82 | 
            +
                    split,
         | 
| 83 | 
            +
                    dataset,
         | 
| 84 | 
            +
                    bpe,
         | 
| 85 | 
            +
                    src_dict,
         | 
| 86 | 
            +
                    tgt_dict=None,
         | 
| 87 | 
            +
                    max_src_length=128,
         | 
| 88 | 
            +
                    max_tgt_length=30,
         | 
| 89 | 
            +
                    patch_image_size=224,
         | 
| 90 | 
            +
                    imagenet_default_mean_and_std=False,
         | 
| 91 | 
            +
                    scst=False
         | 
| 92 | 
            +
                ):
         | 
| 93 | 
            +
                    super().__init__(split, dataset, bpe, src_dict, tgt_dict)
         | 
| 94 | 
            +
                    self.max_src_length = max_src_length
         | 
| 95 | 
            +
                    self.max_tgt_length = max_tgt_length
         | 
| 96 | 
            +
                    self.patch_image_size = patch_image_size
         | 
| 97 | 
            +
                    self.scst = scst
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.transtab = str.maketrans({key: None for key in string.punctuation})
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    if imagenet_default_mean_and_std:
         | 
| 102 | 
            +
                        mean = IMAGENET_DEFAULT_MEAN
         | 
| 103 | 
            +
                        std = IMAGENET_DEFAULT_STD
         | 
| 104 | 
            +
                    else:
         | 
| 105 | 
            +
                        mean = [0.5, 0.5, 0.5]
         | 
| 106 | 
            +
                        std = [0.5, 0.5, 0.5]
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    self.patch_resize_transform = transforms.Compose([
         | 
| 109 | 
            +
                        lambda image: image.convert("RGB"),
         | 
| 110 | 
            +
                        transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
         | 
| 111 | 
            +
                        transforms.ToTensor(),
         | 
| 112 | 
            +
                        transforms.Normalize(mean=mean, std=std),
         | 
| 113 | 
            +
                    ])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def __getitem__(self, index):
         | 
| 116 | 
            +
                    uniq_id, image, caption = self.dataset[index]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
         | 
| 119 | 
            +
                    patch_image = self.patch_resize_transform(image)
         | 
| 120 | 
            +
                    patch_mask = torch.tensor([True])
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    if self.split == 'train' and not self.scst:
         | 
| 123 | 
            +
                        caption = caption.translate(self.transtab).strip()
         | 
| 124 | 
            +
                        caption_token_list = caption.strip().split()
         | 
| 125 | 
            +
                        tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
         | 
| 126 | 
            +
                    else:
         | 
| 127 | 
            +
                        caption = ' '.join(caption.strip().split())
         | 
| 128 | 
            +
                        caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
         | 
| 129 | 
            +
                        tgt_caption = '&&'.join(caption_list)
         | 
| 130 | 
            +
                    src_item = self.encode_text(" what does the image describe?")
         | 
| 131 | 
            +
                    tgt_item = self.encode_text(" {}".format(tgt_caption))
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    src_item = torch.cat([self.bos_item, src_item, self.eos_item])
         | 
| 134 | 
            +
                    target_item = torch.cat([tgt_item, self.eos_item])
         | 
| 135 | 
            +
                    prev_output_item = torch.cat([self.bos_item, tgt_item])
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    example = {
         | 
| 138 | 
            +
                        "id": uniq_id,
         | 
| 139 | 
            +
                        "source": src_item,
         | 
| 140 | 
            +
                        "patch_image": patch_image,
         | 
| 141 | 
            +
                        "patch_mask": patch_mask,
         | 
| 142 | 
            +
                        "target": target_item,
         | 
| 143 | 
            +
                        "prev_output_tokens": prev_output_item
         | 
| 144 | 
            +
                    }
         | 
| 145 | 
            +
                    return example
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                def collater(self, samples, pad_to_length=None):
         | 
| 148 | 
            +
                    """Merge a list of samples to form a mini-batch.
         | 
| 149 | 
            +
                    Args:
         | 
| 150 | 
            +
                        samples (List[dict]): samples to collate
         | 
| 151 | 
            +
                    Returns:
         | 
| 152 | 
            +
                        dict: a mini-batch with the following keys:
         | 
| 153 | 
            +
                    """
         | 
| 154 | 
            +
                    return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
         | 
    	
        data/mm_data/refcoco_dataset.py
    ADDED
    
    | @@ -0,0 +1,168 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 4 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
            from io import BytesIO
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            import warnings
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import base64
         | 
| 13 | 
            +
            import utils.transforms as T
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from PIL import Image, ImageFile
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from data import data_utils
         | 
| 18 | 
            +
            from data.ofa_dataset import OFADataset
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 21 | 
            +
            ImageFile.MAX_IMAGE_PIXELS = None
         | 
| 22 | 
            +
            Image.MAX_IMAGE_PIXELS = None
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 25 | 
            +
            warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
         | 
| 28 | 
            +
            IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def collate(samples, pad_idx, eos_idx):
         | 
| 32 | 
            +
                if len(samples) == 0:
         | 
| 33 | 
            +
                    return {}
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def merge(key):
         | 
| 36 | 
            +
                    return data_utils.collate_tokens(
         | 
| 37 | 
            +
                        [s[key] for s in samples],
         | 
| 38 | 
            +
                        pad_idx,
         | 
| 39 | 
            +
                        eos_idx=eos_idx,
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                id = np.array([s["id"] for s in samples])
         | 
| 43 | 
            +
                src_tokens = merge("source")
         | 
| 44 | 
            +
                src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
         | 
| 47 | 
            +
                patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
         | 
| 50 | 
            +
                h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
         | 
| 51 | 
            +
                region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                prev_output_tokens = None
         | 
| 54 | 
            +
                target = None
         | 
| 55 | 
            +
                if samples[0].get("target", None) is not None:
         | 
| 56 | 
            +
                    target = merge("target")
         | 
| 57 | 
            +
                    tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
         | 
| 58 | 
            +
                    ntokens = tgt_lengths.sum().item()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    if samples[0].get("prev_output_tokens", None) is not None:
         | 
| 61 | 
            +
                        prev_output_tokens = merge("prev_output_tokens")
         | 
| 62 | 
            +
                else:
         | 
| 63 | 
            +
                    ntokens = src_lengths.sum().item()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                batch = {
         | 
| 66 | 
            +
                    "id": id,
         | 
| 67 | 
            +
                    "nsentences": len(samples),
         | 
| 68 | 
            +
                    "ntokens": ntokens,
         | 
| 69 | 
            +
                    "net_input": {
         | 
| 70 | 
            +
                        "src_tokens": src_tokens,
         | 
| 71 | 
            +
                        "src_lengths": src_lengths,
         | 
| 72 | 
            +
                        "patch_images": patch_images,
         | 
| 73 | 
            +
                        "patch_masks": patch_masks,
         | 
| 74 | 
            +
                        "prev_output_tokens": prev_output_tokens
         | 
| 75 | 
            +
                    },
         | 
| 76 | 
            +
                    "target": target,
         | 
| 77 | 
            +
                    "w_resize_ratios": w_resize_ratios,
         | 
| 78 | 
            +
                    "h_resize_ratios": h_resize_ratios,
         | 
| 79 | 
            +
                    "region_coords": region_coords
         | 
| 80 | 
            +
                }
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return batch
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            class RefcocoDataset(OFADataset):
         | 
| 86 | 
            +
                def __init__(
         | 
| 87 | 
            +
                    self,
         | 
| 88 | 
            +
                    split,
         | 
| 89 | 
            +
                    dataset,
         | 
| 90 | 
            +
                    bpe,
         | 
| 91 | 
            +
                    src_dict,
         | 
| 92 | 
            +
                    tgt_dict=None,
         | 
| 93 | 
            +
                    max_src_length=80,
         | 
| 94 | 
            +
                    max_tgt_length=30,
         | 
| 95 | 
            +
                    patch_image_size=512,
         | 
| 96 | 
            +
                    imagenet_default_mean_and_std=False,
         | 
| 97 | 
            +
                    num_bins=1000,
         | 
| 98 | 
            +
                    max_image_size=512
         | 
| 99 | 
            +
                ):
         | 
| 100 | 
            +
                    super().__init__(split, dataset, bpe, src_dict, tgt_dict)
         | 
| 101 | 
            +
                    self.max_src_length = max_src_length
         | 
| 102 | 
            +
                    self.max_tgt_length = max_tgt_length
         | 
| 103 | 
            +
                    self.patch_image_size = patch_image_size
         | 
| 104 | 
            +
                    self.num_bins = num_bins
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    if imagenet_default_mean_and_std:
         | 
| 107 | 
            +
                        mean = IMAGENET_DEFAULT_MEAN
         | 
| 108 | 
            +
                        std = IMAGENET_DEFAULT_STD
         | 
| 109 | 
            +
                    else:
         | 
| 110 | 
            +
                        mean = [0.5, 0.5, 0.5]
         | 
| 111 | 
            +
                        std = [0.5, 0.5, 0.5]
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    # for positioning
         | 
| 114 | 
            +
                    self.positioning_transform = T.Compose([
         | 
| 115 | 
            +
                        T.RandomResize([patch_image_size], max_size=patch_image_size),
         | 
| 116 | 
            +
                        T.ToTensor(),
         | 
| 117 | 
            +
                        T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
         | 
| 118 | 
            +
                    ])
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def __getitem__(self, index):
         | 
| 121 | 
            +
                    uniq_id, base64_str, text, region_coord = self.dataset[index]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
         | 
| 124 | 
            +
                    w, h = image.size
         | 
| 125 | 
            +
                    boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
         | 
| 126 | 
            +
                    x0, y0, x1, y1 = region_coord.strip().split(',')
         | 
| 127 | 
            +
                    region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
         | 
| 128 | 
            +
                    boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
         | 
| 129 | 
            +
                    boxes_target["labels"] = np.array([0])
         | 
| 130 | 
            +
                    boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
         | 
| 133 | 
            +
                    resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
         | 
| 134 | 
            +
                    patch_mask = torch.tensor([True])
         | 
| 135 | 
            +
                    quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
         | 
| 136 | 
            +
                    quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
         | 
| 137 | 
            +
                    quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
         | 
| 138 | 
            +
                    quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
         | 
| 139 | 
            +
                    region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
         | 
| 140 | 
            +
                    src_caption = self.pre_caption(text, self.max_src_length)
         | 
| 141 | 
            +
                    src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
         | 
| 142 | 
            +
                    tgt_item = self.encode_text(region_coord, use_bpe=False)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    src_item = torch.cat([self.bos_item, src_item, self.eos_item])
         | 
| 145 | 
            +
                    target_item = torch.cat([tgt_item, self.eos_item])
         | 
| 146 | 
            +
                    prev_output_item = torch.cat([self.bos_item, tgt_item])
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    example = {
         | 
| 149 | 
            +
                        "id": uniq_id,
         | 
| 150 | 
            +
                        "source": src_item,
         | 
| 151 | 
            +
                        "patch_image": patch_image,
         | 
| 152 | 
            +
                        "patch_mask": patch_mask,
         | 
| 153 | 
            +
                        "target": target_item,
         | 
| 154 | 
            +
                        "prev_output_tokens": prev_output_item,
         | 
| 155 | 
            +
                        "w_resize_ratio": resize_w / w,
         | 
| 156 | 
            +
                        "h_resize_ratio": resize_h / h,
         | 
| 157 | 
            +
                        "region_coord": region
         | 
| 158 | 
            +
                    }
         | 
| 159 | 
            +
                    return example
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def collater(self, samples, pad_to_length=None):
         | 
| 162 | 
            +
                    """Merge a list of samples to form a mini-batch.
         | 
| 163 | 
            +
                    Args:
         | 
| 164 | 
            +
                        samples (List[dict]): samples to collate
         | 
| 165 | 
            +
                    Returns:
         | 
| 166 | 
            +
                        dict: a mini-batch with the following keys:
         | 
| 167 | 
            +
                    """
         | 
| 168 | 
            +
                    return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
         | 
    	
        data/mm_data/vqa_gen_dataset.py
    ADDED
    
    | @@ -0,0 +1,211 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 4 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 5 | 
            +
            from io import BytesIO
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            import warnings
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import numpy as np
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
            import base64
         | 
| 13 | 
            +
            from torchvision import transforms
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from PIL import Image, ImageFile
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from data import data_utils
         | 
| 18 | 
            +
            from data.ofa_dataset import OFADataset
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ImageFile.LOAD_TRUNCATED_IMAGES = True
         | 
| 21 | 
            +
            ImageFile.MAX_IMAGE_PIXELS = None
         | 
| 22 | 
            +
            Image.MAX_IMAGE_PIXELS = None
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 25 | 
            +
            warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
         | 
| 28 | 
            +
            IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            def collate(samples, pad_idx, eos_idx):
         | 
| 32 | 
            +
                if len(samples) == 0:
         | 
| 33 | 
            +
                    return {}
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def merge(key):
         | 
| 36 | 
            +
                    return data_utils.collate_tokens(
         | 
| 37 | 
            +
                        [s[key] for s in samples],
         | 
| 38 | 
            +
                        pad_idx,
         | 
| 39 | 
            +
                        eos_idx=eos_idx,
         | 
| 40 | 
            +
                    )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                id = np.array([s["id"] for s in samples])
         | 
| 43 | 
            +
                src_tokens = merge("source")
         | 
| 44 | 
            +
                src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
         | 
| 47 | 
            +
                patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                conf = None
         | 
| 50 | 
            +
                if samples[0].get("conf", None) is not None:
         | 
| 51 | 
            +
                    conf = torch.cat([s['conf'] for s in samples], dim=0)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                ref_dict = None
         | 
| 54 | 
            +
                if samples[0].get("ref_dict", None) is not None:
         | 
| 55 | 
            +
                    ref_dict = np.array([s['ref_dict'] for s in samples])
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                constraint_masks = None
         | 
| 58 | 
            +
                if samples[0].get("constraint_mask", None) is not None:
         | 
| 59 | 
            +
                    constraint_masks = merge("constraint_mask")
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                decoder_prompts = None
         | 
| 62 | 
            +
                if samples[0].get("decoder_prompt", None) is not None:
         | 
| 63 | 
            +
                    decoder_prompts = np.array([s['decoder_prompt'].tolist() for s in samples])
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                prev_output_tokens = None
         | 
| 66 | 
            +
                target = None
         | 
| 67 | 
            +
                if samples[0].get("target", None) is not None:
         | 
| 68 | 
            +
                    target = merge("target")
         | 
| 69 | 
            +
                    tgt_lengths = torch.LongTensor(
         | 
| 70 | 
            +
                        [s["target"].ne(pad_idx).long().sum() for s in samples]
         | 
| 71 | 
            +
                    )
         | 
| 72 | 
            +
                    ntokens = tgt_lengths.sum().item()
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    if samples[0].get("prev_output_tokens", None) is not None:
         | 
| 75 | 
            +
                        prev_output_tokens = merge("prev_output_tokens")
         | 
| 76 | 
            +
                else:
         | 
| 77 | 
            +
                    ntokens = src_lengths.sum().item()
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                batch = {
         | 
| 80 | 
            +
                    "id": id,
         | 
| 81 | 
            +
                    "nsentences": len(samples),
         | 
| 82 | 
            +
                    "ntokens": ntokens,
         | 
| 83 | 
            +
                    "net_input": {
         | 
| 84 | 
            +
                        "src_tokens": src_tokens,
         | 
| 85 | 
            +
                        "src_lengths": src_lengths,
         | 
| 86 | 
            +
                        "patch_images": patch_images,
         | 
| 87 | 
            +
                        "patch_masks": patch_masks,
         | 
| 88 | 
            +
                        "prev_output_tokens": prev_output_tokens
         | 
| 89 | 
            +
                    },
         | 
| 90 | 
            +
                    "conf": conf,
         | 
| 91 | 
            +
                    "ref_dict": ref_dict,
         | 
| 92 | 
            +
                    "constraint_masks": constraint_masks,
         | 
| 93 | 
            +
                    "decoder_prompts": decoder_prompts,
         | 
| 94 | 
            +
                    "target": target
         | 
| 95 | 
            +
                }
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                return batch
         | 
| 98 | 
            +
             | 
| 99 | 
            +
             | 
| 100 | 
            +
            class VqaGenDataset(OFADataset):
         | 
| 101 | 
            +
                def __init__(
         | 
| 102 | 
            +
                    self,
         | 
| 103 | 
            +
                    split,
         | 
| 104 | 
            +
                    dataset,
         | 
| 105 | 
            +
                    bpe,
         | 
| 106 | 
            +
                    src_dict,
         | 
| 107 | 
            +
                    tgt_dict=None,
         | 
| 108 | 
            +
                    max_src_length=128,
         | 
| 109 | 
            +
                    max_object_length=30,
         | 
| 110 | 
            +
                    max_tgt_length=30,
         | 
| 111 | 
            +
                    patch_image_size=224,
         | 
| 112 | 
            +
                    add_object=False,
         | 
| 113 | 
            +
                    constraint_trie=None,
         | 
| 114 | 
            +
                    imagenet_default_mean_and_std=False,
         | 
| 115 | 
            +
                    prompt_type="none"
         | 
| 116 | 
            +
                ):
         | 
| 117 | 
            +
                    super().__init__(split, dataset, bpe, src_dict, tgt_dict)
         | 
| 118 | 
            +
                    self.max_src_length = max_src_length
         | 
| 119 | 
            +
                    self.max_object_length = max_object_length
         | 
| 120 | 
            +
                    self.max_tgt_length = max_tgt_length
         | 
| 121 | 
            +
                    self.patch_image_size = patch_image_size
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.add_object = add_object
         | 
| 124 | 
            +
                    self.constraint_trie = constraint_trie
         | 
| 125 | 
            +
                    self.prompt_type = prompt_type
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    if imagenet_default_mean_and_std:
         | 
| 128 | 
            +
                        mean = IMAGENET_DEFAULT_MEAN
         | 
| 129 | 
            +
                        std = IMAGENET_DEFAULT_STD
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        mean = [0.5, 0.5, 0.5]
         | 
| 132 | 
            +
                        std = [0.5, 0.5, 0.5]
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    self.patch_resize_transform = transforms.Compose([
         | 
| 135 | 
            +
                        lambda image: image.convert("RGB"),
         | 
| 136 | 
            +
                        transforms.Resize((patch_image_size, patch_image_size), interpolation=Image.BICUBIC),
         | 
| 137 | 
            +
                        transforms.ToTensor(),
         | 
| 138 | 
            +
                        transforms.Normalize(mean=mean, std=std),
         | 
| 139 | 
            +
                    ])
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def __getitem__(self, index):
         | 
| 142 | 
            +
                    item = self.dataset[index]
         | 
| 143 | 
            +
                    if len(item) == 5:
         | 
| 144 | 
            +
                        uniq_id, image, question, ref, predict_objects = item
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        uniq_id, image, question, ref, predict_objects, caption = item
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    image = Image.open(BytesIO(base64.urlsafe_b64decode(image)))
         | 
| 149 | 
            +
                    patch_image = self.patch_resize_transform(image)
         | 
| 150 | 
            +
                    patch_mask = torch.tensor([True])
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    question = self.pre_question(question, self.max_src_length)
         | 
| 153 | 
            +
                    question = question + '?' if not question.endswith('?') else question
         | 
| 154 | 
            +
                    src_item = self.encode_text(' {}'.format(question))
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    ref_dict = {item.split('|!+')[1]: float(item.split('|!+')[0]) for item in ref.split('&&')}
         | 
| 157 | 
            +
                    answer = max(ref_dict, key=ref_dict.get)
         | 
| 158 | 
            +
                    conf = torch.tensor([ref_dict[answer]])
         | 
| 159 | 
            +
                    tgt_item = self.encode_text(" {}".format(answer))
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    if self.add_object and predict_objects is not None:
         | 
| 162 | 
            +
                        predict_object_seq = ' '.join(predict_objects.strip().split('&&')[:self.max_object_length])
         | 
| 163 | 
            +
                        predict_object_item = self.encode_text(" object: {}".format(predict_object_seq))
         | 
| 164 | 
            +
                        src_item = torch.cat([src_item, predict_object_item])
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    src_item = torch.cat([self.bos_item, src_item, self.eos_item])
         | 
| 167 | 
            +
                    if self.prompt_type == 'none':
         | 
| 168 | 
            +
                        prev_output_item = torch.cat([self.bos_item, tgt_item])
         | 
| 169 | 
            +
                        target_item = torch.cat([prev_output_item[1:], self.eos_item])
         | 
| 170 | 
            +
                        decoder_prompt = self.bos_item
         | 
| 171 | 
            +
                    elif self.prompt_type == 'src':
         | 
| 172 | 
            +
                        prev_output_item = torch.cat([src_item, tgt_item])
         | 
| 173 | 
            +
                        target_item = torch.cat([prev_output_item[1:], self.eos_item])
         | 
| 174 | 
            +
                        decoder_prompt = src_item
         | 
| 175 | 
            +
                    elif self.prompt_type == 'prev_output':
         | 
| 176 | 
            +
                        prev_output_item = torch.cat([src_item[:-1], tgt_item])
         | 
| 177 | 
            +
                        target_item = torch.cat([prev_output_item[1:], self.eos_item])
         | 
| 178 | 
            +
                        decoder_prompt = src_item[:-1]
         | 
| 179 | 
            +
                    else:
         | 
| 180 | 
            +
                        raise NotImplementedError
         | 
| 181 | 
            +
                    target_item[:-len(tgt_item)-1] = self.tgt_dict.pad()
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    example = {
         | 
| 184 | 
            +
                        "id": uniq_id,
         | 
| 185 | 
            +
                        "source": src_item,
         | 
| 186 | 
            +
                        "patch_image": patch_image,
         | 
| 187 | 
            +
                        "patch_mask": patch_mask,
         | 
| 188 | 
            +
                        "target": target_item,
         | 
| 189 | 
            +
                        "prev_output_tokens": prev_output_item,
         | 
| 190 | 
            +
                        "decoder_prompt": decoder_prompt,
         | 
| 191 | 
            +
                        "ref_dict": ref_dict,
         | 
| 192 | 
            +
                        "conf": conf,
         | 
| 193 | 
            +
                    }
         | 
| 194 | 
            +
                    if self.constraint_trie is not None:
         | 
| 195 | 
            +
                        constraint_mask = torch.zeros((len(target_item), len(self.tgt_dict))).bool()
         | 
| 196 | 
            +
                        start_idx = len(target_item) - len(tgt_item) - 1
         | 
| 197 | 
            +
                        for i in range(len(target_item)-len(tgt_item)-1, len(target_item)):
         | 
| 198 | 
            +
                            constraint_prefix_token = [self.tgt_dict.bos()] + target_item[start_idx:i].tolist()
         | 
| 199 | 
            +
                            constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
         | 
| 200 | 
            +
                            constraint_mask[i][constraint_nodes] = True
         | 
| 201 | 
            +
                        example["constraint_mask"] = constraint_mask
         | 
| 202 | 
            +
                    return example
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def collater(self, samples, pad_to_length=None):
         | 
| 205 | 
            +
                    """Merge a list of samples to form a mini-batch.
         | 
| 206 | 
            +
                    Args:
         | 
| 207 | 
            +
                        samples (List[dict]): samples to collate
         | 
| 208 | 
            +
                    Returns:
         | 
| 209 | 
            +
                        dict: a mini-batch with the following keys:
         | 
| 210 | 
            +
                    """
         | 
| 211 | 
            +
                    return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
         | 
    	
        data/ofa_dataset.py
    ADDED
    
    | @@ -0,0 +1,74 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import re
         | 
| 3 | 
            +
            import torch.utils.data
         | 
| 4 | 
            +
            from fairseq.data import FairseqDataset
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class OFADataset(FairseqDataset):
         | 
| 10 | 
            +
                def __init__(self, split, dataset, bpe, src_dict, tgt_dict):
         | 
| 11 | 
            +
                    self.split = split
         | 
| 12 | 
            +
                    self.dataset = dataset
         | 
| 13 | 
            +
                    self.bpe = bpe
         | 
| 14 | 
            +
                    self.src_dict = src_dict
         | 
| 15 | 
            +
                    self.tgt_dict = tgt_dict
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    self.bos = src_dict.bos()
         | 
| 18 | 
            +
                    self.eos = src_dict.eos()
         | 
| 19 | 
            +
                    self.pad = src_dict.pad()
         | 
| 20 | 
            +
                    self.bos_item = torch.LongTensor([self.bos])
         | 
| 21 | 
            +
                    self.eos_item = torch.LongTensor([self.eos])
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def __len__(self):
         | 
| 24 | 
            +
                    return len(self.dataset)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def encode_text(self, text, length=None, append_bos=False, append_eos=False, use_bpe=True):
         | 
| 27 | 
            +
                    s = self.tgt_dict.encode_line(
         | 
| 28 | 
            +
                        line=self.bpe.encode(text) if use_bpe else text,
         | 
| 29 | 
            +
                        add_if_not_exist=False,
         | 
| 30 | 
            +
                        append_eos=False
         | 
| 31 | 
            +
                    ).long()
         | 
| 32 | 
            +
                    if length is not None:
         | 
| 33 | 
            +
                        s = s[:length]
         | 
| 34 | 
            +
                    if append_bos:
         | 
| 35 | 
            +
                        s = torch.cat([self.bos_item, s])
         | 
| 36 | 
            +
                    if append_eos:
         | 
| 37 | 
            +
                        s = torch.cat([s, self.eos_item])
         | 
| 38 | 
            +
                    return s
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def pre_question(self, question, max_ques_words):
         | 
| 41 | 
            +
                    question = question.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ')
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    question = re.sub(
         | 
| 44 | 
            +
                        r"\s{2,}",
         | 
| 45 | 
            +
                        ' ',
         | 
| 46 | 
            +
                        question,
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    question = question.rstrip('\n')
         | 
| 49 | 
            +
                    question = question.strip(' ')
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # truncate question
         | 
| 52 | 
            +
                    question_words = question.split(' ')
         | 
| 53 | 
            +
                    if len(question_words) > max_ques_words:
         | 
| 54 | 
            +
                        question = ' '.join(question_words[:max_ques_words])
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    return question
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def pre_caption(self, caption, max_words):
         | 
| 59 | 
            +
                    caption = caption.lower().lstrip(",.!?*#:;~").replace('-', ' ').replace('/', ' ').replace('<person>', 'person')
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    caption = re.sub(
         | 
| 62 | 
            +
                        r"\s{2,}",
         | 
| 63 | 
            +
                        ' ',
         | 
| 64 | 
            +
                        caption,
         | 
| 65 | 
            +
                    )
         | 
| 66 | 
            +
                    caption = caption.rstrip('\n')
         | 
| 67 | 
            +
                    caption = caption.strip(' ')
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # truncate caption
         | 
| 70 | 
            +
                    caption_words = caption.split(' ')
         | 
| 71 | 
            +
                    if len(caption_words) > max_words:
         | 
| 72 | 
            +
                        caption = ' '.join(caption_words[:max_words])
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    return caption
         | 
    	
        datasets.md
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Datasets
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            We provide links to download our preprocessed dataset. If you would like to process the data on your own, we will soon provide scripts for you to do so. 
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ## Finetuning
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/caption_data/caption_data.zip"> Dataset for Caption </a>
         | 
| 8 | 
            +
             * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcoco_data/refcoco_data.zip"> Dataset for RefCOCO </a>
         | 
| 9 | 
            +
             * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocoplus_data/refcocoplus_data.zip"> Dataset for RefCOCO+ </a>
         | 
| 10 | 
            +
             * <a href="https://ofa-beijing.oss-cn-beijing.aliyuncs.com/datasets/refcocog_data/refcocog_data.zip"> Dataset for RefCOCOg </a>
         | 
    	
        evaluate.py
    ADDED
    
    | @@ -0,0 +1,156 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python3 -u
         | 
| 2 | 
            +
            # Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # This source code is licensed under the MIT license found in the
         | 
| 5 | 
            +
            # LICENSE file in the root directory of this source tree.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import logging
         | 
| 8 | 
            +
            import os
         | 
| 9 | 
            +
            import sys
         | 
| 10 | 
            +
            import json
         | 
| 11 | 
            +
            from itertools import chain
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import numpy as np
         | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.distributed as dist
         | 
| 16 | 
            +
            from fairseq import distributed_utils, options, tasks, utils
         | 
| 17 | 
            +
            from fairseq.dataclass.utils import convert_namespace_to_omegaconf
         | 
| 18 | 
            +
            from fairseq.logging import progress_bar
         | 
| 19 | 
            +
            from fairseq.utils import reset_logging
         | 
| 20 | 
            +
            from omegaconf import DictConfig
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from utils import checkpoint_utils
         | 
| 23 | 
            +
            from utils.eval_utils import eval_step
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            logging.basicConfig(
         | 
| 26 | 
            +
                format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
         | 
| 27 | 
            +
                datefmt="%Y-%m-%d %H:%M:%S",
         | 
| 28 | 
            +
                level=os.environ.get("LOGLEVEL", "INFO").upper(),
         | 
| 29 | 
            +
                stream=sys.stdout,
         | 
| 30 | 
            +
            )
         | 
| 31 | 
            +
            logger = logging.getLogger("ofa.evaluate")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def apply_half(t):
         | 
| 35 | 
            +
                if t.dtype is torch.float32:
         | 
| 36 | 
            +
                    return t.to(dtype=torch.half)
         | 
| 37 | 
            +
                return t
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            def main(cfg: DictConfig, **kwargs):
         | 
| 41 | 
            +
                utils.import_user_module(cfg.common)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                reset_logging()
         | 
| 44 | 
            +
                logger.info(cfg)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                assert (
         | 
| 47 | 
            +
                    cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
         | 
| 48 | 
            +
                ), "Must specify batch size either with --max-tokens or --batch-size"
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # Fix seed for stochastic decoding
         | 
| 51 | 
            +
                if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
         | 
| 52 | 
            +
                    np.random.seed(cfg.common.seed)
         | 
| 53 | 
            +
                    utils.set_torch_seed(cfg.common.seed)
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                use_fp16 = cfg.common.fp16
         | 
| 56 | 
            +
                use_cuda = torch.cuda.is_available() and not cfg.common.cpu
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                if use_cuda:
         | 
| 59 | 
            +
                    torch.cuda.set_device(cfg.distributed_training.device_id)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # Load ensemble
         | 
| 62 | 
            +
                overrides = eval(cfg.common_eval.model_overrides)
         | 
| 63 | 
            +
                logger.info("loading model(s) from {}".format(cfg.common_eval.path))
         | 
| 64 | 
            +
                models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
         | 
| 65 | 
            +
                    utils.split_paths(cfg.common_eval.path),
         | 
| 66 | 
            +
                    arg_overrides=overrides,
         | 
| 67 | 
            +
                    suffix=cfg.checkpoint.checkpoint_suffix,
         | 
| 68 | 
            +
                    strict=(cfg.checkpoint.checkpoint_shard_count == 1),
         | 
| 69 | 
            +
                    num_shards=cfg.checkpoint.checkpoint_shard_count,
         | 
| 70 | 
            +
                )
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
         | 
| 73 | 
            +
                task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                # Move models to GPU
         | 
| 76 | 
            +
                for model, ckpt_path in zip(models, utils.split_paths(cfg.common_eval.path)):
         | 
| 77 | 
            +
                    if kwargs['ema_eval']:
         | 
| 78 | 
            +
                        logger.info("loading EMA weights from {}".format(ckpt_path))
         | 
| 79 | 
            +
                        model.load_state_dict(checkpoint_utils.load_ema_from_checkpoint(ckpt_path)['model'])
         | 
| 80 | 
            +
                    model.eval()
         | 
| 81 | 
            +
                    if use_fp16:
         | 
| 82 | 
            +
                        model.half()
         | 
| 83 | 
            +
                    if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
         | 
| 84 | 
            +
                        model.cuda()
         | 
| 85 | 
            +
                    model.prepare_for_inference_(cfg)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                # Load dataset (possibly sharded)
         | 
| 88 | 
            +
                itr = task.get_batch_iterator(
         | 
| 89 | 
            +
                    dataset=task.dataset(cfg.dataset.gen_subset),
         | 
| 90 | 
            +
                    max_tokens=cfg.dataset.max_tokens,
         | 
| 91 | 
            +
                    max_sentences=cfg.dataset.batch_size,
         | 
| 92 | 
            +
                    max_positions=utils.resolve_max_positions(
         | 
| 93 | 
            +
                        task.max_positions(), *[m.max_positions() for m in models]
         | 
| 94 | 
            +
                    ),
         | 
| 95 | 
            +
                    ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
         | 
| 96 | 
            +
                    required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
         | 
| 97 | 
            +
                    seed=cfg.common.seed,
         | 
| 98 | 
            +
                    num_shards=cfg.distributed_training.distributed_world_size,
         | 
| 99 | 
            +
                    shard_id=cfg.distributed_training.distributed_rank,
         | 
| 100 | 
            +
                    num_workers=cfg.dataset.num_workers,
         | 
| 101 | 
            +
                    data_buffer_size=cfg.dataset.data_buffer_size,
         | 
| 102 | 
            +
                ).next_epoch_itr(shuffle=False)
         | 
| 103 | 
            +
                progress = progress_bar.progress_bar(
         | 
| 104 | 
            +
                    itr,
         | 
| 105 | 
            +
                    log_format=cfg.common.log_format,
         | 
| 106 | 
            +
                    log_interval=cfg.common.log_interval,
         | 
| 107 | 
            +
                    default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
         | 
| 108 | 
            +
                )
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                # Initialize generator
         | 
| 111 | 
            +
                generator = task.build_generator(models, cfg.generation)
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                results = []
         | 
| 114 | 
            +
                score_sum = torch.FloatTensor([0]).cuda()
         | 
| 115 | 
            +
                score_cnt = torch.FloatTensor([0]).cuda()
         | 
| 116 | 
            +
                for sample in progress:
         | 
| 117 | 
            +
                    if "net_input" not in sample:
         | 
| 118 | 
            +
                        continue
         | 
| 119 | 
            +
                    sample = utils.move_to_cuda(sample) if use_cuda else sample
         | 
| 120 | 
            +
                    sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
         | 
| 121 | 
            +
                    with torch.no_grad():
         | 
| 122 | 
            +
                        result, scores = eval_step(task, generator, models, sample)
         | 
| 123 | 
            +
                    results += result
         | 
| 124 | 
            +
                    score_sum += sum(scores) if scores is not None else 0
         | 
| 125 | 
            +
                    score_cnt += len(scores) if scores is not None else 0
         | 
| 126 | 
            +
                    progress.log({"sentences": sample["nsentences"]})
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                gather_results = None
         | 
| 129 | 
            +
                if cfg.distributed_training.distributed_world_size > 1:
         | 
| 130 | 
            +
                    gather_results = [None for _ in range(dist.get_world_size())]
         | 
| 131 | 
            +
                    dist.all_gather_object(gather_results, results)
         | 
| 132 | 
            +
                    dist.all_reduce(score_sum.data)
         | 
| 133 | 
            +
                    dist.all_reduce(score_cnt.data)
         | 
| 134 | 
            +
                if score_cnt.item() > 0:
         | 
| 135 | 
            +
                    logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
         | 
| 136 | 
            +
                        score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
         | 
| 137 | 
            +
                    ))
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
         | 
| 140 | 
            +
                    os.makedirs(cfg.common_eval.results_path, exist_ok=True)
         | 
| 141 | 
            +
                    output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
         | 
| 142 | 
            +
                    gather_results = list(chain(*gather_results)) if gather_results is not None else results
         | 
| 143 | 
            +
                    with open(output_path, 'w') as fw:
         | 
| 144 | 
            +
                        json.dump(gather_results, fw)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            def cli_main():
         | 
| 148 | 
            +
                parser = options.get_generation_parser()
         | 
| 149 | 
            +
                parser.add_argument("--ema-eval", action='store_true', help="Use EMA weights to make evaluation.")
         | 
| 150 | 
            +
                args = options.parse_args_and_arch(parser)
         | 
| 151 | 
            +
                cfg = convert_namespace_to_omegaconf(args)
         | 
| 152 | 
            +
                distributed_utils.call_main(cfg, main, ema_eval=args.ema_eval)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
            if __name__ == "__main__":
         | 
| 156 | 
            +
                cli_main()
         | 
    	
        fairseq/.github/ISSUE_TEMPLATE.md
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
         | 
    	
        fairseq/.github/ISSUE_TEMPLATE/bug_report.md
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            name: 🐛 Bug Report
         | 
| 3 | 
            +
            about: Submit a bug report to help us improve
         | 
| 4 | 
            +
            labels: 'bug, needs triage'
         | 
| 5 | 
            +
            ---
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            ## 🐛 Bug
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            <!-- A clear and concise description of what the bug is. -->
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ### To Reproduce
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Steps to reproduce the behavior (**always include the command you ran**):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            1. Run cmd '....'
         | 
| 16 | 
            +
            2. See error
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            #### Code sample
         | 
| 22 | 
            +
            <!-- Ideally attach a minimal code sample to reproduce the decried issue.
         | 
| 23 | 
            +
            Minimal means having the shortest code but still preserving the bug. -->
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ### Expected behavior
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            <!-- A clear and concise description of what you expected to happen. -->
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            ### Environment
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             - fairseq Version (e.g., 1.0 or main):
         | 
| 32 | 
            +
             - PyTorch Version (e.g., 1.0)
         | 
| 33 | 
            +
             - OS (e.g., Linux):
         | 
| 34 | 
            +
             - How you installed fairseq (`pip`, source):
         | 
| 35 | 
            +
             - Build command you used (if compiling from source):
         | 
| 36 | 
            +
             - Python version:
         | 
| 37 | 
            +
             - CUDA/cuDNN version:
         | 
| 38 | 
            +
             - GPU models and configuration:
         | 
| 39 | 
            +
             - Any other relevant information:
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            ### Additional context
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            <!-- Add any other context about the problem here. -->
         | 
    	
        fairseq/.github/ISSUE_TEMPLATE/documentation.md
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            name: 📚 Documentation/Typos
         | 
| 3 | 
            +
            about: Report an issue related to documentation or a typo
         | 
| 4 | 
            +
            labels: 'documentation, needs triage'
         | 
| 5 | 
            +
            ---
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            ## 📚 Documentation
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            For typos and doc fixes, please go ahead and:
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            1. Create an issue.
         | 
| 12 | 
            +
            2. Fix the typo.
         | 
| 13 | 
            +
            3. Submit a PR.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            Thanks!
         | 
    	
        fairseq/.github/ISSUE_TEMPLATE/feature_request.md
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            name: 🚀 Feature Request
         | 
| 3 | 
            +
            about: Submit a proposal/request for a new feature
         | 
| 4 | 
            +
            labels: 'enhancement, help wanted, needs triage'
         | 
| 5 | 
            +
            ---
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            ## 🚀 Feature Request
         | 
| 8 | 
            +
            <!-- A clear and concise description of the feature proposal -->
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            ### Motivation
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            ### Pitch
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            <!-- A clear and concise description of what you want to happen. -->
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            ### Alternatives
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            ### Additional context
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            <!-- Add any other context or screenshots about the feature request here. -->
         | 
    	
        fairseq/.github/ISSUE_TEMPLATE/how-to-question.md
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            name: ❓ Questions/Help
         | 
| 3 | 
            +
            about: If you have questions, please first search existing issues and docs
         | 
| 4 | 
            +
            labels: 'question, needs triage'
         | 
| 5 | 
            +
            ---
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            ## ❓ Questions and Help
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            ### Before asking:
         | 
| 10 | 
            +
            1. search the issues.
         | 
| 11 | 
            +
            2. search the docs.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            <!-- If you still can't find what you need: -->
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            #### What is your question?
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            #### Code
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            <!-- Please paste a code snippet if your question requires it! -->
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            #### What have you tried?
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            #### What's your environment?
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             - fairseq Version (e.g., 1.0 or main):
         | 
| 26 | 
            +
             - PyTorch Version (e.g., 1.0)
         | 
| 27 | 
            +
             - OS (e.g., Linux):
         | 
| 28 | 
            +
             - How you installed fairseq (`pip`, source):
         | 
| 29 | 
            +
             - Build command you used (if compiling from source):
         | 
| 30 | 
            +
             - Python version:
         | 
| 31 | 
            +
             - CUDA/cuDNN version:
         | 
| 32 | 
            +
             - GPU models and configuration:
         | 
| 33 | 
            +
             - Any other relevant information:
         | 
    	
        fairseq/.github/PULL_REQUEST_TEMPLATE.md
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Before submitting
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            - [ ] Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
         | 
| 4 | 
            +
            - [ ] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/main/CONTRIBUTING.md)?
         | 
| 5 | 
            +
            - [ ] Did you make sure to update the docs?
         | 
| 6 | 
            +
            - [ ] Did you write any new necessary tests?
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            ## What does this PR do?
         | 
| 9 | 
            +
            Fixes # (issue).
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ## PR review
         | 
| 12 | 
            +
            Anyone in the community is free to review the PR once the tests have passed.
         | 
| 13 | 
            +
            If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ## Did you have fun?
         | 
| 16 | 
            +
            Make sure you had fun coding 🙃
         | 
    	
        fairseq/.github/stale.yml
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Configuration for probot-stale - https://github.com/probot/stale
         | 
| 2 | 
            +
            # Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml
         | 
| 3 | 
            +
            # Number of days of inactivity before an issue becomes stale
         | 
| 4 | 
            +
            daysUntilStale: 90
         | 
| 5 | 
            +
            # Number of days of inactivity before a stale issue is closed
         | 
| 6 | 
            +
            daysUntilClose: 7
         | 
| 7 | 
            +
            # Issues with these labels will never be considered stale
         | 
| 8 | 
            +
            exemptLabels:
         | 
| 9 | 
            +
              - bug
         | 
| 10 | 
            +
            # Label to use when marking an issue as stale
         | 
| 11 | 
            +
            staleLabel: stale
         | 
| 12 | 
            +
            issues:
         | 
| 13 | 
            +
              # Comment to post when marking an issue as stale.
         | 
| 14 | 
            +
              markComment: >
         | 
| 15 | 
            +
                This issue has been automatically marked as stale.
         | 
| 16 | 
            +
                **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open.
         | 
| 17 | 
            +
                We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment!
         | 
| 18 | 
            +
              # Comment to post when closing a stale issue.
         | 
| 19 | 
            +
              closeComment: >
         | 
| 20 | 
            +
                Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you!
         | 
| 21 | 
            +
            pulls:
         | 
| 22 | 
            +
              # Comment to post when marking a pull request as stale.
         | 
| 23 | 
            +
              markComment: >
         | 
| 24 | 
            +
                This pull request has been automatically marked as stale.
         | 
| 25 | 
            +
                **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open.
         | 
| 26 | 
            +
                We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated.
         | 
| 27 | 
            +
              # Comment to post when closing a stale pull request.
         | 
| 28 | 
            +
              closeComment: >
         | 
| 29 | 
            +
                Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you!
         | 
| 30 | 
            +
             | 
    	
        fairseq/.github/workflows/build.yml
    ADDED
    
    | @@ -0,0 +1,55 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: build
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            on:
         | 
| 4 | 
            +
              # Trigger the workflow on push to main or any pull request
         | 
| 5 | 
            +
              push:
         | 
| 6 | 
            +
                branches:
         | 
| 7 | 
            +
                  - main
         | 
| 8 | 
            +
              pull_request:
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            jobs:
         | 
| 11 | 
            +
              build:
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                strategy:
         | 
| 14 | 
            +
                  max-parallel: 4
         | 
| 15 | 
            +
                  matrix:
         | 
| 16 | 
            +
                    platform: [ubuntu-latest, macos-latest]
         | 
| 17 | 
            +
                    python-version: [3.6, 3.7]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                runs-on: ${{ matrix.platform }}
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                steps:
         | 
| 22 | 
            +
                - uses: actions/checkout@v2
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                - name: Set up Python ${{ matrix.python-version }}
         | 
| 25 | 
            +
                  uses: actions/setup-python@v2
         | 
| 26 | 
            +
                  with:
         | 
| 27 | 
            +
                    python-version: ${{ matrix.python-version }}
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                - name: Conditionally install pytorch
         | 
| 30 | 
            +
                  if: matrix.platform == 'windows-latest'
         | 
| 31 | 
            +
                  run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                - name: Install locally
         | 
| 34 | 
            +
                  run: |
         | 
| 35 | 
            +
                    python -m pip install --upgrade pip
         | 
| 36 | 
            +
                    git submodule update --init --recursive
         | 
| 37 | 
            +
                    python setup.py build_ext --inplace
         | 
| 38 | 
            +
                    python -m pip install --editable .
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                - name: Install optional test requirements
         | 
| 41 | 
            +
                  run: |
         | 
| 42 | 
            +
                    python -m pip install iopath transformers pyarrow
         | 
| 43 | 
            +
                    python -m pip install git+https://github.com/facebookresearch/fairscale.git@main
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                - name: Lint with flake8
         | 
| 46 | 
            +
                  run: |
         | 
| 47 | 
            +
                    pip install flake8
         | 
| 48 | 
            +
                    # stop the build if there are Python syntax errors or undefined names
         | 
| 49 | 
            +
                    flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron
         | 
| 50 | 
            +
                    # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
         | 
| 51 | 
            +
                    flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                - name: Run tests
         | 
| 54 | 
            +
                  run: |
         | 
| 55 | 
            +
                      python setup.py test
         | 
    	
        fairseq/.github/workflows/build_wheels.yml
    ADDED
    
    | @@ -0,0 +1,41 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            name: build_wheels
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            on:
         | 
| 4 | 
            +
              push:
         | 
| 5 | 
            +
                branches:
         | 
| 6 | 
            +
                  - v[0-9]+.[0-9]+.[x0-9]+
         | 
| 7 | 
            +
                tags:
         | 
| 8 | 
            +
                  - v*
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            jobs:
         | 
| 11 | 
            +
              build_wheels:
         | 
| 12 | 
            +
                name: Build wheels on ${{ matrix.os }}
         | 
| 13 | 
            +
                runs-on: ${{ matrix.os }}
         | 
| 14 | 
            +
                strategy:
         | 
| 15 | 
            +
                  matrix:
         | 
| 16 | 
            +
                    os: [ubuntu-latest, macos-latest]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                steps:
         | 
| 19 | 
            +
                  - uses: actions/checkout@v2
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                  - name: Install Python
         | 
| 22 | 
            +
                    uses: actions/setup-python@v2
         | 
| 23 | 
            +
                    with:
         | 
| 24 | 
            +
                      python-version: '3.7'
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  - name: Install cibuildwheel
         | 
| 27 | 
            +
                    run: |
         | 
| 28 | 
            +
                      python -m pip install cibuildwheel
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  - name: Build wheels for CPython
         | 
| 31 | 
            +
                    run: |
         | 
| 32 | 
            +
                      python -m cibuildwheel --output-dir dist
         | 
| 33 | 
            +
                    env:
         | 
| 34 | 
            +
                      CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64"
         | 
| 35 | 
            +
                      CIBW_MANYLINUX_X86_64_IMAGE: manylinux1
         | 
| 36 | 
            +
                      CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install .
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                  - uses: actions/upload-artifact@v2
         | 
| 39 | 
            +
                    with:
         | 
| 40 | 
            +
                      name: wheels
         | 
| 41 | 
            +
                      path: ./dist/*.whl
         | 
    	
        fairseq/.gitignore
    ADDED
    
    | @@ -0,0 +1,136 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # JetBrains PyCharm IDE
         | 
| 2 | 
            +
            .idea/
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # Byte-compiled / optimized / DLL files
         | 
| 5 | 
            +
            __pycache__/
         | 
| 6 | 
            +
            *.py[cod]
         | 
| 7 | 
            +
            *$py.class
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            # C extensions
         | 
| 10 | 
            +
            *.so
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # macOS dir files
         | 
| 13 | 
            +
            .DS_Store
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Distribution / packaging
         | 
| 16 | 
            +
            .Python
         | 
| 17 | 
            +
            env/
         | 
| 18 | 
            +
            build/
         | 
| 19 | 
            +
            develop-eggs/
         | 
| 20 | 
            +
            dist/
         | 
| 21 | 
            +
            downloads/
         | 
| 22 | 
            +
            eggs/
         | 
| 23 | 
            +
            .eggs/
         | 
| 24 | 
            +
            lib/
         | 
| 25 | 
            +
            lib64/
         | 
| 26 | 
            +
            parts/
         | 
| 27 | 
            +
            sdist/
         | 
| 28 | 
            +
            var/
         | 
| 29 | 
            +
            wheels/
         | 
| 30 | 
            +
            *.egg-info/
         | 
| 31 | 
            +
            .installed.cfg
         | 
| 32 | 
            +
            *.egg
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # Checkpoints
         | 
| 35 | 
            +
            checkpoints
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            # PyInstaller
         | 
| 38 | 
            +
            #  Usually these files are written by a python script from a template
         | 
| 39 | 
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         | 
| 40 | 
            +
            *.manifest
         | 
| 41 | 
            +
            *.spec
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            # Installer logs
         | 
| 44 | 
            +
            pip-log.txt
         | 
| 45 | 
            +
            pip-delete-this-directory.txt
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            # Unit test / coverage reports
         | 
| 48 | 
            +
            htmlcov/
         | 
| 49 | 
            +
            .tox/
         | 
| 50 | 
            +
            .coverage
         | 
| 51 | 
            +
            .coverage.*
         | 
| 52 | 
            +
            .cache
         | 
| 53 | 
            +
            nosetests.xml
         | 
| 54 | 
            +
            coverage.xml
         | 
| 55 | 
            +
            *.cover
         | 
| 56 | 
            +
            .hypothesis/
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Translations
         | 
| 59 | 
            +
            *.mo
         | 
| 60 | 
            +
            *.pot
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            # Django stuff:
         | 
| 63 | 
            +
            *.log
         | 
| 64 | 
            +
            local_settings.py
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            # Flask stuff:
         | 
| 67 | 
            +
            instance/
         | 
| 68 | 
            +
            .webassets-cache
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # Scrapy stuff:
         | 
| 71 | 
            +
            .scrapy
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            # Sphinx documentation
         | 
| 74 | 
            +
            docs/_build/
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            # PyBuilder
         | 
| 77 | 
            +
            target/
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            # Jupyter Notebook
         | 
| 80 | 
            +
            .ipynb_checkpoints
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # pyenv
         | 
| 83 | 
            +
            .python-version
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            # celery beat schedule file
         | 
| 86 | 
            +
            celerybeat-schedule
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            # SageMath parsed files
         | 
| 89 | 
            +
            *.sage.py
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            # dotenv
         | 
| 92 | 
            +
            .env
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            # virtualenv
         | 
| 95 | 
            +
            .venv
         | 
| 96 | 
            +
            venv/
         | 
| 97 | 
            +
            ENV/
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            # Spyder project settings
         | 
| 100 | 
            +
            .spyderproject
         | 
| 101 | 
            +
            .spyproject
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            # Rope project settings
         | 
| 104 | 
            +
            .ropeproject
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            # mkdocs documentation
         | 
| 107 | 
            +
            /site
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            # mypy
         | 
| 110 | 
            +
            .mypy_cache/
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            # Generated files
         | 
| 113 | 
            +
            /fairseq/temporal_convolution_tbc
         | 
| 114 | 
            +
            /fairseq/modules/*_layer/*_forward.cu
         | 
| 115 | 
            +
            /fairseq/modules/*_layer/*_backward.cu
         | 
| 116 | 
            +
            /fairseq/version.py
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            # data
         | 
| 119 | 
            +
            data-bin/
         | 
| 120 | 
            +
             | 
| 121 | 
            +
            # reranking
         | 
| 122 | 
            +
            /examples/reranking/rerank_data
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            # Cython-generated C++ source files
         | 
| 125 | 
            +
            /fairseq/data/data_utils_fast.cpp
         | 
| 126 | 
            +
            /fairseq/data/token_block_utils_fast.cpp
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            # VSCODE
         | 
| 129 | 
            +
            .vscode/ftp-sync.json
         | 
| 130 | 
            +
            .vscode/settings.json
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            # Experimental Folder
         | 
| 133 | 
            +
            experimental/*
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            # Weights and Biases logs
         | 
| 136 | 
            +
            wandb/
         | 
    	
        fairseq/.gitmodules
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [submodule "fairseq/model_parallel/megatron"]
         | 
| 2 | 
            +
                path = fairseq/model_parallel/megatron
         | 
| 3 | 
            +
                url = https://github.com/ngoyal2707/Megatron-LM
         | 
| 4 | 
            +
                branch = fairseq
         | 
    	
        fairseq/CODE_OF_CONDUCT.md
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Code of Conduct
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            ## Our Pledge
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            In the interest of fostering an open and welcoming environment, we as
         | 
| 6 | 
            +
            contributors and maintainers pledge to make participation in our project and
         | 
| 7 | 
            +
            our community a harassment-free experience for everyone, regardless of age, body
         | 
| 8 | 
            +
            size, disability, ethnicity, sex characteristics, gender identity and expression,
         | 
| 9 | 
            +
            level of experience, education, socio-economic status, nationality, personal
         | 
| 10 | 
            +
            appearance, race, religion, or sexual identity and orientation.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            ## Our Standards
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Examples of behavior that contributes to creating a positive environment
         | 
| 15 | 
            +
            include:
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            * Using welcoming and inclusive language
         | 
| 18 | 
            +
            * Being respectful of differing viewpoints and experiences
         | 
| 19 | 
            +
            * Gracefully accepting constructive criticism
         | 
| 20 | 
            +
            * Focusing on what is best for the community
         | 
| 21 | 
            +
            * Showing empathy towards other community members
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            Examples of unacceptable behavior by participants include:
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            * The use of sexualized language or imagery and unwelcome sexual attention or
         | 
| 26 | 
            +
              advances
         | 
| 27 | 
            +
            * Trolling, insulting/derogatory comments, and personal or political attacks
         | 
| 28 | 
            +
            * Public or private harassment
         | 
| 29 | 
            +
            * Publishing others' private information, such as a physical or electronic
         | 
| 30 | 
            +
              address, without explicit permission
         | 
| 31 | 
            +
            * Other conduct which could reasonably be considered inappropriate in a
         | 
| 32 | 
            +
              professional setting
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            ## Our Responsibilities
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            Project maintainers are responsible for clarifying the standards of acceptable
         | 
| 37 | 
            +
            behavior and are expected to take appropriate and fair corrective action in
         | 
| 38 | 
            +
            response to any instances of unacceptable behavior.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            Project maintainers have the right and responsibility to remove, edit, or
         | 
| 41 | 
            +
            reject comments, commits, code, wiki edits, issues, and other contributions
         | 
| 42 | 
            +
            that are not aligned to this Code of Conduct, or to ban temporarily or
         | 
| 43 | 
            +
            permanently any contributor for other behaviors that they deem inappropriate,
         | 
| 44 | 
            +
            threatening, offensive, or harmful.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            ## Scope
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            This Code of Conduct applies within all project spaces, and it also applies when
         | 
| 49 | 
            +
            an individual is representing the project or its community in public spaces.
         | 
| 50 | 
            +
            Examples of representing a project or community include using an official
         | 
| 51 | 
            +
            project e-mail address, posting via an official social media account, or acting
         | 
| 52 | 
            +
            as an appointed representative at an online or offline event. Representation of
         | 
| 53 | 
            +
            a project may be further defined and clarified by project maintainers.
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            ## Enforcement
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            Instances of abusive, harassing, or otherwise unacceptable behavior may be
         | 
| 58 | 
            +
            reported by contacting the project team at <[email protected]>. All
         | 
| 59 | 
            +
            complaints will be reviewed and investigated and will result in a response that
         | 
| 60 | 
            +
            is deemed necessary and appropriate to the circumstances. The project team is
         | 
| 61 | 
            +
            obligated to maintain confidentiality with regard to the reporter of an incident.
         | 
| 62 | 
            +
            Further details of specific enforcement policies may be posted separately.
         | 
| 63 | 
            +
             | 
| 64 | 
            +
            Project maintainers who do not follow or enforce the Code of Conduct in good
         | 
| 65 | 
            +
            faith may face temporary or permanent repercussions as determined by other
         | 
| 66 | 
            +
            members of the project's leadership.
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            ## Attribution
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
         | 
| 71 | 
            +
            available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            [homepage]: https://www.contributor-covenant.org
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            For answers to common questions about this code of conduct, see
         | 
| 76 | 
            +
            https://www.contributor-covenant.org/faq
         | 
| 77 | 
            +
             | 
    	
        fairseq/CONTRIBUTING.md
    ADDED
    
    | @@ -0,0 +1,28 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq)
         | 
| 2 | 
            +
            We want to make contributing to this project as easy and transparent as
         | 
| 3 | 
            +
            possible.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            ## Pull Requests
         | 
| 6 | 
            +
            We actively welcome your pull requests.
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            1. Fork the repo and create your branch from `main`.
         | 
| 9 | 
            +
            2. If you've added code that should be tested, add tests.
         | 
| 10 | 
            +
            3. If you've changed APIs, update the documentation.
         | 
| 11 | 
            +
            4. Ensure the test suite passes.
         | 
| 12 | 
            +
            5. Make sure your code lints.
         | 
| 13 | 
            +
            6. If you haven't already, complete the Contributor License Agreement ("CLA").
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            ## Contributor License Agreement ("CLA")
         | 
| 16 | 
            +
            In order to accept your pull request, we need you to submit a CLA. You only need
         | 
| 17 | 
            +
            to do this once to work on any of Facebook's open source projects.
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            Complete your CLA here: <https://code.facebook.com/cla>
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            ## Issues
         | 
| 22 | 
            +
            We use GitHub issues to track public bugs. Please ensure your description is
         | 
| 23 | 
            +
            clear and has sufficient instructions to be able to reproduce the issue.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            ## License
         | 
| 26 | 
            +
            By contributing to Facebook AI Research Sequence-to-Sequence Toolkit (fairseq),
         | 
| 27 | 
            +
            you agree that your contributions will be licensed under the LICENSE file in
         | 
| 28 | 
            +
            the root directory of this source tree.
         | 
    	
        fairseq/LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright (c) Facebook, Inc. and its affiliates.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        fairseq/README.md
    ADDED
    
    | @@ -0,0 +1,229 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            <p align="center">
         | 
| 2 | 
            +
              <img src="docs/fairseq_logo.png" width="150">
         | 
| 3 | 
            +
              <br />
         | 
| 4 | 
            +
              <br />
         | 
| 5 | 
            +
              <a href="https://github.com/pytorch/fairseq/blob/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/license-MIT-blue.svg" /></a>
         | 
| 6 | 
            +
              <a href="https://github.com/pytorch/fairseq/releases"><img alt="Latest Release" src="https://img.shields.io/github/release/pytorch/fairseq.svg" /></a>
         | 
| 7 | 
            +
              <a href="https://github.com/pytorch/fairseq/actions?query=workflow:build"><img alt="Build Status" src="https://github.com/pytorch/fairseq/workflows/build/badge.svg" /></a>
         | 
| 8 | 
            +
              <a href="https://fairseq.readthedocs.io/en/latest/?badge=latest"><img alt="Documentation Status" src="https://readthedocs.org/projects/fairseq/badge/?version=latest" /></a>
         | 
| 9 | 
            +
            </p>
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            --------------------------------------------------------------------------------
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Fairseq(-py) is a sequence modeling toolkit that allows researchers and
         | 
| 14 | 
            +
            developers to train custom models for translation, summarization, language
         | 
| 15 | 
            +
            modeling and other text generation tasks.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            We provide reference implementations of various sequence modeling papers:
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            <details><summary>List of implemented papers</summary><p>
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            * **Convolutional Neural Networks (CNN)**
         | 
| 22 | 
            +
              + [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/conv_lm/README.md)
         | 
| 23 | 
            +
              + [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
         | 
| 24 | 
            +
              + [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
         | 
| 25 | 
            +
              + [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
         | 
| 26 | 
            +
              + [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
         | 
| 27 | 
            +
            * **LightConv and DynamicConv models**
         | 
| 28 | 
            +
              + [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
         | 
| 29 | 
            +
            * **Long Short-Term Memory (LSTM) networks**
         | 
| 30 | 
            +
              + Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015)
         | 
| 31 | 
            +
            * **Transformer (self-attention) networks**
         | 
| 32 | 
            +
              + Attention Is All You Need (Vaswani et al., 2017)
         | 
| 33 | 
            +
              + [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
         | 
| 34 | 
            +
              + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
         | 
| 35 | 
            +
              + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md)
         | 
| 36 | 
            +
              + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md)
         | 
| 37 | 
            +
              + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md)
         | 
| 38 | 
            +
              + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md)
         | 
| 39 | 
            +
              + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
         | 
| 40 | 
            +
              + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
         | 
| 41 | 
            +
              + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
         | 
| 42 | 
            +
              + [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md )
         | 
| 43 | 
            +
              + [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
         | 
| 44 | 
            +
              + [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
         | 
| 45 | 
            +
              + [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
         | 
| 46 | 
            +
              + [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
         | 
| 47 | 
            +
              + [Generating Medical Reports from Patient-Doctor Conversations Using Sequence-to-Sequence Models (Enarvi et al., 2020)](examples/pointer_generator/README.md)
         | 
| 48 | 
            +
              + [Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)](examples/linformer/README.md)
         | 
| 49 | 
            +
              + [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
         | 
| 50 | 
            +
              + [Deep Transformers with Latent Depth (Li et al., 2020)](examples/latent_depth/README.md)
         | 
| 51 | 
            +
              + [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979)
         | 
| 52 | 
            +
              + [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027)
         | 
| 53 | 
            +
              + [Unsupervised Speech Recognition (Baevski, et al., 2021)](https://arxiv.org/abs/2105.11084)
         | 
| 54 | 
            +
            * **Non-autoregressive Transformers**
         | 
| 55 | 
            +
              + Non-Autoregressive Neural Machine Translation (Gu et al., 2017)
         | 
| 56 | 
            +
              + Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al. 2018)
         | 
| 57 | 
            +
              + Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al. 2019)
         | 
| 58 | 
            +
              + Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)
         | 
| 59 | 
            +
              + [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
         | 
| 60 | 
            +
            * **Finetuning**
         | 
| 61 | 
            +
              + [Better Fine-Tuning by Reducing Representational Collapse (Aghajanyan et al. 2020)](examples/rxf/README.md)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            </p></details>
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            ### What's New:
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            * September 2021 [`master` branch renamed to `main`](https://github.com/github/renaming).
         | 
| 68 | 
            +
            * July 2021 [Released DrNMT code](examples/discriminative_reranking_nmt/README.md)
         | 
| 69 | 
            +
            * July 2021 [Released Robust wav2vec 2.0 model](examples/wav2vec/README.md)
         | 
| 70 | 
            +
            * June 2021 [Released XLMR-XL and XLMR-XXL models](examples/xlmr/README.md)
         | 
| 71 | 
            +
            * May 2021 [Released Unsupervised Speech Recognition code](examples/wav2vec/unsupervised/README.md)
         | 
| 72 | 
            +
            * March 2021 [Added full parameter and optimizer state sharding + CPU offloading](examples/fully_sharded_data_parallel/README.md)
         | 
| 73 | 
            +
            * February 2021 [Added LASER training code](examples/laser/README.md)
         | 
| 74 | 
            +
            * December 2020: [Added Adaptive Attention Span code](examples/adaptive_span/README.md)
         | 
| 75 | 
            +
            * December 2020: [GottBERT model and code released](examples/gottbert/README.md)
         | 
| 76 | 
            +
            * November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework
         | 
| 77 | 
            +
              * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md)
         | 
| 78 | 
            +
            * November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0)
         | 
| 79 | 
            +
            * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md)
         | 
| 80 | 
            +
            * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md)
         | 
| 81 | 
            +
            * October 2020: [Added CRISS models and code](examples/criss/README.md)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            <details><summary>Previous updates</summary><p>
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            * September 2020: [Added Linformer code](examples/linformer/README.md)
         | 
| 86 | 
            +
            * September 2020: [Added pointer-generator networks](examples/pointer_generator/README.md)
         | 
| 87 | 
            +
            * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md)
         | 
| 88 | 
            +
            * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md)
         | 
| 89 | 
            +
            * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md)
         | 
| 90 | 
            +
            * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq)
         | 
| 91 | 
            +
            * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md)
         | 
| 92 | 
            +
            * April 2020: [Quant-Noise code released](examples/quant_noise/README.md)
         | 
| 93 | 
            +
            * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md)
         | 
| 94 | 
            +
            * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md)
         | 
| 95 | 
            +
            * February 2020: [mBART model and code released](examples/mbart/README.md)
         | 
| 96 | 
            +
            * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/main/examples/backtranslation#training-your-own-model-wmt18-english-german)
         | 
| 97 | 
            +
            * December 2019: [fairseq 0.9.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.9.0)
         | 
| 98 | 
            +
            * November 2019: [VizSeq released (a visual analysis toolkit for evaluating fairseq models)](https://facebookresearch.github.io/vizseq/docs/getting_started/fairseq_example)
         | 
| 99 | 
            +
            * November 2019: [CamemBERT model and code released](examples/camembert/README.md)
         | 
| 100 | 
            +
            * November 2019: [BART model and code released](examples/bart/README.md)
         | 
| 101 | 
            +
            * November 2019: [XLM-R models and code released](examples/xlmr/README.md)
         | 
| 102 | 
            +
            * September 2019: [Nonautoregressive translation code released](examples/nonautoregressive_translation/README.md)
         | 
| 103 | 
            +
            * August 2019: [WMT'19 models released](examples/wmt19/README.md)
         | 
| 104 | 
            +
            * July 2019: fairseq relicensed under MIT license
         | 
| 105 | 
            +
            * July 2019: [RoBERTa models and code released](examples/roberta/README.md)
         | 
| 106 | 
            +
            * June 2019: [wav2vec models and code released](examples/wav2vec/README.md)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            </p></details>
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            ### Features:
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            * multi-GPU training on one machine or across multiple machines (data and model parallel)
         | 
| 113 | 
            +
            * fast generation on both CPU and GPU with multiple search algorithms implemented:
         | 
| 114 | 
            +
              + beam search
         | 
| 115 | 
            +
              + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424))
         | 
| 116 | 
            +
              + sampling (unconstrained, top-k and top-p/nucleus)
         | 
| 117 | 
            +
              + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018)
         | 
| 118 | 
            +
            * [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU
         | 
| 119 | 
            +
            * [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores))
         | 
| 120 | 
            +
            * [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers
         | 
| 121 | 
            +
            * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration
         | 
| 122 | 
            +
            * [full parameter and optimizer state sharding](examples/fully_sharded_data_parallel/README.md)
         | 
| 123 | 
            +
            * [offloading parameters to CPU](examples/fully_sharded_data_parallel/README.md)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples)
         | 
| 126 | 
            +
            with a convenient `torch.hub` interface:
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            ``` python
         | 
| 129 | 
            +
            en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model')
         | 
| 130 | 
            +
            en2de.translate('Hello world', beam=5)
         | 
| 131 | 
            +
            # 'Hallo Welt'
         | 
| 132 | 
            +
            ```
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/)
         | 
| 135 | 
            +
            and [RoBERTa](https://pytorch.org/hub/pytorch_fairseq_roberta/) for more examples.
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            # Requirements and Installation
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            * [PyTorch](http://pytorch.org/) version >= 1.5.0
         | 
| 140 | 
            +
            * Python version >= 3.6
         | 
| 141 | 
            +
            * For training new models, you'll also need an NVIDIA GPU and [NCCL](https://github.com/NVIDIA/nccl)
         | 
| 142 | 
            +
            * **To install fairseq** and develop locally:
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            ``` bash
         | 
| 145 | 
            +
            git clone https://github.com/pytorch/fairseq
         | 
| 146 | 
            +
            cd fairseq
         | 
| 147 | 
            +
            pip install --editable ./
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            # on MacOS:
         | 
| 150 | 
            +
            # CFLAGS="-stdlib=libc++" pip install --editable ./
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            # to install the latest stable release (0.10.x)
         | 
| 153 | 
            +
            # pip install fairseq
         | 
| 154 | 
            +
            ```
         | 
| 155 | 
            +
             | 
| 156 | 
            +
            * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library:
         | 
| 157 | 
            +
             | 
| 158 | 
            +
            ``` bash
         | 
| 159 | 
            +
            git clone https://github.com/NVIDIA/apex
         | 
| 160 | 
            +
            cd apex
         | 
| 161 | 
            +
            pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
         | 
| 162 | 
            +
              --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
         | 
| 163 | 
            +
              --global-option="--fast_multihead_attn" ./
         | 
| 164 | 
            +
            ```
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            * **For large datasets** install [PyArrow](https://arrow.apache.org/docs/python/install.html#using-pip): `pip install pyarrow`
         | 
| 167 | 
            +
            * If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size`
         | 
| 168 | 
            +
             as command line options to `nvidia-docker run` .
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            # Getting Started
         | 
| 171 | 
            +
             | 
| 172 | 
            +
            The [full documentation](https://fairseq.readthedocs.io/) contains instructions
         | 
| 173 | 
            +
            for getting started, training new models and extending fairseq with new model
         | 
| 174 | 
            +
            types and tasks.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            # Pre-trained models and examples
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            We provide pre-trained models and pre-processed, binarized test sets for several tasks listed below,
         | 
| 179 | 
            +
            as well as example training and evaluation commands.
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            * [Translation](examples/translation/README.md): convolutional and transformer models are available
         | 
| 182 | 
            +
            * [Language Modeling](examples/language_model/README.md): convolutional and transformer models are available
         | 
| 183 | 
            +
             | 
| 184 | 
            +
            We also have more detailed READMEs to reproduce results from specific papers:
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            * [Cross-lingual Retrieval for Iterative Self-Supervised Training (Tran et al., 2020)](examples/criss/README.md)
         | 
| 187 | 
            +
            * [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](examples/wav2vec/README.md)
         | 
| 188 | 
            +
            * [Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)](examples/unsupervised_quality_estimation/README.md)
         | 
| 189 | 
            +
            * [Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)](examples/quant_noise/README.md)
         | 
| 190 | 
            +
            * [Neural Machine Translation with Byte-Level Subwords (Wang et al., 2020)](examples/byte_level_bpe/README.md)
         | 
| 191 | 
            +
            * [Multilingual Denoising Pre-training for Neural Machine Translation (Liu et at., 2020)](examples/mbart/README.md)
         | 
| 192 | 
            +
            * [Reducing Transformer Depth on Demand with Structured Dropout (Fan et al., 2019)](examples/layerdrop/README.md)
         | 
| 193 | 
            +
            * [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](examples/joint_alignment_translation/README.md)
         | 
| 194 | 
            +
            * [Levenshtein Transformer (Gu et al., 2019)](examples/nonautoregressive_translation/README.md)
         | 
| 195 | 
            +
            * [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md)
         | 
| 196 | 
            +
            * [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md)
         | 
| 197 | 
            +
            * [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md)
         | 
| 198 | 
            +
            * [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md)
         | 
| 199 | 
            +
            * [Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)](examples/pay_less_attention_paper/README.md)
         | 
| 200 | 
            +
            * [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md)
         | 
| 201 | 
            +
            * [Classical Structured Prediction Losses for Sequence to Sequence Learning (Edunov et al., 2018)](https://github.com/pytorch/fairseq/tree/classic_seqlevel)
         | 
| 202 | 
            +
            * [Hierarchical Neural Story Generation (Fan et al., 2018)](examples/stories/README.md)
         | 
| 203 | 
            +
            * [Scaling Neural Machine Translation (Ott et al., 2018)](examples/scaling_nmt/README.md)
         | 
| 204 | 
            +
            * [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](examples/conv_seq2seq/README.md)
         | 
| 205 | 
            +
            * [Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)](examples/language_model/README.conv.md)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            # Join the fairseq community
         | 
| 208 | 
            +
             | 
| 209 | 
            +
            * Twitter: https://twitter.com/fairseq
         | 
| 210 | 
            +
            * Facebook page: https://www.facebook.com/groups/fairseq.users
         | 
| 211 | 
            +
            * Google group: https://groups.google.com/forum/#!forum/fairseq-users
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            # License
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            fairseq(-py) is MIT-licensed.
         | 
| 216 | 
            +
            The license applies to the pre-trained models as well.
         | 
| 217 | 
            +
             | 
| 218 | 
            +
            # Citation
         | 
| 219 | 
            +
             | 
| 220 | 
            +
            Please cite as:
         | 
| 221 | 
            +
             | 
| 222 | 
            +
            ``` bibtex
         | 
| 223 | 
            +
            @inproceedings{ott2019fairseq,
         | 
| 224 | 
            +
              title = {fairseq: A Fast, Extensible Toolkit for Sequence Modeling},
         | 
| 225 | 
            +
              author = {Myle Ott and Sergey Edunov and Alexei Baevski and Angela Fan and Sam Gross and Nathan Ng and David Grangier and Michael Auli},
         | 
| 226 | 
            +
              booktitle = {Proceedings of NAACL-HLT 2019: Demonstrations},
         | 
| 227 | 
            +
              year = {2019},
         | 
| 228 | 
            +
            }
         | 
| 229 | 
            +
            ```
         | 
    	
        fairseq/docs/Makefile
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Minimal makefile for Sphinx documentation
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # You can set these variables from the command line.
         | 
| 5 | 
            +
            SPHINXOPTS    =
         | 
| 6 | 
            +
            SPHINXBUILD   = python -msphinx
         | 
| 7 | 
            +
            SPHINXPROJ    = fairseq
         | 
| 8 | 
            +
            SOURCEDIR     = .
         | 
| 9 | 
            +
            BUILDDIR      = _build
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Put it first so that "make" without argument is like "make help".
         | 
| 12 | 
            +
            help:
         | 
| 13 | 
            +
            	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            .PHONY: help Makefile
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Catch-all target: route all unknown targets to Sphinx using the new
         | 
| 18 | 
            +
            # "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
         | 
| 19 | 
            +
            %: Makefile
         | 
| 20 | 
            +
            	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
         | 
    	
        fairseq/docs/_static/theme_overrides.css
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .wy-table-responsive table td kbd {
         | 
| 2 | 
            +
                white-space: nowrap;
         | 
| 3 | 
            +
            }
         | 
| 4 | 
            +
            .wy-table-responsive table td {
         | 
| 5 | 
            +
                white-space: normal !important;
         | 
| 6 | 
            +
            }
         | 
| 7 | 
            +
            .wy-table-responsive {
         | 
| 8 | 
            +
                overflow: visible !important;
         | 
| 9 | 
            +
            }
         | 
    	
        fairseq/docs/command_line_tools.rst
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .. _Command-line Tools:
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Command-line Tools
         | 
| 4 | 
            +
            ==================
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Fairseq provides several command-line tools for training and evaluating models:
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            - :ref:`fairseq-preprocess`: Data pre-processing: build vocabularies and binarize training data
         | 
| 9 | 
            +
            - :ref:`fairseq-train`: Train a new model on one or multiple GPUs
         | 
| 10 | 
            +
            - :ref:`fairseq-generate`: Translate pre-processed data with a trained model
         | 
| 11 | 
            +
            - :ref:`fairseq-interactive`: Translate raw text with a trained model
         | 
| 12 | 
            +
            - :ref:`fairseq-score`: BLEU scoring of generated translations against reference translations
         | 
| 13 | 
            +
            - :ref:`fairseq-eval-lm`: Language model evaluation
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            .. _fairseq-preprocess:
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            fairseq-preprocess
         | 
| 19 | 
            +
            ~~~~~~~~~~~~~~~~~~
         | 
| 20 | 
            +
            .. automodule:: fairseq_cli.preprocess
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                .. argparse::
         | 
| 23 | 
            +
                    :module: fairseq.options
         | 
| 24 | 
            +
                    :func: get_preprocessing_parser
         | 
| 25 | 
            +
                    :prog: fairseq-preprocess
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            .. _fairseq-train:
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            fairseq-train
         | 
| 31 | 
            +
            ~~~~~~~~~~~~~
         | 
| 32 | 
            +
            .. automodule:: fairseq_cli.train
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                .. argparse::
         | 
| 35 | 
            +
                    :module: fairseq.options
         | 
| 36 | 
            +
                    :func: get_training_parser
         | 
| 37 | 
            +
                    :prog: fairseq-train
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            .. _fairseq-generate:
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            fairseq-generate
         | 
| 43 | 
            +
            ~~~~~~~~~~~~~~~~
         | 
| 44 | 
            +
            .. automodule:: fairseq_cli.generate
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                .. argparse::
         | 
| 47 | 
            +
                    :module: fairseq.options
         | 
| 48 | 
            +
                    :func: get_generation_parser
         | 
| 49 | 
            +
                    :prog: fairseq-generate
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            .. _fairseq-interactive:
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            fairseq-interactive
         | 
| 55 | 
            +
            ~~~~~~~~~~~~~~~~~~~
         | 
| 56 | 
            +
            .. automodule:: fairseq_cli.interactive
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                .. argparse::
         | 
| 59 | 
            +
                    :module: fairseq.options
         | 
| 60 | 
            +
                    :func: get_interactive_generation_parser
         | 
| 61 | 
            +
                    :prog: fairseq-interactive
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            .. _fairseq-score:
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            fairseq-score
         | 
| 67 | 
            +
            ~~~~~~~~~~~~~
         | 
| 68 | 
            +
            .. automodule:: fairseq_cli.score
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                .. argparse::
         | 
| 71 | 
            +
                    :module: fairseq_cli.score
         | 
| 72 | 
            +
                    :func: get_parser
         | 
| 73 | 
            +
                    :prog: fairseq-score
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            .. _fairseq-eval-lm:
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            fairseq-eval-lm
         | 
| 79 | 
            +
            ~~~~~~~~~~~~~~~
         | 
| 80 | 
            +
            .. automodule:: fairseq_cli.eval_lm
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                .. argparse::
         | 
| 83 | 
            +
                    :module: fairseq.options
         | 
| 84 | 
            +
                    :func: get_eval_lm_parser
         | 
| 85 | 
            +
                    :prog: fairseq-eval-lm
         | 
    	
        fairseq/docs/conf.py
    ADDED
    
    | @@ -0,0 +1,134 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/env python3
         | 
| 2 | 
            +
            # -*- coding: utf-8 -*-
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # fairseq documentation build configuration file, created by
         | 
| 5 | 
            +
            # sphinx-quickstart on Fri Aug 17 21:45:30 2018.
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            # This file is execfile()d with the current directory set to its
         | 
| 8 | 
            +
            # containing dir.
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Note that not all possible configuration values are present in this
         | 
| 11 | 
            +
            # autogenerated file.
         | 
| 12 | 
            +
            #
         | 
| 13 | 
            +
            # All configuration values have a default; values that are commented out
         | 
| 14 | 
            +
            # serve to show the default.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # If extensions (or modules to document with autodoc) are in another directory,
         | 
| 17 | 
            +
            # add these directories to sys.path here. If the directory is relative to the
         | 
| 18 | 
            +
            # documentation root, use os.path.abspath to make it absolute, like shown here.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import os
         | 
| 21 | 
            +
            import sys
         | 
| 22 | 
            +
            from fairseq import __version__
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            # source code directory, relative to this file, for sphinx-autobuild
         | 
| 26 | 
            +
            sys.path.insert(0, os.path.abspath(".."))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            source_suffix = [".rst"]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # -- General configuration ------------------------------------------------
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # If your documentation needs a minimal Sphinx version, state it here.
         | 
| 33 | 
            +
            #
         | 
| 34 | 
            +
            # needs_sphinx = '1.0'
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Add any Sphinx extension module names here, as strings. They can be
         | 
| 37 | 
            +
            # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
         | 
| 38 | 
            +
            # ones.
         | 
| 39 | 
            +
            extensions = [
         | 
| 40 | 
            +
                "sphinx.ext.autodoc",
         | 
| 41 | 
            +
                "sphinx.ext.intersphinx",
         | 
| 42 | 
            +
                "sphinx.ext.viewcode",
         | 
| 43 | 
            +
                "sphinx.ext.napoleon",
         | 
| 44 | 
            +
                "sphinxarg.ext",
         | 
| 45 | 
            +
            ]
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            # Add any paths that contain templates here, relative to this directory.
         | 
| 48 | 
            +
            templates_path = ["_templates"]
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            # The master toctree document.
         | 
| 51 | 
            +
            master_doc = "index"
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            # General information about the project.
         | 
| 54 | 
            +
            project = "fairseq"
         | 
| 55 | 
            +
            copyright = "Facebook AI Research (FAIR)"
         | 
| 56 | 
            +
            author = "Facebook AI Research (FAIR)"
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            github_doc_root = "https://github.com/pytorch/fairseq/tree/main/docs/"
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            # The version info for the project you're documenting, acts as replacement for
         | 
| 61 | 
            +
            # |version| and |release|, also used in various other places throughout the
         | 
| 62 | 
            +
            # built documents.
         | 
| 63 | 
            +
            #
         | 
| 64 | 
            +
            # The short X.Y version.
         | 
| 65 | 
            +
            version = __version__
         | 
| 66 | 
            +
            # The full version, including alpha/beta/rc tags.
         | 
| 67 | 
            +
            release = __version__
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # The language for content autogenerated by Sphinx. Refer to documentation
         | 
| 70 | 
            +
            # for a list of supported languages.
         | 
| 71 | 
            +
            #
         | 
| 72 | 
            +
            # This is also used if you do content translation via gettext catalogs.
         | 
| 73 | 
            +
            # Usually you set "language" from the command line for these cases.
         | 
| 74 | 
            +
            language = None
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            # List of patterns, relative to source directory, that match files and
         | 
| 77 | 
            +
            # directories to ignore when looking for source files.
         | 
| 78 | 
            +
            # This patterns also effect to html_static_path and html_extra_path
         | 
| 79 | 
            +
            exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            # The name of the Pygments (syntax highlighting) style to use.
         | 
| 82 | 
            +
            pygments_style = "sphinx"
         | 
| 83 | 
            +
            highlight_language = "python"
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            # If true, `todo` and `todoList` produce output, else they produce nothing.
         | 
| 86 | 
            +
            todo_include_todos = False
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            # -- Options for HTML output ----------------------------------------------
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            # The theme to use for HTML and HTML Help pages.  See the documentation for
         | 
| 92 | 
            +
            # a list of builtin themes.
         | 
| 93 | 
            +
            #
         | 
| 94 | 
            +
            html_theme = "sphinx_rtd_theme"
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            # Theme options are theme-specific and customize the look and feel of a theme
         | 
| 97 | 
            +
            # further.  For a list of options available for each theme, see the
         | 
| 98 | 
            +
            # documentation.
         | 
| 99 | 
            +
            #
         | 
| 100 | 
            +
            # html_theme_options = {}
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            # Add any paths that contain custom static files (such as style sheets) here,
         | 
| 103 | 
            +
            # relative to this directory. They are copied after the builtin static files,
         | 
| 104 | 
            +
            # so a file named "default.css" will overwrite the builtin "default.css".
         | 
| 105 | 
            +
            html_static_path = ["_static"]
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            html_context = {
         | 
| 108 | 
            +
                "css_files": [
         | 
| 109 | 
            +
                    "_static/theme_overrides.css",  # override wide tables in RTD theme
         | 
| 110 | 
            +
                ],
         | 
| 111 | 
            +
            }
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            # Custom sidebar templates, must be a dictionary that maps document names
         | 
| 114 | 
            +
            # to template names.
         | 
| 115 | 
            +
            #
         | 
| 116 | 
            +
            # This is required for the alabaster theme
         | 
| 117 | 
            +
            # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars
         | 
| 118 | 
            +
            # html_sidebars = {
         | 
| 119 | 
            +
            #    '**': [
         | 
| 120 | 
            +
            #        'about.html',
         | 
| 121 | 
            +
            #        'navigation.html',
         | 
| 122 | 
            +
            #        'relations.html',  # needs 'show_related': True theme option to display
         | 
| 123 | 
            +
            #        'searchbox.html',
         | 
| 124 | 
            +
            #        'donate.html',
         | 
| 125 | 
            +
            #    ]
         | 
| 126 | 
            +
            # }
         | 
| 127 | 
            +
             | 
| 128 | 
            +
             | 
| 129 | 
            +
            # Example configuration for intersphinx: refer to the Python standard library.
         | 
| 130 | 
            +
            intersphinx_mapping = {
         | 
| 131 | 
            +
                "numpy": ("http://docs.scipy.org/doc/numpy/", None),
         | 
| 132 | 
            +
                "python": ("https://docs.python.org/", None),
         | 
| 133 | 
            +
                "torch": ("https://pytorch.org/docs/master/", None),
         | 
| 134 | 
            +
            }
         | 
    	
        fairseq/docs/criterions.rst
    ADDED
    
    | @@ -0,0 +1,31 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .. role:: hidden
         | 
| 2 | 
            +
                :class: hidden-section
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            .. _Criterions:
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Criterions
         | 
| 7 | 
            +
            ==========
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            Criterions compute the loss function given the model and batch, roughly::
         | 
| 10 | 
            +
             | 
| 11 | 
            +
              loss = criterion(model, batch)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            .. automodule:: fairseq.criterions
         | 
| 14 | 
            +
                :members:
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            .. autoclass:: fairseq.criterions.FairseqCriterion
         | 
| 17 | 
            +
                :members:
         | 
| 18 | 
            +
                :undoc-members:
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            .. autoclass:: fairseq.criterions.adaptive_loss.AdaptiveLoss
         | 
| 21 | 
            +
                :members:
         | 
| 22 | 
            +
                :undoc-members:
         | 
| 23 | 
            +
            .. autoclass:: fairseq.criterions.composite_loss.CompositeLoss
         | 
| 24 | 
            +
                :members:
         | 
| 25 | 
            +
                :undoc-members:
         | 
| 26 | 
            +
            .. autoclass:: fairseq.criterions.cross_entropy.CrossEntropyCriterion
         | 
| 27 | 
            +
                :members:
         | 
| 28 | 
            +
                :undoc-members:
         | 
| 29 | 
            +
            .. autoclass:: fairseq.criterions.label_smoothed_cross_entropy.LabelSmoothedCrossEntropyCriterion
         | 
| 30 | 
            +
                :members:
         | 
| 31 | 
            +
                :undoc-members:
         | 
    	
        fairseq/docs/data.rst
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .. role:: hidden
         | 
| 2 | 
            +
                :class: hidden-section
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            .. module:: fairseq.data
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Data Loading and Utilities
         | 
| 7 | 
            +
            ==========================
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            .. _datasets:
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            Datasets
         | 
| 12 | 
            +
            --------
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            **Datasets** define the data format and provide helpers for creating
         | 
| 15 | 
            +
            mini-batches.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            .. autoclass:: fairseq.data.FairseqDataset
         | 
| 18 | 
            +
                :members:
         | 
| 19 | 
            +
            .. autoclass:: fairseq.data.LanguagePairDataset
         | 
| 20 | 
            +
                :members:
         | 
| 21 | 
            +
            .. autoclass:: fairseq.data.MonolingualDataset
         | 
| 22 | 
            +
                :members:
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            **Helper Datasets**
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            These datasets wrap other :class:`fairseq.data.FairseqDataset` instances and
         | 
| 27 | 
            +
            provide additional functionality:
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            .. autoclass:: fairseq.data.BacktranslationDataset
         | 
| 30 | 
            +
                :members:
         | 
| 31 | 
            +
            .. autoclass:: fairseq.data.ConcatDataset
         | 
| 32 | 
            +
                :members:
         | 
| 33 | 
            +
            .. autoclass:: fairseq.data.ResamplingDataset
         | 
| 34 | 
            +
                :members:
         | 
| 35 | 
            +
            .. autoclass:: fairseq.data.RoundRobinZipDatasets
         | 
| 36 | 
            +
                :members:
         | 
| 37 | 
            +
            .. autoclass:: fairseq.data.TransformEosDataset
         | 
| 38 | 
            +
                :members:
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            Dictionary
         | 
| 42 | 
            +
            ----------
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            .. autoclass:: fairseq.data.Dictionary
         | 
| 45 | 
            +
                :members:
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            Iterators
         | 
| 49 | 
            +
            ---------
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            .. autoclass:: fairseq.data.CountingIterator
         | 
| 52 | 
            +
                :members:
         | 
| 53 | 
            +
            .. autoclass:: fairseq.data.EpochBatchIterator
         | 
| 54 | 
            +
                :members:
         | 
| 55 | 
            +
            .. autoclass:: fairseq.data.GroupedIterator
         | 
| 56 | 
            +
                :members:
         | 
| 57 | 
            +
            .. autoclass:: fairseq.data.ShardedIterator
         | 
| 58 | 
            +
                :members:
         | 
    	
        fairseq/docs/docutils.conf
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [writers]
         | 
| 2 | 
            +
            option-limit=0
         | 
    	
        fairseq/docs/fairseq_logo.png
    ADDED
    
    |   | 
    	
        fairseq/docs/getting_started.rst
    ADDED
    
    | @@ -0,0 +1,216 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Evaluating Pre-trained Models
         | 
| 2 | 
            +
            =============================
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            First, download a pre-trained model along with its vocabularies:
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            .. code-block:: console
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                > curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf -
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            This model uses a `Byte Pair Encoding (BPE)
         | 
| 11 | 
            +
            vocabulary <https://arxiv.org/abs/1508.07909>`__, so we'll have to apply
         | 
| 12 | 
            +
            the encoding to the source text before it can be translated. This can be
         | 
| 13 | 
            +
            done with the
         | 
| 14 | 
            +
            `apply\_bpe.py <https://github.com/rsennrich/subword-nmt/blob/master/subword_nmt/apply_bpe.py>`__
         | 
| 15 | 
            +
            script using the ``wmt14.en-fr.fconv-cuda/bpecodes`` file. ``@@`` is
         | 
| 16 | 
            +
            used as a continuation marker and the original text can be easily
         | 
| 17 | 
            +
            recovered with e.g. ``sed s/@@ //g`` or by passing the ``--remove-bpe``
         | 
| 18 | 
            +
            flag to :ref:`fairseq-generate`. Prior to BPE, input text needs to be tokenized
         | 
| 19 | 
            +
            using ``tokenizer.perl`` from
         | 
| 20 | 
            +
            `mosesdecoder <https://github.com/moses-smt/mosesdecoder>`__.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            Let's use :ref:`fairseq-interactive` to generate translations interactively.
         | 
| 23 | 
            +
            Here, we use a beam size of 5 and preprocess the input with the Moses
         | 
| 24 | 
            +
            tokenizer and the given Byte-Pair Encoding vocabulary. It will automatically
         | 
| 25 | 
            +
            remove the BPE continuation markers and detokenize the output.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            .. code-block:: console
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                > MODEL_DIR=wmt14.en-fr.fconv-py
         | 
| 30 | 
            +
                > fairseq-interactive \
         | 
| 31 | 
            +
                    --path $MODEL_DIR/model.pt $MODEL_DIR \
         | 
| 32 | 
            +
                    --beam 5 --source-lang en --target-lang fr \
         | 
| 33 | 
            +
                    --tokenizer moses \
         | 
| 34 | 
            +
                    --bpe subword_nmt --bpe-codes $MODEL_DIR/bpecodes
         | 
| 35 | 
            +
                | loading model(s) from wmt14.en-fr.fconv-py/model.pt
         | 
| 36 | 
            +
                | [en] dictionary: 44206 types
         | 
| 37 | 
            +
                | [fr] dictionary: 44463 types
         | 
| 38 | 
            +
                | Type the input sentence and press return:
         | 
| 39 | 
            +
                Why is it rare to discover new marine mammal species?
         | 
| 40 | 
            +
                S-0     Why is it rare to discover new marine mam@@ mal species ?
         | 
| 41 | 
            +
                H-0     -0.0643349438905716     Pourquoi est-il rare de découvrir de nouvelles espèces de mammifères marins?
         | 
| 42 | 
            +
                P-0     -0.0763 -0.1849 -0.0956 -0.0946 -0.0735 -0.1150 -0.1301 -0.0042 -0.0321 -0.0171 -0.0052 -0.0062 -0.0015
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            This generation script produces three types of outputs: a line prefixed
         | 
| 45 | 
            +
            with *O* is a copy of the original source sentence; *H* is the
         | 
| 46 | 
            +
            hypothesis along with an average log-likelihood; and *P* is the
         | 
| 47 | 
            +
            positional score per token position, including the
         | 
| 48 | 
            +
            end-of-sentence marker which is omitted from the text.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            Other types of output lines you might see are *D*, the detokenized hypothesis,
         | 
| 51 | 
            +
            *T*, the reference target, *A*, alignment info, *E* the history of generation steps.
         | 
| 52 | 
            +
             | 
| 53 | 
            +
            See the `README <https://github.com/pytorch/fairseq#pre-trained-models>`__ for a
         | 
| 54 | 
            +
            full list of pre-trained models available.
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            Training a New Model
         | 
| 57 | 
            +
            ====================
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            The following tutorial is for machine translation. For an example of how
         | 
| 60 | 
            +
            to use Fairseq for other tasks, such as :ref:`language modeling`, please see the
         | 
| 61 | 
            +
            ``examples/`` directory.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            Data Pre-processing
         | 
| 64 | 
            +
            -------------------
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            Fairseq contains example pre-processing scripts for several translation
         | 
| 67 | 
            +
            datasets: IWSLT 2014 (German-English), WMT 2014 (English-French) and WMT
         | 
| 68 | 
            +
            2014 (English-German). To pre-process and binarize the IWSLT dataset:
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            .. code-block:: console
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                > cd examples/translation/
         | 
| 73 | 
            +
                > bash prepare-iwslt14.sh
         | 
| 74 | 
            +
                > cd ../..
         | 
| 75 | 
            +
                > TEXT=examples/translation/iwslt14.tokenized.de-en
         | 
| 76 | 
            +
                > fairseq-preprocess --source-lang de --target-lang en \
         | 
| 77 | 
            +
                    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
         | 
| 78 | 
            +
                    --destdir data-bin/iwslt14.tokenized.de-en
         | 
| 79 | 
            +
             | 
| 80 | 
            +
            This will write binarized data that can be used for model training to
         | 
| 81 | 
            +
            ``data-bin/iwslt14.tokenized.de-en``.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            Training
         | 
| 84 | 
            +
            --------
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            Use :ref:`fairseq-train` to train a new model. Here a few example settings that work
         | 
| 87 | 
            +
            well for the IWSLT 2014 dataset:
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            .. code-block:: console
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                > mkdir -p checkpoints/fconv
         | 
| 92 | 
            +
                > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \
         | 
| 93 | 
            +
                    --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \
         | 
| 94 | 
            +
                    --arch fconv_iwslt_de_en --save-dir checkpoints/fconv
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the
         | 
| 97 | 
            +
            ``CUDA_VISIBLE_DEVICES`` environment variable to select specific GPUs and/or to
         | 
| 98 | 
            +
            change the number of GPU devices that will be used.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            Also note that the batch size is specified in terms of the maximum
         | 
| 101 | 
            +
            number of tokens per batch (``--max-tokens``). You may need to use a
         | 
| 102 | 
            +
            smaller value depending on the available GPU memory on your system.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
            Generation
         | 
| 105 | 
            +
            ----------
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            Once your model is trained, you can generate translations using
         | 
| 108 | 
            +
            :ref:`fairseq-generate` **(for binarized data)** or
         | 
| 109 | 
            +
            :ref:`fairseq-interactive` **(for raw text)**:
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            .. code-block:: console
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                > fairseq-generate data-bin/iwslt14.tokenized.de-en \
         | 
| 114 | 
            +
                    --path checkpoints/fconv/checkpoint_best.pt \
         | 
| 115 | 
            +
                    --batch-size 128 --beam 5
         | 
| 116 | 
            +
                | [de] dictionary: 35475 types
         | 
| 117 | 
            +
                | [en] dictionary: 24739 types
         | 
| 118 | 
            +
                | data-bin/iwslt14.tokenized.de-en test 6750 examples
         | 
| 119 | 
            +
                | model fconv
         | 
| 120 | 
            +
                | loaded checkpoint trainings/fconv/checkpoint_best.pt
         | 
| 121 | 
            +
                S-721   danke .
         | 
| 122 | 
            +
                T-721   thank you .
         | 
| 123 | 
            +
                ...
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            To generate translations with only a CPU, use the ``--cpu`` flag. BPE
         | 
| 126 | 
            +
            continuation markers can be removed with the ``--remove-bpe`` flag.
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            Advanced Training Options
         | 
| 129 | 
            +
            =========================
         | 
| 130 | 
            +
             | 
| 131 | 
            +
            Large mini-batch training with delayed updates
         | 
| 132 | 
            +
            ----------------------------------------------
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            The ``--update-freq`` option can be used to accumulate gradients from
         | 
| 135 | 
            +
            multiple mini-batches and delay updating, creating a larger effective
         | 
| 136 | 
            +
            batch size. Delayed updates can also improve training speed by reducing
         | 
| 137 | 
            +
            inter-GPU communication costs and by saving idle time caused by variance
         | 
| 138 | 
            +
            in workload across GPUs. See `Ott et al.
         | 
| 139 | 
            +
            (2018) <https://arxiv.org/abs/1806.00187>`__ for more details.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            To train on a single GPU with an effective batch size that is equivalent
         | 
| 142 | 
            +
            to training on 8 GPUs:
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            .. code-block:: console
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                > CUDA_VISIBLE_DEVICES=0 fairseq-train --update-freq 8 (...)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
            Training with half precision floating point (FP16)
         | 
| 149 | 
            +
            --------------------------------------------------
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            .. note::
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                FP16 training requires a Volta GPU and CUDA 9.1 or greater
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            Recent GPUs enable efficient half precision floating point computation,
         | 
| 156 | 
            +
            e.g., using `Nvidia Tensor Cores
         | 
| 157 | 
            +
            <https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html>`__.
         | 
| 158 | 
            +
            Fairseq supports FP16 training with the ``--fp16`` flag:
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            .. code-block:: console
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                > fairseq-train --fp16 (...)
         | 
| 163 | 
            +
             | 
| 164 | 
            +
            Distributed training
         | 
| 165 | 
            +
            --------------------
         | 
| 166 | 
            +
             | 
| 167 | 
            +
            Distributed training in fairseq is implemented on top of ``torch.distributed``.
         | 
| 168 | 
            +
            The easiest way to launch jobs is with the `torch.distributed.launch
         | 
| 169 | 
            +
            <https://pytorch.org/docs/stable/distributed.html#launch-utility>`__ tool.
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            For example, to train a large English-German Transformer model on 2 nodes each
         | 
| 172 | 
            +
            with 8 GPUs (in total 16 GPUs), run the following command on each node,
         | 
| 173 | 
            +
            replacing ``node_rank=0`` with ``node_rank=1`` on the second node and making
         | 
| 174 | 
            +
            sure to update ``--master_addr`` to the IP address of the first node:
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            .. code-block:: console
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                > python -m torch.distributed.launch --nproc_per_node=8 \
         | 
| 179 | 
            +
                    --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" \
         | 
| 180 | 
            +
                    --master_port=12345 \
         | 
| 181 | 
            +
                    $(which fairseq-train) data-bin/wmt16_en_de_bpe32k \
         | 
| 182 | 
            +
                    --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
         | 
| 183 | 
            +
                    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
         | 
| 184 | 
            +
                    --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
         | 
| 185 | 
            +
                    --lr 0.0005 \
         | 
| 186 | 
            +
                    --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
         | 
| 187 | 
            +
                    --max-tokens 3584 \
         | 
| 188 | 
            +
                    --max-epoch 70 \
         | 
| 189 | 
            +
                    --fp16
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            On SLURM clusters, fairseq will automatically detect the number of nodes and
         | 
| 192 | 
            +
            GPUs, but a port number must be provided:
         | 
| 193 | 
            +
             | 
| 194 | 
            +
            .. code-block:: console
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                > salloc --gpus=16 --nodes 2 (...)
         | 
| 197 | 
            +
                > srun fairseq-train --distributed-port 12345 (...).
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            Sharding very large datasets
         | 
| 200 | 
            +
            ----------------------------
         | 
| 201 | 
            +
             | 
| 202 | 
            +
            It can be challenging to train over very large datasets, particularly if your
         | 
| 203 | 
            +
            machine does not have much system RAM. Most tasks in fairseq support training
         | 
| 204 | 
            +
            over "sharded" datasets, in which the original dataset has been preprocessed
         | 
| 205 | 
            +
            into non-overlapping chunks (or "shards").
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            For example, instead of preprocessing all your data into a single "data-bin"
         | 
| 208 | 
            +
            directory, you can split the data and create "data-bin1", "data-bin2", etc.
         | 
| 209 | 
            +
            Then you can adapt your training command like so:
         | 
| 210 | 
            +
             | 
| 211 | 
            +
            .. code-block:: console
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                > fairseq-train data-bin1:data-bin2:data-bin3 (...)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            Training will now iterate over each shard, one by one, with each shard
         | 
| 216 | 
            +
            corresponding to an "epoch", thus reducing system memory usage.
         | 
    	
        fairseq/docs/hydra_integration.md
    ADDED
    
    | @@ -0,0 +1,284 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ## Hydra
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python
         | 
| 4 | 
            +
            framework that simplifies the development of research and other complex
         | 
| 5 | 
            +
            applications. The key feature is the ability to dynamically create a
         | 
| 6 | 
            +
            hierarchical configuration by composition and override it through config files
         | 
| 7 | 
            +
            and the command line. The name Hydra comes from its ability to run multiple
         | 
| 8 | 
            +
            similar jobs - much like a Hydra with multiple heads.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            ## Motivation
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            Until recently, all components in fairseq were configured through a shared
         | 
| 13 | 
            +
            `args` namespace that was created at application startup. Components declared
         | 
| 14 | 
            +
            their own `add_args` method to update the argparse parser, hoping that the names
         | 
| 15 | 
            +
            would not clash with arguments from other components. While this model works for
         | 
| 16 | 
            +
            smaller applications, as fairseq grew and became integrated into other
         | 
| 17 | 
            +
            applications, this became problematic. In order to determine how to configure
         | 
| 18 | 
            +
            each component, one needed to a) examine what args were added by this component,
         | 
| 19 | 
            +
            and b) read the code to figure out what shared arguments it is using that were
         | 
| 20 | 
            +
            added in other places. Reproducing models involved sharing commands that often
         | 
| 21 | 
            +
            contained dozens of command line switches.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            The model described above is still supported by fairseq for backward
         | 
| 24 | 
            +
            compatibility, but will be deprecated some time in the future.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            New components in fairseq should now create a dataclass that encapsulates all
         | 
| 27 | 
            +
            parameters required to configure this component. The dataclass is registered
         | 
| 28 | 
            +
            along with the component, and fairseq takes care of constructing and providing
         | 
| 29 | 
            +
            this configuration object to the component's constructor. Note that sharing
         | 
| 30 | 
            +
            parameters can optionally still work, but one has to explicitly point to the
         | 
| 31 | 
            +
            "source of truth" (see inheritance example below). These changes make components
         | 
| 32 | 
            +
            in fairseq more independent and re-usable by other applications: all that is
         | 
| 33 | 
            +
            needed to create a component is to initialize its dataclass and overwrite some
         | 
| 34 | 
            +
            of the defaults.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            While configuring fairseq through command line (using either the legacy argparse
         | 
| 37 | 
            +
            based or the new Hydra based entry points) is still fully supported, you can now
         | 
| 38 | 
            +
            take advantage of configuring fairseq completely or piece-by-piece through
         | 
| 39 | 
            +
            hierarchical YAML configuration files. These files can also be shipped as
         | 
| 40 | 
            +
            examples that others can use to run an identically configured job.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            Additionally, Hydra has a rich and growing [library of
         | 
| 43 | 
            +
            plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that
         | 
| 44 | 
            +
            provide functionality such as hyperparameter sweeping (including using bayesian
         | 
| 45 | 
            +
            optimization through the [Ax](https://github.com/facebook/Ax) library), job
         | 
| 46 | 
            +
            launching across various platforms, and more.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            ## Creating or migrating components
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            In general, each new (or updated) component should provide a companion
         | 
| 51 | 
            +
            [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are
         | 
| 52 | 
            +
            typically located in the same file as the component and are passed as arguments
         | 
| 53 | 
            +
            to the `register_*()` functions. Top-level configs that should be present in
         | 
| 54 | 
            +
            every fairseq application are placed in the
         | 
| 55 | 
            +
            [global](fairseq/dataclass/configs.py) config file and added to the
         | 
| 56 | 
            +
            `FairseqConfig` object.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These
         | 
| 59 | 
            +
            classes are decorated with a `@dataclass` decorator, and typically inherit from
         | 
| 60 | 
            +
            `FairseqDataclass` (which adds some functionality for backward compatibility).
         | 
| 61 | 
            +
            Each field must have a type, and generally has metadata (such as a help string)
         | 
| 62 | 
            +
            and a default value. Only primitive types or other config objects are allowed as
         | 
| 63 | 
            +
            data types for each field.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            #### Example:
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            ```python
         | 
| 68 | 
            +
            from dataclasses import dataclass, field
         | 
| 69 | 
            +
            from fairseq.dataclass import FairseqDataclass
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            @dataclass
         | 
| 72 | 
            +
            class InteractiveConfig(FairseqDataclass):
         | 
| 73 | 
            +
                buffer_size: int = field(
         | 
| 74 | 
            +
                    default=0,
         | 
| 75 | 
            +
                    metadata={
         | 
| 76 | 
            +
                        "help": "read this many sentences into a buffer before processing them"
         | 
| 77 | 
            +
                    },
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
                input: str = field(
         | 
| 80 | 
            +
                    default="-",
         | 
| 81 | 
            +
                    metadata={"help": "file to read from; use - for stdin"},
         | 
| 82 | 
            +
                )
         | 
| 83 | 
            +
            ```
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            ### Inherting values
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            Some components require sharing a value. For example, a learning rate scheduler
         | 
| 88 | 
            +
            and an optimizer may both need to know the initial learning rate value. One can
         | 
| 89 | 
            +
            declare a field that, by default, will inherit its value from another config
         | 
| 90 | 
            +
            node in the same hierarchy:
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            ```python
         | 
| 93 | 
            +
            @dataclass
         | 
| 94 | 
            +
            FairseqAdamConfig(FairseqDataclass):
         | 
| 95 | 
            +
                ...
         | 
| 96 | 
            +
                lr: List[float] = II("optimization.lr")
         | 
| 97 | 
            +
                ...
         | 
| 98 | 
            +
            ```
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            `II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is
         | 
| 101 | 
            +
            the value one can use in a YAML config file or through command line to achieve
         | 
| 102 | 
            +
            the same effect. Note that this assumes that there is an "optimization" config
         | 
| 103 | 
            +
            object in the root config and it has a field called "lr".
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            ### Tasks and Models
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            Creating Tasks and Models works same as before, except that legacy
         | 
| 108 | 
            +
            implementations now inherit from `LegacyFairseq*` base classes, while new
         | 
| 109 | 
            +
            components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass
         | 
| 110 | 
            +
            to the `register_*()` functions.
         | 
| 111 | 
            +
             | 
| 112 | 
            +
            #### Task example:
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            ```python
         | 
| 115 | 
            +
            @dataclass
         | 
| 116 | 
            +
            class LanguageModelingConfig(FairseqDataclass):
         | 
| 117 | 
            +
                data: Optional[str] = field(
         | 
| 118 | 
            +
                    default=None, metadata={"help": "path to data directory"}
         | 
| 119 | 
            +
                )
         | 
| 120 | 
            +
                ...
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            @register_task("language_modeling", dataclass=LanguageModelingConfig)
         | 
| 123 | 
            +
            class LanguageModelingTask(FairseqTask):
         | 
| 124 | 
            +
                ...
         | 
| 125 | 
            +
                @classmethod
         | 
| 126 | 
            +
                def setup_task(cls, cfg: LanguageModelingConfig):
         | 
| 127 | 
            +
                    ...
         | 
| 128 | 
            +
            ```
         | 
| 129 | 
            +
             | 
| 130 | 
            +
            #### Model example:
         | 
| 131 | 
            +
             | 
| 132 | 
            +
            ```python
         | 
| 133 | 
            +
            @dataclass
         | 
| 134 | 
            +
            class TransformerLanguageModelConfig(FairseqDataclass):
         | 
| 135 | 
            +
                activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
         | 
| 136 | 
            +
                    default="relu", metadata={"help": "activation function to use"}
         | 
| 137 | 
            +
                )
         | 
| 138 | 
            +
                dropout: float = field(default=0.1, metadata={"help": "dropout probability"})
         | 
| 139 | 
            +
                ...
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            @register_model("transformer_lm", dataclass=TransformerLanguageModelConfig)
         | 
| 142 | 
            +
            class TransformerLanguageModel(FairseqLanguageModel):
         | 
| 143 | 
            +
                ...
         | 
| 144 | 
            +
                @classmethod
         | 
| 145 | 
            +
                def build_model(cls, cfg: TransformerLanguageModelConfig, task: FairseqTask):
         | 
| 146 | 
            +
                    ...
         | 
| 147 | 
            +
            ```
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            ### Other components
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            Other components work as before, but they now take their configuration dataclass
         | 
| 152 | 
            +
            as the only constructor argument:
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            ```python
         | 
| 155 | 
            +
            @dataclass
         | 
| 156 | 
            +
            class MosesTokenizerConfig(FairseqDataclass):
         | 
| 157 | 
            +
                source_lang: str = field(default="en", metadata={"help": "source language"})
         | 
| 158 | 
            +
                ...
         | 
| 159 | 
            +
             | 
| 160 | 
            +
            @register_tokenizer("moses", dataclass=MosesTokenizerConfig)
         | 
| 161 | 
            +
            class MosesTokenizer(object):
         | 
| 162 | 
            +
                def __init__(self, cfg: MosesTokenizerConfig):
         | 
| 163 | 
            +
                    ...
         | 
| 164 | 
            +
            ```
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            Note that if you are adding a new registry for a new set of components, you need
         | 
| 167 | 
            +
            to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`:
         | 
| 168 | 
            +
             | 
| 169 | 
            +
            ```python
         | 
| 170 | 
            +
            @dataclass
         | 
| 171 | 
            +
            class FairseqConfig(object):
         | 
| 172 | 
            +
                ...
         | 
| 173 | 
            +
                my_new_registry: Any = None
         | 
| 174 | 
            +
            ```
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            ## Training with `fairseq-hydra-train`
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            To fully take advantage of configuration flexibility offered by Hydra, you may
         | 
| 179 | 
            +
            want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI
         | 
| 180 | 
            +
            tools such as `fairseq-train` will remain supported for the foreseeable future
         | 
| 181 | 
            +
            but will be deprecated eventually.
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            On startup, Hydra will create a configuration object that contains a hierarchy
         | 
| 184 | 
            +
            of all the necessary dataclasses populated with their default values in the
         | 
| 185 | 
            +
            code. The default values are overwritten by values found in YAML files in
         | 
| 186 | 
            +
            `fairseq/config` directory (which currently sets minimal defaults) and then
         | 
| 187 | 
            +
            further overwritten by values provided through command line arguments.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
            Some of the most common use cases are shown below:
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            ### 1. Override default values through command line:
         | 
| 192 | 
            +
             | 
| 193 | 
            +
            ```shell script
         | 
| 194 | 
            +
            $ fairseq-hydra-train \
         | 
| 195 | 
            +
                distributed_training.distributed_world_size=1 \
         | 
| 196 | 
            +
                dataset.batch_size=2 \
         | 
| 197 | 
            +
                task.data=data-bin \
         | 
| 198 | 
            +
                model=transformer_lm/transformer_lm_gpt \
         | 
| 199 | 
            +
                task=language_modeling \
         | 
| 200 | 
            +
                optimization.max_update=5000
         | 
| 201 | 
            +
            ```
         | 
| 202 | 
            +
             | 
| 203 | 
            +
            Note that along with explicitly providing values for parameters such as
         | 
| 204 | 
            +
            `dataset.batch_size`, this also tells Hydra to overlay configuration found in
         | 
| 205 | 
            +
            `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default
         | 
| 206 | 
            +
            values in the dataclass. If you want to train a model without specifying a
         | 
| 207 | 
            +
            particular architecture you can simply specify `model=transformer_lm`. This only
         | 
| 208 | 
            +
            works for migrated tasks and models.
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            ### 2. Replace bundled configs with an external config:
         | 
| 211 | 
            +
             | 
| 212 | 
            +
            ```shell script
         | 
| 213 | 
            +
            $ fairseq-hydra-train \
         | 
| 214 | 
            +
                --config-dir /path/to/external/configs \
         | 
| 215 | 
            +
                --config-name wiki103
         | 
| 216 | 
            +
            ```
         | 
| 217 | 
            +
             | 
| 218 | 
            +
            where `/path/to/external/configs/wiki103.yaml` contains:
         | 
| 219 | 
            +
             | 
| 220 | 
            +
            ```yaml
         | 
| 221 | 
            +
            # @package _group_
         | 
| 222 | 
            +
             | 
| 223 | 
            +
            model:
         | 
| 224 | 
            +
              _name: transformer_lm
         | 
| 225 | 
            +
            distributed_training:
         | 
| 226 | 
            +
              distributed_world_size: 1
         | 
| 227 | 
            +
            dataset:
         | 
| 228 | 
            +
              batch_size: 2
         | 
| 229 | 
            +
            task:
         | 
| 230 | 
            +
              _name: language_modeling
         | 
| 231 | 
            +
              data: /path/to/data
         | 
| 232 | 
            +
              add_bos_token: false
         | 
| 233 | 
            +
              max_target_positions: 1024
         | 
| 234 | 
            +
            optimization:
         | 
| 235 | 
            +
              max_update: 50000
         | 
| 236 | 
            +
              lr: [ 0.25 ]
         | 
| 237 | 
            +
            criterion: cross_entropy
         | 
| 238 | 
            +
            optimizer: adam
         | 
| 239 | 
            +
            lr_scheduler:
         | 
| 240 | 
            +
              _name: cosine
         | 
| 241 | 
            +
            ```
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            Note that here bundled configs from `fairseq/config` directory are not used,
         | 
| 244 | 
            +
            however the defaults from each dataclass will still be used (unless overwritten
         | 
| 245 | 
            +
            by your external config).
         | 
| 246 | 
            +
             | 
| 247 | 
            +
            Additionally you can choose to break up your configs by creating a directory
         | 
| 248 | 
            +
            structure in the same location as your main config file, with the names of the
         | 
| 249 | 
            +
            top-level fields (such as "model", "dataset", etc), and placing config files
         | 
| 250 | 
            +
            with meaningful names that would populate that specific section of your
         | 
| 251 | 
            +
            top-level config file (for example, you might have
         | 
| 252 | 
            +
            `model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You
         | 
| 253 | 
            +
            can then specify the correct configuration via command line, defaults in the
         | 
| 254 | 
            +
            main config, or even launch all of them as a sweep (see Hydra documentation on
         | 
| 255 | 
            +
            how to do this).
         | 
| 256 | 
            +
             | 
| 257 | 
            +
            ### 3. Add an external config directory to Hydra search path:
         | 
| 258 | 
            +
             | 
| 259 | 
            +
            This allows combining default configuration (including using any bundled config
         | 
| 260 | 
            +
            files), while specifying your own config files for some parts of the
         | 
| 261 | 
            +
            configuration.
         | 
| 262 | 
            +
             | 
| 263 | 
            +
            ```shell script
         | 
| 264 | 
            +
            $ fairseq-hydra-train \
         | 
| 265 | 
            +
                distributed_training.distributed_world_size=1 \
         | 
| 266 | 
            +
                dataset.batch_size=2 \
         | 
| 267 | 
            +
                task.data=/path/to/data/ \
         | 
| 268 | 
            +
                model=transformer_lm/2_layers \
         | 
| 269 | 
            +
                task=language_modeling \
         | 
| 270 | 
            +
                optimization.max_update=5000 \
         | 
| 271 | 
            +
                --config-dir /path/to/external/configs
         | 
| 272 | 
            +
            ```
         | 
| 273 | 
            +
             | 
| 274 | 
            +
            where `/path/to/external/configs` has the following structure:
         | 
| 275 | 
            +
            ```
         | 
| 276 | 
            +
            .
         | 
| 277 | 
            +
            +-- model
         | 
| 278 | 
            +
            |   +-- transformer_lm
         | 
| 279 | 
            +
            |   |   +-- 2_layers.yaml
         | 
| 280 | 
            +
            ```
         | 
| 281 | 
            +
             | 
| 282 | 
            +
            and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with
         | 
| 283 | 
            +
            `decoder_layers` set to 2. You can add other configs to configure other
         | 
| 284 | 
            +
            components as well.
         | 
    	
        fairseq/docs/index.rst
    ADDED
    
    | @@ -0,0 +1,49 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .. fairseq documentation master file, created by
         | 
| 2 | 
            +
               sphinx-quickstart on Fri Aug 17 21:45:30 2018.
         | 
| 3 | 
            +
               You can adapt this file completely to your liking, but it should at least
         | 
| 4 | 
            +
               contain the root `toctree` directive.
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            :github_url: https://github.com/pytorch/fairseq
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            fairseq documentation
         | 
| 10 | 
            +
            =====================
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            Fairseq is a sequence modeling toolkit written in `PyTorch
         | 
| 13 | 
            +
            <http://pytorch.org/>`_ that allows researchers and developers to
         | 
| 14 | 
            +
            train custom models for translation, summarization, language modeling and other
         | 
| 15 | 
            +
            text generation tasks.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            .. toctree::
         | 
| 18 | 
            +
                :maxdepth: 1
         | 
| 19 | 
            +
                :caption: Getting Started
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                getting_started
         | 
| 22 | 
            +
                command_line_tools
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            .. toctree::
         | 
| 25 | 
            +
                :maxdepth: 1
         | 
| 26 | 
            +
                :caption: Extending Fairseq
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                overview
         | 
| 29 | 
            +
                tutorial_simple_lstm
         | 
| 30 | 
            +
                tutorial_classifying_names
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            .. toctree::
         | 
| 33 | 
            +
                :maxdepth: 2
         | 
| 34 | 
            +
                :caption: Library Reference
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                tasks
         | 
| 37 | 
            +
                models
         | 
| 38 | 
            +
                criterions
         | 
| 39 | 
            +
                optim
         | 
| 40 | 
            +
                lr_scheduler
         | 
| 41 | 
            +
                data
         | 
| 42 | 
            +
                modules
         | 
| 43 | 
            +
             | 
| 44 | 
            +
             | 
| 45 | 
            +
            Indices and tables
         | 
| 46 | 
            +
            ==================
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            * :ref:`genindex`
         | 
| 49 | 
            +
            * :ref:`search`
         | 
    	
        fairseq/docs/lr_scheduler.rst
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .. role:: hidden
         | 
| 2 | 
            +
                :class: hidden-section
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            .. _Learning Rate Schedulers:
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            Learning Rate Schedulers
         | 
| 7 | 
            +
            ========================
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            Learning Rate Schedulers update the learning rate over the course of training.
         | 
| 10 | 
            +
            Learning rates can be updated after each update via :func:`step_update` or at
         | 
| 11 | 
            +
            epoch boundaries via :func:`step`.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            .. automodule:: fairseq.optim.lr_scheduler
         | 
| 14 | 
            +
                :members:
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            .. autoclass:: fairseq.optim.lr_scheduler.FairseqLRScheduler
         | 
| 17 | 
            +
                :members:
         | 
| 18 | 
            +
                :undoc-members:
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            .. autoclass:: fairseq.optim.lr_scheduler.cosine_lr_scheduler.CosineSchedule
         | 
| 21 | 
            +
                :members:
         | 
| 22 | 
            +
                :undoc-members:
         | 
| 23 | 
            +
            .. autoclass:: fairseq.optim.lr_scheduler.fixed_schedule.FixedSchedule
         | 
| 24 | 
            +
                :members:
         | 
| 25 | 
            +
                :undoc-members:
         | 
| 26 | 
            +
            .. autoclass:: fairseq.optim.lr_scheduler.inverse_square_root_schedule.InverseSquareRootSchedule
         | 
| 27 | 
            +
                :members:
         | 
| 28 | 
            +
                :undoc-members:
         | 
| 29 | 
            +
            .. autoclass:: fairseq.optim.lr_scheduler.reduce_lr_on_plateau.ReduceLROnPlateau
         | 
| 30 | 
            +
                :members:
         | 
| 31 | 
            +
                :undoc-members:
         | 
| 32 | 
            +
            .. autoclass:: fairseq.optim.lr_scheduler.triangular_lr_scheduler.TriangularSchedule
         | 
| 33 | 
            +
                :members:
         | 
| 34 | 
            +
                :undoc-members:
         | 
    	
        fairseq/docs/make.bat
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            @ECHO OFF
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            pushd %~dp0
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            REM Command file for Sphinx documentation
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            if "%SPHINXBUILD%" == "" (
         | 
| 8 | 
            +
            	set SPHINXBUILD=python -msphinx
         | 
| 9 | 
            +
            )
         | 
| 10 | 
            +
            set SOURCEDIR=.
         | 
| 11 | 
            +
            set BUILDDIR=_build
         | 
| 12 | 
            +
            set SPHINXPROJ=fairseq
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            if "%1" == "" goto help
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            %SPHINXBUILD% >NUL 2>NUL
         | 
| 17 | 
            +
            if errorlevel 9009 (
         | 
| 18 | 
            +
            	echo.
         | 
| 19 | 
            +
            	echo.The Sphinx module was not found. Make sure you have Sphinx installed,
         | 
| 20 | 
            +
            	echo.then set the SPHINXBUILD environment variable to point to the full
         | 
| 21 | 
            +
            	echo.path of the 'sphinx-build' executable. Alternatively you may add the
         | 
| 22 | 
            +
            	echo.Sphinx directory to PATH.
         | 
| 23 | 
            +
            	echo.
         | 
| 24 | 
            +
            	echo.If you don't have Sphinx installed, grab it from
         | 
| 25 | 
            +
            	echo.http://sphinx-doc.org/
         | 
| 26 | 
            +
            	exit /b 1
         | 
| 27 | 
            +
            )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
         | 
| 30 | 
            +
            goto end
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            :help
         | 
| 33 | 
            +
            %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            :end
         | 
| 36 | 
            +
            popd
         | 
    	
        fairseq/docs/models.rst
    ADDED
    
    | @@ -0,0 +1,104 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            .. role:: hidden
         | 
| 2 | 
            +
                :class: hidden-section
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            .. module:: fairseq.models
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            .. _Models:
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            Models
         | 
| 9 | 
            +
            ======
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            A Model defines the neural network's ``forward()`` method and encapsulates all
         | 
| 12 | 
            +
            of the learnable parameters in the network. Each model also provides a set of
         | 
| 13 | 
            +
            named *architectures* that define the precise network configuration (e.g.,
         | 
| 14 | 
            +
            embedding dimension, number of layers, etc.).
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            Both the model type and architecture are selected via the ``--arch``
         | 
| 17 | 
            +
            command-line argument. Once selected, a model may expose additional command-line
         | 
| 18 | 
            +
            arguments for further configuration.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            .. note::
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                All fairseq Models extend :class:`BaseFairseqModel`, which in turn extends
         | 
| 23 | 
            +
                :class:`torch.nn.Module`. Thus any fairseq Model can be used as a
         | 
| 24 | 
            +
                stand-alone Module in other PyTorch code.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            Convolutional Neural Networks (CNN)
         | 
| 28 | 
            +
            -----------------------------------
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            .. module:: fairseq.models.fconv
         | 
| 31 | 
            +
            .. autoclass:: fairseq.models.fconv.FConvModel
         | 
| 32 | 
            +
                :members:
         | 
| 33 | 
            +
            .. autoclass:: fairseq.models.fconv.FConvEncoder
         | 
| 34 | 
            +
                :members:
         | 
| 35 | 
            +
                :undoc-members:
         | 
| 36 | 
            +
            .. autoclass:: fairseq.models.fconv.FConvDecoder
         | 
| 37 | 
            +
                :members:
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            Long Short-Term Memory (LSTM) networks
         | 
| 41 | 
            +
            --------------------------------------
         | 
| 42 | 
            +
             | 
| 43 | 
            +
            .. module:: fairseq.models.lstm
         | 
| 44 | 
            +
            .. autoclass:: fairseq.models.lstm.LSTMModel
         | 
| 45 | 
            +
                :members:
         | 
| 46 | 
            +
            .. autoclass:: fairseq.models.lstm.LSTMEncoder
         | 
| 47 | 
            +
                :members:
         | 
| 48 | 
            +
            .. autoclass:: fairseq.models.lstm.LSTMDecoder
         | 
| 49 | 
            +
                :members:
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            Transformer (self-attention) networks
         | 
| 53 | 
            +
            -------------------------------------
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            .. module:: fairseq.models.transformer
         | 
| 56 | 
            +
            .. autoclass:: fairseq.models.transformer.TransformerModel
         | 
| 57 | 
            +
                :members:
         | 
| 58 | 
            +
            .. autoclass:: fairseq.models.transformer.TransformerEncoder
         | 
| 59 | 
            +
                :members:
         | 
| 60 | 
            +
            .. autoclass:: fairseq.models.transformer.TransformerEncoderLayer
         | 
| 61 | 
            +
                :members:
         | 
| 62 | 
            +
            .. autoclass:: fairseq.models.transformer.TransformerDecoder
         | 
| 63 | 
            +
                :members:
         | 
| 64 | 
            +
            .. autoclass:: fairseq.models.transformer.TransformerDecoderLayer
         | 
| 65 | 
            +
                :members:
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            Adding new models
         | 
| 69 | 
            +
            -----------------
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            .. currentmodule:: fairseq.models
         | 
| 72 | 
            +
            .. autofunction:: fairseq.models.register_model
         | 
| 73 | 
            +
            .. autofunction:: fairseq.models.register_model_architecture
         | 
| 74 | 
            +
            .. autoclass:: fairseq.models.BaseFairseqModel
         | 
| 75 | 
            +
                :members:
         | 
| 76 | 
            +
                :undoc-members:
         | 
| 77 | 
            +
            .. autoclass:: fairseq.models.FairseqEncoderDecoderModel
         | 
| 78 | 
            +
                :members:
         | 
| 79 | 
            +
                :undoc-members:
         | 
| 80 | 
            +
            .. autoclass:: fairseq.models.FairseqEncoderModel
         | 
| 81 | 
            +
                :members:
         | 
| 82 | 
            +
                :undoc-members:
         | 
| 83 | 
            +
            .. autoclass:: fairseq.models.FairseqLanguageModel
         | 
| 84 | 
            +
                :members:
         | 
| 85 | 
            +
                :undoc-members:
         | 
| 86 | 
            +
            .. autoclass:: fairseq.models.FairseqMultiModel
         | 
| 87 | 
            +
                :members:
         | 
| 88 | 
            +
                :undoc-members:
         | 
| 89 | 
            +
            .. autoclass:: fairseq.models.FairseqEncoder
         | 
| 90 | 
            +
                :members:
         | 
| 91 | 
            +
            .. autoclass:: fairseq.models.CompositeEncoder
         | 
| 92 | 
            +
                :members:
         | 
| 93 | 
            +
            .. autoclass:: fairseq.models.FairseqDecoder
         | 
| 94 | 
            +
                :members:
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
            .. _Incremental decoding:
         | 
| 98 | 
            +
             | 
| 99 | 
            +
            Incremental decoding
         | 
| 100 | 
            +
            --------------------
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            .. autoclass:: fairseq.models.FairseqIncrementalDecoder
         | 
| 103 | 
            +
                :members:
         | 
| 104 | 
            +
                :undoc-members:
         | 
 
			
