Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							·
						
						b6bff08
	
0
								Parent(s):
							
							
Duplicate from h2oai/h2ogpt-chatbot
Browse filesCo-authored-by: Jonathan McKinney <[email protected]>
- .gitattributes +34 -0
- LICENSE +201 -0
- README.md +14 -0
- app.py +1959 -0
- client_test.py +93 -0
- finetune.py +934 -0
- h2o-logo.svg +1 -0
- prompter.py +106 -0
- requirements.txt +48 -0
- stopping.py +139 -0
- utils.py +154 -0
    	
        .gitattributes
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                            Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            title: H2ogpt Chatbot
         | 
| 3 | 
            +
            emoji: 📚
         | 
| 4 | 
            +
            colorFrom: yellow
         | 
| 5 | 
            +
            colorTo: yellow
         | 
| 6 | 
            +
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.27.0
         | 
| 8 | 
            +
            app_file: app.py
         | 
| 9 | 
            +
            pinned: false
         | 
| 10 | 
            +
            license: apache-2.0
         | 
| 11 | 
            +
            duplicated_from: h2oai/h2ogpt-chatbot
         | 
| 12 | 
            +
            ---
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,1959 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import functools
         | 
| 2 | 
            +
            import inspect
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import traceback
         | 
| 6 | 
            +
            import typing
         | 
| 7 | 
            +
            from utils import set_seed, flatten_list, clear_torch_cache, system_info_print, zip_data, save_generate_output
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            SEED = 1236
         | 
| 10 | 
            +
            set_seed(SEED)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
         | 
| 13 | 
            +
            from typing import Union
         | 
| 14 | 
            +
            import numpy as np
         | 
| 15 | 
            +
            import pandas as pd
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import fire
         | 
| 18 | 
            +
            import torch
         | 
| 19 | 
            +
            from peft import PeftModel
         | 
| 20 | 
            +
            from transformers import GenerationConfig, StoppingCriteriaList, AutoModel
         | 
| 21 | 
            +
            from accelerate import init_empty_weights, infer_auto_device_map
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from prompter import Prompter
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from finetune import get_loaders, example_data_points, generate_prompt, get_githash, prompt_types_strings, \
         | 
| 26 | 
            +
                human, bot, prompt_type_to_model_name, inv_prompt_type_to_model_lower
         | 
| 27 | 
            +
            from stopping import CallbackToGenerator, Stream, StoppingCriteriaSub
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
         | 
| 30 | 
            +
            is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
         | 
| 31 | 
            +
            is_public = is_hf or is_gpth2oai  # multi-user case with fixed model and disclaimer
         | 
| 32 | 
            +
            is_low_mem = is_hf  # assumes run on 24GB consumer GPU
         | 
| 33 | 
            +
            admin_pass = os.getenv("ADMIN_PASS")
         | 
| 34 | 
            +
            # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
         | 
| 35 | 
            +
            raise_generate_gpu_exceptions = True
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            eval_extra_columns = ['prompt', 'response', 'score']
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            def main(
         | 
| 40 | 
            +
                    load_8bit: bool = False,
         | 
| 41 | 
            +
                    load_half: bool = True,
         | 
| 42 | 
            +
                    infer_devices: bool = True,
         | 
| 43 | 
            +
                    base_model: str = '',
         | 
| 44 | 
            +
                    tokenizer_base_model: str = '',
         | 
| 45 | 
            +
                    lora_weights: str = "",
         | 
| 46 | 
            +
                    gpu_id: int = 0,  # if infer_devices = True and gpu_id != -1
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    prompt_type: Union[int, str] = None,
         | 
| 49 | 
            +
                    # input to generation
         | 
| 50 | 
            +
                    temperature: float = None,
         | 
| 51 | 
            +
                    top_p: float = None,
         | 
| 52 | 
            +
                    top_k: int = None,
         | 
| 53 | 
            +
                    num_beams: int = None,
         | 
| 54 | 
            +
                    repetition_penalty: float = None,
         | 
| 55 | 
            +
                    num_return_sequences: int = None,
         | 
| 56 | 
            +
                    do_sample: bool = None,
         | 
| 57 | 
            +
                    max_new_tokens: int = None,
         | 
| 58 | 
            +
                    min_new_tokens: int = None,
         | 
| 59 | 
            +
                    early_stopping: Union[bool, str] = None,
         | 
| 60 | 
            +
                    max_time: float = None,
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    llama_type: bool = None,
         | 
| 63 | 
            +
                    debug: bool = False,
         | 
| 64 | 
            +
                    save_dir: str = None,
         | 
| 65 | 
            +
                    share: bool = True,
         | 
| 66 | 
            +
                    local_files_only: bool = False,
         | 
| 67 | 
            +
                    resume_download: bool = True,
         | 
| 68 | 
            +
                    use_auth_token: Union[str, bool] = False,  # True requires CLI did huggingface-cli login before running
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    src_lang: str = "English",
         | 
| 71 | 
            +
                    tgt_lang: str = "Russian",
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    gradio: bool = True,
         | 
| 74 | 
            +
                    gradio_avoid_processing_markdown: bool = False,
         | 
| 75 | 
            +
                    chat: bool = True,
         | 
| 76 | 
            +
                    chat_history: int = 4096,  # character length of chat context/history
         | 
| 77 | 
            +
                    stream_output: bool = True,
         | 
| 78 | 
            +
                    show_examples: bool = None,
         | 
| 79 | 
            +
                    verbose: bool = False,
         | 
| 80 | 
            +
                    h2ocolors: bool = True,
         | 
| 81 | 
            +
                    height: int = 400,
         | 
| 82 | 
            +
                    show_lora: bool = True,
         | 
| 83 | 
            +
                    # set to True to load --base_model after client logs in,
         | 
| 84 | 
            +
                    # to be able to free GPU memory when model is swapped
         | 
| 85 | 
            +
                    login_mode_if_model0: bool = False,
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    sanitize_user_prompt: bool = True,
         | 
| 88 | 
            +
                    sanitize_bot_response: bool = True,
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    extra_model_options: typing.List[str] = [],
         | 
| 91 | 
            +
                    extra_lora_options: typing.List[str] = [],
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
         | 
| 94 | 
            +
                    auto_score: bool = True,
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    eval_sharegpt_prompts_only: int = 0,
         | 
| 97 | 
            +
                    eval_sharegpt_prompts_only_seed: int = 1234,
         | 
| 98 | 
            +
                    eval_sharegpt_as_output: bool = False,
         | 
| 99 | 
            +
            ):
         | 
| 100 | 
            +
                # allow set token directly
         | 
| 101 | 
            +
                use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                if is_public:
         | 
| 104 | 
            +
                    temperature = 0.4
         | 
| 105 | 
            +
                    top_p = 0.85
         | 
| 106 | 
            +
                    top_k = 70
         | 
| 107 | 
            +
                    do_sample = True
         | 
| 108 | 
            +
                    if is_low_mem:
         | 
| 109 | 
            +
                        base_model = 'h2oai/h2ogpt-oasst1-512-12b'
         | 
| 110 | 
            +
                        load_8bit = True
         | 
| 111 | 
            +
                    else:
         | 
| 112 | 
            +
                        base_model = 'h2oai/h2ogpt-oasst1-512-20b'
         | 
| 113 | 
            +
                if is_low_mem:
         | 
| 114 | 
            +
                    load_8bit = True
         | 
| 115 | 
            +
                if is_hf:
         | 
| 116 | 
            +
                    # must override share if in spaces
         | 
| 117 | 
            +
                    share = False
         | 
| 118 | 
            +
                save_dir = os.getenv('SAVE_DIR', save_dir)
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                # get defaults
         | 
| 121 | 
            +
                model_lower = base_model.lower()
         | 
| 122 | 
            +
                if not gradio:
         | 
| 123 | 
            +
                    # force, else not single response like want to look at
         | 
| 124 | 
            +
                    stream_output = False
         | 
| 125 | 
            +
                    # else prompt removal can mess up output
         | 
| 126 | 
            +
                    chat = False
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                placeholder_instruction, placeholder_input, \
         | 
| 129 | 
            +
                stream_output, show_examples, \
         | 
| 130 | 
            +
                prompt_type, temperature, top_p, top_k, num_beams, \
         | 
| 131 | 
            +
                max_new_tokens, min_new_tokens, early_stopping, max_time, \
         | 
| 132 | 
            +
                repetition_penalty, num_return_sequences, \
         | 
| 133 | 
            +
                do_sample, \
         | 
| 134 | 
            +
                src_lang, tgt_lang, \
         | 
| 135 | 
            +
                examples, \
         | 
| 136 | 
            +
                task_info = \
         | 
| 137 | 
            +
                    get_generate_params(model_lower, chat,
         | 
| 138 | 
            +
                                        stream_output, show_examples,
         | 
| 139 | 
            +
                                        prompt_type, temperature, top_p, top_k, num_beams,
         | 
| 140 | 
            +
                                        max_new_tokens, min_new_tokens, early_stopping, max_time,
         | 
| 141 | 
            +
                                        repetition_penalty, num_return_sequences,
         | 
| 142 | 
            +
                                        do_sample,
         | 
| 143 | 
            +
                                        )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                if not gradio:
         | 
| 146 | 
            +
                    if eval_sharegpt_prompts_only > 0:
         | 
| 147 | 
            +
                        # override default examples with shareGPT ones for human-level eval purposes only
         | 
| 148 | 
            +
                        eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
         | 
| 149 | 
            +
                        if not os.path.isfile(eval_filename):
         | 
| 150 | 
            +
                            os.system(
         | 
| 151 | 
            +
                                'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % eval_filename)
         | 
| 152 | 
            +
                        import json
         | 
| 153 | 
            +
                        data = json.load(open(eval_filename, 'rt'))
         | 
| 154 | 
            +
                        # focus on data that starts with human, else likely chopped from other data
         | 
| 155 | 
            +
                        turn_start = 0  # odd in general
         | 
| 156 | 
            +
                        data = [x for x in data if len(x['conversations']) > turn_start + 1 and
         | 
| 157 | 
            +
                                x['conversations'][turn_start]['from'] == 'human' and
         | 
| 158 | 
            +
                                x['conversations'][turn_start + 1]['from'] == 'gpt']
         | 
| 159 | 
            +
                        np.random.seed(eval_sharegpt_prompts_only_seed)
         | 
| 160 | 
            +
                        example1 = examples[-1]  # pick reference example
         | 
| 161 | 
            +
                        examples = []
         | 
| 162 | 
            +
                        responses = []
         | 
| 163 | 
            +
                        for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
         | 
| 164 | 
            +
                            assert data[i]['conversations'][turn_start]['from'] == 'human'
         | 
| 165 | 
            +
                            instruction = data[i]['conversations'][turn_start]['value']
         | 
| 166 | 
            +
                            assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
         | 
| 167 | 
            +
                            output = data[i]['conversations'][turn_start + 1]['value']
         | 
| 168 | 
            +
                            examplenew = example1.copy()
         | 
| 169 | 
            +
                            assert not chat, "No gradio must use chat=False, uses nochat isntruct"
         | 
| 170 | 
            +
                            examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
         | 
| 171 | 
            +
                            examplenew[eval_func_param_names.index('iinput_nochat')] = ''  # no input
         | 
| 172 | 
            +
                            examplenew[eval_func_param_names.index('context')] = ''  # no context
         | 
| 173 | 
            +
                            examples.append(examplenew)
         | 
| 174 | 
            +
                            responses.append(output)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    num_examples = len(examples)
         | 
| 177 | 
            +
                    scoring_path = 'scoring'
         | 
| 178 | 
            +
                    os.makedirs(scoring_path, exist_ok=True)
         | 
| 179 | 
            +
                    if eval_sharegpt_as_output:
         | 
| 180 | 
            +
                        used_base_model = 'gpt35'
         | 
| 181 | 
            +
                        used_lora_weights = ''
         | 
| 182 | 
            +
                    else:
         | 
| 183 | 
            +
                        used_base_model = str(base_model.split('/')[-1])
         | 
| 184 | 
            +
                        used_lora_weights = str(lora_weights.split('/')[-1])
         | 
| 185 | 
            +
                    eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
         | 
| 186 | 
            +
                                                                             eval_sharegpt_prompts_only_seed,
         | 
| 187 | 
            +
                                                                             eval_sharegpt_as_output,
         | 
| 188 | 
            +
                                                                             used_base_model,
         | 
| 189 | 
            +
                                                                             used_lora_weights)
         | 
| 190 | 
            +
                    eval_filename = os.path.join(scoring_path, eval_filename)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    with torch.device("cuda"):
         | 
| 193 | 
            +
                        # ensure was set right above before examples generated
         | 
| 194 | 
            +
                        assert not stream_output, "stream_output=True does not make sense with example loop"
         | 
| 195 | 
            +
                        import time
         | 
| 196 | 
            +
                        from functools import partial
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                        # get score model
         | 
| 199 | 
            +
                        smodel, stokenizer, sdevice = get_score_model(**locals())
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        if not eval_sharegpt_as_output:
         | 
| 202 | 
            +
                            model, tokenizer, device = get_model(**locals())
         | 
| 203 | 
            +
                            model_state = [model, tokenizer, device, base_model]
         | 
| 204 | 
            +
                            fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir)
         | 
| 205 | 
            +
                        else:
         | 
| 206 | 
            +
                            assert eval_sharegpt_prompts_only > 0
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                            def get_response(*args, exi=0):
         | 
| 209 | 
            +
                                # assumes same ordering of examples and responses
         | 
| 210 | 
            +
                                yield responses[exi]
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                            fun = get_response
         | 
| 213 | 
            +
                        t0 = time.time()
         | 
| 214 | 
            +
                        score_dump = []
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                        import matplotlib.pyplot as plt
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        for exi, ex in enumerate(examples):
         | 
| 219 | 
            +
                            instruction = ex[eval_func_param_names.index('instruction_nochat')]
         | 
| 220 | 
            +
                            iinput = ex[eval_func_param_names.index('iinput_nochat')]
         | 
| 221 | 
            +
                            context = ex[eval_func_param_names.index('context')]
         | 
| 222 | 
            +
                            clear_torch_cache()
         | 
| 223 | 
            +
                            print("")
         | 
| 224 | 
            +
                            print("START" + "=" * 100)
         | 
| 225 | 
            +
                            print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
         | 
| 226 | 
            +
                            print("-" * 105)
         | 
| 227 | 
            +
                            # fun yields as generator, so have to iterate over it
         | 
| 228 | 
            +
                            # Also means likely do NOT want --stream_output=True, else would show all generations
         | 
| 229 | 
            +
                            for res in fun(*tuple(ex), exi=exi):
         | 
| 230 | 
            +
                                print(res)
         | 
| 231 | 
            +
                                if smodel:
         | 
| 232 | 
            +
                                    score_with_prompt = False
         | 
| 233 | 
            +
                                    if score_with_prompt:
         | 
| 234 | 
            +
                                        data_point = dict(instruction=instruction, input=iinput, context=context)
         | 
| 235 | 
            +
                                        prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
         | 
| 236 | 
            +
                                        prompt = prompter.generate_prompt(data_point)
         | 
| 237 | 
            +
                                    else:
         | 
| 238 | 
            +
                                        # just raw input and output
         | 
| 239 | 
            +
                                        assert iinput in [None, '']  # should be no iinput
         | 
| 240 | 
            +
                                        assert context in [None, '']  # should be no context
         | 
| 241 | 
            +
                                        prompt = instruction
         | 
| 242 | 
            +
                                    cutoff_len = 768 if is_low_mem else 2048
         | 
| 243 | 
            +
                                    inputs = stokenizer(prompt, res,
         | 
| 244 | 
            +
                                                        return_tensors="pt",
         | 
| 245 | 
            +
                                                        truncation=True,
         | 
| 246 | 
            +
                                                        max_length=cutoff_len)
         | 
| 247 | 
            +
                                    try:
         | 
| 248 | 
            +
                                        score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
         | 
| 249 | 
            +
                                    except torch.cuda.OutOfMemoryError as e:
         | 
| 250 | 
            +
                                        print("GPU OOM: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
         | 
| 251 | 
            +
                                        traceback.print_exc()
         | 
| 252 | 
            +
                                        score = 0.0
         | 
| 253 | 
            +
                                        clear_torch_cache()
         | 
| 254 | 
            +
                                    except (Exception, RuntimeError) as e:
         | 
| 255 | 
            +
                                        if 'Expected all tensors to be on the same device' in str(e) or \
         | 
| 256 | 
            +
                                                'expected scalar type Half but found Float' in str(e) or \
         | 
| 257 | 
            +
                                                'probability tensor contains either' in str(e) or \
         | 
| 258 | 
            +
                                                'cublasLt ran into an error!' in str(e):
         | 
| 259 | 
            +
                                            print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
         | 
| 260 | 
            +
                                                  flush=True)
         | 
| 261 | 
            +
                                            traceback.print_exc()
         | 
| 262 | 
            +
                                            score = 0.0
         | 
| 263 | 
            +
                                            clear_torch_cache()
         | 
| 264 | 
            +
                                        else:
         | 
| 265 | 
            +
                                            raise
         | 
| 266 | 
            +
                                    print("SCORE %s: %s" % (exi, score), flush=True)
         | 
| 267 | 
            +
                                    score_dump.append(ex + [prompt, res, score])
         | 
| 268 | 
            +
                                    # dump every score in case abort
         | 
| 269 | 
            +
                                    df_scores = pd.DataFrame(score_dump,
         | 
| 270 | 
            +
                                                             columns=eval_func_param_names + eval_extra_columns)
         | 
| 271 | 
            +
                                    df_scores.to_parquet(eval_filename, index=False)
         | 
| 272 | 
            +
                                    # plot histogram so far
         | 
| 273 | 
            +
                                    plt.figure(figsize=(10, 10))
         | 
| 274 | 
            +
                                    plt.hist(df_scores['score'], bins=20)
         | 
| 275 | 
            +
                                    score_avg = np.mean(df_scores['score'])
         | 
| 276 | 
            +
                                    score_median = np.median(df_scores['score'])
         | 
| 277 | 
            +
                                    plt.title("Score avg: %s median: %s" % (score_avg, score_median))
         | 
| 278 | 
            +
                                    plt.savefig(eval_filename.replace('.parquet', '.png'))
         | 
| 279 | 
            +
                                    plt.close()
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                            print("END" + "=" * 102)
         | 
| 282 | 
            +
                            print("")
         | 
| 283 | 
            +
                            t2 = time.time()
         | 
| 284 | 
            +
                            print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
         | 
| 285 | 
            +
                        t1 = time.time()
         | 
| 286 | 
            +
                        print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
         | 
| 287 | 
            +
                    return eval_filename
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                if gradio:
         | 
| 290 | 
            +
                    go_gradio(**locals())
         | 
| 291 | 
            +
             | 
| 292 | 
            +
             | 
| 293 | 
            +
            def get_device():
         | 
| 294 | 
            +
                if torch.cuda.is_available():
         | 
| 295 | 
            +
                    device = "cuda"
         | 
| 296 | 
            +
                else:
         | 
| 297 | 
            +
                    raise RuntimeError("only cuda supported")
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                return device
         | 
| 300 | 
            +
             | 
| 301 | 
            +
             | 
| 302 | 
            +
            def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
         | 
| 303 | 
            +
                                   gpu_id=0,
         | 
| 304 | 
            +
                                   use_auth_token=False):
         | 
| 305 | 
            +
                """
         | 
| 306 | 
            +
                Ensure model gets on correct device
         | 
| 307 | 
            +
                :param base_model:
         | 
| 308 | 
            +
                :param model_loader:
         | 
| 309 | 
            +
                :param load_half:
         | 
| 310 | 
            +
                :param model_kwargs:
         | 
| 311 | 
            +
                :param reward_type:
         | 
| 312 | 
            +
                :param gpu_id:
         | 
| 313 | 
            +
                :param use_auth_token:
         | 
| 314 | 
            +
                :return:
         | 
| 315 | 
            +
                """
         | 
| 316 | 
            +
                with init_empty_weights():
         | 
| 317 | 
            +
                    from transformers import AutoConfig
         | 
| 318 | 
            +
                    config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
         | 
| 319 | 
            +
                    model = AutoModel.from_config(
         | 
| 320 | 
            +
                        config,
         | 
| 321 | 
            +
                    )
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
         | 
| 324 | 
            +
                # NOTE: Some models require avoiding sharding some layers,
         | 
| 325 | 
            +
                # then would pass no_split_module_classes and give list of those layers.
         | 
| 326 | 
            +
                device_map = infer_auto_device_map(
         | 
| 327 | 
            +
                    model,
         | 
| 328 | 
            +
                    dtype=torch.float16 if load_half else torch.float32,
         | 
| 329 | 
            +
                )
         | 
| 330 | 
            +
                if hasattr(model, 'model'):
         | 
| 331 | 
            +
                    device_map_model = infer_auto_device_map(
         | 
| 332 | 
            +
                        model.model,
         | 
| 333 | 
            +
                        dtype=torch.float16 if load_half else torch.float32,
         | 
| 334 | 
            +
                    )
         | 
| 335 | 
            +
                    device_map.update(device_map_model)
         | 
| 336 | 
            +
                print('device_map: %s' % device_map, flush=True)
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                if gpu_id >= 0:
         | 
| 339 | 
            +
                    # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
         | 
| 340 | 
            +
                    # So avoid for now, just put on first GPU, unless score_model, put on last
         | 
| 341 | 
            +
                    n_gpus = torch.cuda.device_count()
         | 
| 342 | 
            +
                    if reward_type:
         | 
| 343 | 
            +
                        device_map = {'': n_gpus - 1}
         | 
| 344 | 
            +
                    else:
         | 
| 345 | 
            +
                        device_map = {'': min(n_gpus - 1, gpu_id)}
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                load_in_8bit = model_kwargs.get('load_in_8bit', False)
         | 
| 348 | 
            +
                model_kwargs['device_map'] = device_map
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                if load_in_8bit or not load_half:
         | 
| 351 | 
            +
                    model = model_loader.from_pretrained(
         | 
| 352 | 
            +
                        base_model,
         | 
| 353 | 
            +
                        **model_kwargs,
         | 
| 354 | 
            +
                    )
         | 
| 355 | 
            +
                else:
         | 
| 356 | 
            +
                    model = model_loader.from_pretrained(
         | 
| 357 | 
            +
                        base_model,
         | 
| 358 | 
            +
                        **model_kwargs,
         | 
| 359 | 
            +
                    ).half()
         | 
| 360 | 
            +
                return model
         | 
| 361 | 
            +
             | 
| 362 | 
            +
             | 
| 363 | 
            +
            def get_model(
         | 
| 364 | 
            +
                    load_8bit: bool = False,
         | 
| 365 | 
            +
                    load_half: bool = True,
         | 
| 366 | 
            +
                    infer_devices: bool = True,
         | 
| 367 | 
            +
                    base_model: str = '',
         | 
| 368 | 
            +
                    tokenizer_base_model: str = '',
         | 
| 369 | 
            +
                    lora_weights: str = "",
         | 
| 370 | 
            +
                    gpu_id: int = 0,
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    llama_type: bool = None,
         | 
| 373 | 
            +
                    reward_type: bool = None,
         | 
| 374 | 
            +
                    local_files_only: bool = False,
         | 
| 375 | 
            +
                    resume_download: bool = True,
         | 
| 376 | 
            +
                    use_auth_token: Union[str, bool] = False,
         | 
| 377 | 
            +
                    compile: bool = True,
         | 
| 378 | 
            +
                    **kwargs,
         | 
| 379 | 
            +
            ):
         | 
| 380 | 
            +
                """
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                :param load_8bit: load model in 8-bit, not supported by all models
         | 
| 383 | 
            +
                :param load_half: load model in 16-bit
         | 
| 384 | 
            +
                :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
         | 
| 385 | 
            +
                       For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
         | 
| 386 | 
            +
                       So it is not the default
         | 
| 387 | 
            +
                :param base_model: name/path of base model
         | 
| 388 | 
            +
                :param tokenizer_base_model: name/path of tokenizer
         | 
| 389 | 
            +
                :param lora_weights: name/path
         | 
| 390 | 
            +
                :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
         | 
| 391 | 
            +
                :param llama_type: whether LLaMa type model
         | 
| 392 | 
            +
                :param reward_type: reward type model for sequence classification
         | 
| 393 | 
            +
                :param local_files_only: use local files instead of from HF
         | 
| 394 | 
            +
                :param resume_download: resume downloads from HF
         | 
| 395 | 
            +
                :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
         | 
| 396 | 
            +
                :parm compile: whether to compile torch model
         | 
| 397 | 
            +
                :param kwargs:
         | 
| 398 | 
            +
                :return:
         | 
| 399 | 
            +
                """
         | 
| 400 | 
            +
                print("Get %s model" % base_model, flush=True)
         | 
| 401 | 
            +
                if lora_weights is not None and lora_weights.strip():
         | 
| 402 | 
            +
                    print("Get %s lora weights" % lora_weights, flush=True)
         | 
| 403 | 
            +
                device = get_device()
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                if 'gpt2' in base_model.lower():
         | 
| 406 | 
            +
                    # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
         | 
| 407 | 
            +
                    load_8bit = False
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                assert base_model.strip(), (
         | 
| 410 | 
            +
                    "Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
         | 
| 411 | 
            +
                )
         | 
| 412 | 
            +
                llama_type = llama_type or "llama" in base_model
         | 
| 413 | 
            +
                model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
         | 
| 414 | 
            +
                if not tokenizer_base_model:
         | 
| 415 | 
            +
                    tokenizer_base_model = base_model
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
         | 
| 418 | 
            +
                    tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
         | 
| 419 | 
            +
                                                                 local_files_only=local_files_only,
         | 
| 420 | 
            +
                                                                 resume_download=resume_download,
         | 
| 421 | 
            +
                                                                 use_auth_token=use_auth_token,
         | 
| 422 | 
            +
                                                                 )
         | 
| 423 | 
            +
                else:
         | 
| 424 | 
            +
                    tokenizer = tokenizer_loader
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                if isinstance(tokenizer, str):
         | 
| 427 | 
            +
                    # already a pipeline, tokenizer_loader is string for task
         | 
| 428 | 
            +
                    model = model_loader(tokenizer,
         | 
| 429 | 
            +
                                         model=base_model,
         | 
| 430 | 
            +
                                         device=0 if device == "cuda" else -1,
         | 
| 431 | 
            +
                                         torch_dtype=torch.float16)
         | 
| 432 | 
            +
                else:
         | 
| 433 | 
            +
                    assert device == "cuda", "Unsupported device %s" % device
         | 
| 434 | 
            +
                    model_kwargs = dict(local_files_only=local_files_only,
         | 
| 435 | 
            +
                                        torch_dtype=torch.float16,
         | 
| 436 | 
            +
                                        resume_download=resume_download,
         | 
| 437 | 
            +
                                        use_auth_token=use_auth_token)
         | 
| 438 | 
            +
                    if 'mbart-' not in base_model.lower():
         | 
| 439 | 
            +
                        model_kwargs.update(dict(load_in_8bit=load_8bit,
         | 
| 440 | 
            +
                                                 device_map={"": 0} if load_8bit else "auto",
         | 
| 441 | 
            +
                                                 ))
         | 
| 442 | 
            +
                    if 'OpenAssistant/reward-model'.lower() in base_model.lower():
         | 
| 443 | 
            +
                        # could put on other GPUs
         | 
| 444 | 
            +
                        model_kwargs['device_map'] = {"": 0}
         | 
| 445 | 
            +
                        model_kwargs.pop('torch_dtype', None)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                    if not lora_weights:
         | 
| 448 | 
            +
                        with torch.device("cuda"):
         | 
| 449 | 
            +
                            if infer_devices:
         | 
| 450 | 
            +
                                model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
         | 
| 451 | 
            +
                                                           gpu_id=gpu_id, use_auth_token=use_auth_token)
         | 
| 452 | 
            +
                            else:
         | 
| 453 | 
            +
                                if load_half and not load_8bit:
         | 
| 454 | 
            +
                                    model = model_loader.from_pretrained(
         | 
| 455 | 
            +
                                        base_model,
         | 
| 456 | 
            +
                                        **model_kwargs).half()
         | 
| 457 | 
            +
                                else:
         | 
| 458 | 
            +
                                    model = model_loader.from_pretrained(
         | 
| 459 | 
            +
                                        base_model,
         | 
| 460 | 
            +
                                        **model_kwargs)
         | 
| 461 | 
            +
                    elif load_8bit:
         | 
| 462 | 
            +
                        model = model_loader.from_pretrained(
         | 
| 463 | 
            +
                            base_model,
         | 
| 464 | 
            +
                            **model_kwargs
         | 
| 465 | 
            +
                        )
         | 
| 466 | 
            +
                        model = PeftModel.from_pretrained(
         | 
| 467 | 
            +
                            model,
         | 
| 468 | 
            +
                            lora_weights,
         | 
| 469 | 
            +
                            torch_dtype=torch.float16,
         | 
| 470 | 
            +
                            local_files_only=local_files_only,
         | 
| 471 | 
            +
                            resume_download=resume_download,
         | 
| 472 | 
            +
                            use_auth_token=use_auth_token,
         | 
| 473 | 
            +
                            device_map={"": 0},  # seems to be required
         | 
| 474 | 
            +
                        )
         | 
| 475 | 
            +
                    else:
         | 
| 476 | 
            +
                        with torch.device("cuda"):
         | 
| 477 | 
            +
                            model = model_loader.from_pretrained(
         | 
| 478 | 
            +
                                base_model,
         | 
| 479 | 
            +
                                **model_kwargs
         | 
| 480 | 
            +
                            )
         | 
| 481 | 
            +
                            model = PeftModel.from_pretrained(
         | 
| 482 | 
            +
                                model,
         | 
| 483 | 
            +
                                lora_weights,
         | 
| 484 | 
            +
                                torch_dtype=torch.float16,
         | 
| 485 | 
            +
                                local_files_only=local_files_only,
         | 
| 486 | 
            +
                                resume_download=resume_download,
         | 
| 487 | 
            +
                                use_auth_token=use_auth_token,
         | 
| 488 | 
            +
                                device_map="auto",
         | 
| 489 | 
            +
                            )
         | 
| 490 | 
            +
                            if load_half:
         | 
| 491 | 
            +
                                model.half()
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                # unwind broken decapoda-research config
         | 
| 494 | 
            +
                if llama_type:
         | 
| 495 | 
            +
                    model.config.pad_token_id = tokenizer.pad_token_id = 0  # unk
         | 
| 496 | 
            +
                    model.config.bos_token_id = 1
         | 
| 497 | 
            +
                    model.config.eos_token_id = 2
         | 
| 498 | 
            +
                if 'gpt2' in base_model.lower():
         | 
| 499 | 
            +
                    # add special tokens that otherwise all share the same id
         | 
| 500 | 
            +
                    tokenizer.add_special_tokens({'bos_token': '<bos>',
         | 
| 501 | 
            +
                                                  'eos_token': '<eos>',
         | 
| 502 | 
            +
                                                  'pad_token': '<pad>'})
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                if not isinstance(tokenizer, str):
         | 
| 505 | 
            +
                    model.eval()
         | 
| 506 | 
            +
                    if torch.__version__ >= "2" and sys.platform != "win32" and compile:
         | 
| 507 | 
            +
                        model = torch.compile(model)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                return model, tokenizer, device
         | 
| 510 | 
            +
             | 
| 511 | 
            +
             | 
| 512 | 
            +
            def get_score_model(**kwargs):
         | 
| 513 | 
            +
                # score model
         | 
| 514 | 
            +
                if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
         | 
| 515 | 
            +
                    score_all_kwargs = kwargs.copy()
         | 
| 516 | 
            +
                    score_all_kwargs['load_8bit'] = False
         | 
| 517 | 
            +
                    score_all_kwargs['load_half'] = False
         | 
| 518 | 
            +
                    score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
         | 
| 519 | 
            +
                    score_all_kwargs['tokenizer_base_model'] = ''
         | 
| 520 | 
            +
                    score_all_kwargs['lora_weights'] = ''
         | 
| 521 | 
            +
                    score_all_kwargs['llama_type'] = False
         | 
| 522 | 
            +
                    score_all_kwargs['compile'] = False
         | 
| 523 | 
            +
                    smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
         | 
| 524 | 
            +
                else:
         | 
| 525 | 
            +
                    smodel, stokenizer, sdevice = None, None, None
         | 
| 526 | 
            +
                return smodel, stokenizer, sdevice
         | 
| 527 | 
            +
             | 
| 528 | 
            +
             | 
| 529 | 
            +
            def go_gradio(**kwargs):
         | 
| 530 | 
            +
                # get default model
         | 
| 531 | 
            +
                all_kwargs = kwargs.copy()
         | 
| 532 | 
            +
                all_kwargs.update(locals())
         | 
| 533 | 
            +
                if kwargs.get('base_model') and not kwargs['login_mode_if_model0']:
         | 
| 534 | 
            +
                    model0, tokenizer0, device = get_model(**all_kwargs)
         | 
| 535 | 
            +
                else:
         | 
| 536 | 
            +
                    # if empty model, then don't load anything, just get gradio up
         | 
| 537 | 
            +
                    model0, tokenizer0, device = None, None, None
         | 
| 538 | 
            +
                model_state0 = [model0, tokenizer0, device, kwargs['base_model']]
         | 
| 539 | 
            +
             | 
| 540 | 
            +
                # get score model
         | 
| 541 | 
            +
                smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                if 'mbart-' in kwargs['model_lower']:
         | 
| 544 | 
            +
                    instruction_label_nochat = "Text to translate"
         | 
| 545 | 
            +
                else:
         | 
| 546 | 
            +
                    instruction_label_nochat = "Instruction"
         | 
| 547 | 
            +
                instruction_label = "You (Shift-Enter or push Submit to send message)"
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                title = 'h2oGPT'
         | 
| 550 | 
            +
                if kwargs['verbose']:
         | 
| 551 | 
            +
                    description = f"""Model {kwargs['base_model']} Instruct dataset.
         | 
| 552 | 
            +
                                  For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).
         | 
| 553 | 
            +
                                  Command: {str(' '.join(sys.argv))}
         | 
| 554 | 
            +
                                  Hash: {get_githash()}
         | 
| 555 | 
            +
                                  """
         | 
| 556 | 
            +
                else:
         | 
| 557 | 
            +
                    description = "For more information, visit [the project's website](https://github.com/h2oai/h2ogpt).<br>"
         | 
| 558 | 
            +
                if is_public:
         | 
| 559 | 
            +
                    description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content.  Use at own risk.</i></li>"""
         | 
| 560 | 
            +
                    if kwargs['load_8bit']:
         | 
| 561 | 
            +
                        description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
         | 
| 562 | 
            +
                    description += """<i><li>Conversations may be used to improve h2oGPT.  Do not share sensitive information.</i></li>"""
         | 
| 563 | 
            +
                    description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                if kwargs['verbose']:
         | 
| 566 | 
            +
                    task_info_md = f"""
         | 
| 567 | 
            +
                    ### Task: {kwargs['task_info']}"""
         | 
| 568 | 
            +
                else:
         | 
| 569 | 
            +
                    task_info_md = ''
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                css_code = """footer {visibility: hidden;}
         | 
| 572 | 
            +
            body{background:linear-gradient(#f5f5f5,#e5e5e5);}
         | 
| 573 | 
            +
            body.dark{background:linear-gradient(#0d0d0d,#333333);}"""
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                from gradio.themes.utils import Color, colors, fonts, sizes
         | 
| 576 | 
            +
                if kwargs['h2ocolors']:
         | 
| 577 | 
            +
                    h2o_yellow = Color(
         | 
| 578 | 
            +
                        name="yellow",
         | 
| 579 | 
            +
                        c50="#fffef2",
         | 
| 580 | 
            +
                        c100="#fff9e6",
         | 
| 581 | 
            +
                        c200="#ffecb3",
         | 
| 582 | 
            +
                        c300="#ffe28c",
         | 
| 583 | 
            +
                        c400="#ffd659",
         | 
| 584 | 
            +
                        c500="#fec925",
         | 
| 585 | 
            +
                        c600="#e6ac00",
         | 
| 586 | 
            +
                        c700="#bf8f00",
         | 
| 587 | 
            +
                        c800="#a67c00",
         | 
| 588 | 
            +
                        c900="#664d00",
         | 
| 589 | 
            +
                        c950="#403000",
         | 
| 590 | 
            +
                    )
         | 
| 591 | 
            +
                    h2o_gray = Color(
         | 
| 592 | 
            +
                        name="gray",
         | 
| 593 | 
            +
                        c50="#f2f2f2",
         | 
| 594 | 
            +
                        c100="#e5e5e5",
         | 
| 595 | 
            +
                        c200="#cccccc",
         | 
| 596 | 
            +
                        c300="#b2b2b2",
         | 
| 597 | 
            +
                        c400="#999999",
         | 
| 598 | 
            +
                        c500="#7f7f7f",
         | 
| 599 | 
            +
                        c600="#666666",
         | 
| 600 | 
            +
                        c700="#4c4c4c",
         | 
| 601 | 
            +
                        c800="#333333",
         | 
| 602 | 
            +
                        c900="#191919",
         | 
| 603 | 
            +
                        c950="#0d0d0d",
         | 
| 604 | 
            +
                    )
         | 
| 605 | 
            +
                    colors_dict = dict(primary_hue=h2o_yellow,
         | 
| 606 | 
            +
                                       secondary_hue=h2o_yellow,
         | 
| 607 | 
            +
                                       neutral_hue=h2o_gray,
         | 
| 608 | 
            +
                                       spacing_size=sizes.spacing_md,
         | 
| 609 | 
            +
                                       radius_size=sizes.radius_md,
         | 
| 610 | 
            +
                                       text_size=sizes.text_md,
         | 
| 611 | 
            +
                                       )
         | 
| 612 | 
            +
                else:
         | 
| 613 | 
            +
                    colors_dict = dict(primary_hue=colors.indigo,
         | 
| 614 | 
            +
                                       secondary_hue=colors.indigo,
         | 
| 615 | 
            +
                                       neutral_hue=colors.gray,
         | 
| 616 | 
            +
                                       spacing_size=sizes.spacing_md,
         | 
| 617 | 
            +
                                       radius_size=sizes.radius_md,
         | 
| 618 | 
            +
                                       text_size=sizes.text_md,
         | 
| 619 | 
            +
                                       )
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                import gradio as gr
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                if kwargs['gradio_avoid_processing_markdown']:
         | 
| 624 | 
            +
                    from gradio_client import utils as client_utils
         | 
| 625 | 
            +
                    from gradio.components import Chatbot
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                    # gradio has issue with taking too long to process input/output for markdown etc.
         | 
| 628 | 
            +
                    # Avoid for now, allow raw html to render, good enough for chatbot.
         | 
| 629 | 
            +
                    def _postprocess_chat_messages(self, chat_message: str):
         | 
| 630 | 
            +
                        if chat_message is None:
         | 
| 631 | 
            +
                            return None
         | 
| 632 | 
            +
                        elif isinstance(chat_message, (tuple, list)):
         | 
| 633 | 
            +
                            filepath = chat_message[0]
         | 
| 634 | 
            +
                            mime_type = client_utils.get_mimetype(filepath)
         | 
| 635 | 
            +
                            filepath = self.make_temp_copy_if_needed(filepath)
         | 
| 636 | 
            +
                            return {
         | 
| 637 | 
            +
                                "name": filepath,
         | 
| 638 | 
            +
                                "mime_type": mime_type,
         | 
| 639 | 
            +
                                "alt_text": chat_message[1] if len(chat_message) > 1 else None,
         | 
| 640 | 
            +
                                "data": None,  # These last two fields are filled in by the frontend
         | 
| 641 | 
            +
                                "is_file": True,
         | 
| 642 | 
            +
                            }
         | 
| 643 | 
            +
                        elif isinstance(chat_message, str):
         | 
| 644 | 
            +
                            return chat_message
         | 
| 645 | 
            +
                        else:
         | 
| 646 | 
            +
                            raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                    Chatbot._postprocess_chat_messages = _postprocess_chat_messages
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                demo = gr.Blocks(theme=gr.themes.Soft(**colors_dict), css=css_code, title="h2oGPT", analytics_enabled=False)
         | 
| 651 | 
            +
                callback = gr.CSVLogger()
         | 
| 652 | 
            +
                # css_code = 'body{background-image:url("https://h2o.ai/content/experience-fragments/h2o/us/en/site/header/master/_jcr_content/root/container/header_copy/logo.coreimg.svg/1678976605175/h2o-logo.svg");}'
         | 
| 653 | 
            +
                # demo = gr.Blocks(theme='gstaff/xkcd', css=css_code)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
         | 
| 656 | 
            +
                if kwargs['base_model'].strip() not in model_options:
         | 
| 657 | 
            +
                    lora_options = [kwargs['base_model'].strip()] + model_options
         | 
| 658 | 
            +
                lora_options = kwargs['extra_lora_options']
         | 
| 659 | 
            +
                if kwargs['lora_weights'].strip() not in lora_options:
         | 
| 660 | 
            +
                    lora_options = [kwargs['lora_weights'].strip()] + lora_options
         | 
| 661 | 
            +
                # always add in no lora case
         | 
| 662 | 
            +
                # add fake space so doesn't go away in gradio dropdown
         | 
| 663 | 
            +
                no_lora_str = no_model_str = '[None/Remove]'
         | 
| 664 | 
            +
                lora_options = [no_lora_str] + kwargs['extra_lora_options']  # FIXME: why double?
         | 
| 665 | 
            +
                # always add in no model case so can free memory
         | 
| 666 | 
            +
                # add fake space so doesn't go away in gradio dropdown
         | 
| 667 | 
            +
                model_options = [no_model_str] + model_options
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                # transcribe, will be detranscribed before use by evaluate()
         | 
| 670 | 
            +
                if not kwargs['lora_weights'].strip():
         | 
| 671 | 
            +
                    kwargs['lora_weights'] = no_lora_str
         | 
| 672 | 
            +
             | 
| 673 | 
            +
                if not kwargs['base_model'].strip():
         | 
| 674 | 
            +
                    kwargs['base_model'] = no_model_str
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                # transcribe for gradio
         | 
| 677 | 
            +
                kwargs['gpu_id'] = str(kwargs['gpu_id'])
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                no_model_msg = 'h2oGPT [   !!! Please Load Model in Models Tab !!!   ]'
         | 
| 680 | 
            +
                output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
         | 
| 681 | 
            +
                    'base_model') else no_model_msg
         | 
| 682 | 
            +
                output_label0_model2 = no_model_msg
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                with demo:
         | 
| 685 | 
            +
                    # avoid actual model/tokenizer here or anything that would be bad to deepcopy
         | 
| 686 | 
            +
                    # https://github.com/gradio-app/gradio/issues/3558
         | 
| 687 | 
            +
                    model_state = gr.State(['model', 'tokenizer', device, kwargs['base_model']])
         | 
| 688 | 
            +
                    model_state2 = gr.State([None, None, None, None])
         | 
| 689 | 
            +
                    model_options_state = gr.State([model_options])
         | 
| 690 | 
            +
                    lora_options_state = gr.State([lora_options])
         | 
| 691 | 
            +
                    gr.Markdown(
         | 
| 692 | 
            +
                        f"""
         | 
| 693 | 
            +
                        <h1 align="center"> {title}</h1>
         | 
| 694 | 
            +
             | 
| 695 | 
            +
                        {description}
         | 
| 696 | 
            +
                        {task_info_md}
         | 
| 697 | 
            +
                        """)
         | 
| 698 | 
            +
                    if is_hf:
         | 
| 699 | 
            +
                        gr.HTML(
         | 
| 700 | 
            +
                            '''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                    # go button visible if
         | 
| 703 | 
            +
                    base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
         | 
| 704 | 
            +
                    go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
         | 
| 705 | 
            +
                    normal_block = gr.Row(visible=not base_wanted)
         | 
| 706 | 
            +
                    with normal_block:
         | 
| 707 | 
            +
                        with gr.Tabs():
         | 
| 708 | 
            +
                            with gr.Row():
         | 
| 709 | 
            +
                                col_nochat = gr.Column(visible=not kwargs['chat'])
         | 
| 710 | 
            +
                                with col_nochat:  # FIXME: for model comparison, and check rest
         | 
| 711 | 
            +
                                    text_output_nochat = gr.Textbox(lines=5, label=output_label0)
         | 
| 712 | 
            +
                                    instruction_nochat = gr.Textbox(
         | 
| 713 | 
            +
                                        lines=4, label=instruction_label_nochat,
         | 
| 714 | 
            +
                                        placeholder=kwargs['placeholder_instruction'],
         | 
| 715 | 
            +
                                    )
         | 
| 716 | 
            +
                                    iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
         | 
| 717 | 
            +
                                                               placeholder=kwargs['placeholder_input'])
         | 
| 718 | 
            +
                                    submit_nochat = gr.Button("Submit")
         | 
| 719 | 
            +
                                    flag_btn_nochat = gr.Button("Flag")
         | 
| 720 | 
            +
                                    if kwargs['score_model']:
         | 
| 721 | 
            +
                                        if not kwargs['auto_score']:
         | 
| 722 | 
            +
                                            with gr.Column():
         | 
| 723 | 
            +
                                                score_btn_nochat = gr.Button("Score last prompt & response")
         | 
| 724 | 
            +
                                                score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
         | 
| 725 | 
            +
                                        else:
         | 
| 726 | 
            +
                                            score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
         | 
| 727 | 
            +
                                col_chat = gr.Column(visible=kwargs['chat'])
         | 
| 728 | 
            +
                                with col_chat:
         | 
| 729 | 
            +
                                    with gr.Row():
         | 
| 730 | 
            +
                                        text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
         | 
| 731 | 
            +
                                        text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
         | 
| 732 | 
            +
                                            height=kwargs['height'] or 400)
         | 
| 733 | 
            +
                                    with gr.Row():
         | 
| 734 | 
            +
                                        with gr.Column(scale=50):
         | 
| 735 | 
            +
                                            instruction = gr.Textbox(
         | 
| 736 | 
            +
                                                lines=4, label=instruction_label,
         | 
| 737 | 
            +
                                                placeholder=kwargs['placeholder_instruction'],
         | 
| 738 | 
            +
                                            )
         | 
| 739 | 
            +
                                        with gr.Row():  # .style(equal_height=False, equal_width=False):
         | 
| 740 | 
            +
                                            submit = gr.Button(value='Submit').style(full_width=False, size='sm')
         | 
| 741 | 
            +
                                            stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
         | 
| 742 | 
            +
                                    with gr.Row():
         | 
| 743 | 
            +
                                        clear = gr.Button("New Conversation")
         | 
| 744 | 
            +
                                        flag_btn = gr.Button("Flag")
         | 
| 745 | 
            +
                                        if kwargs['score_model']:
         | 
| 746 | 
            +
                                            if not kwargs['auto_score']:  # FIXME: For checkbox model2
         | 
| 747 | 
            +
                                                with gr.Column():
         | 
| 748 | 
            +
                                                    with gr.Row():
         | 
| 749 | 
            +
                                                        score_btn = gr.Button("Score last prompt & response").style(
         | 
| 750 | 
            +
                                                            full_width=False, size='sm')
         | 
| 751 | 
            +
                                                        score_text = gr.Textbox("Response Score: NA", show_label=False)
         | 
| 752 | 
            +
                                                    score_res2 = gr.Row(visible=False)
         | 
| 753 | 
            +
                                                    with score_res2:
         | 
| 754 | 
            +
                                                        score_btn2 = gr.Button("Score last prompt & response 2").style(
         | 
| 755 | 
            +
                                                            full_width=False, size='sm')
         | 
| 756 | 
            +
                                                        score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
         | 
| 757 | 
            +
                                            else:
         | 
| 758 | 
            +
                                                score_text = gr.Textbox("Response Score: NA", show_label=False)
         | 
| 759 | 
            +
                                                score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
         | 
| 760 | 
            +
                                        retry = gr.Button("Regenerate")
         | 
| 761 | 
            +
                                        undo = gr.Button("Undo")
         | 
| 762 | 
            +
                            with gr.TabItem("Input/Output"):
         | 
| 763 | 
            +
                                with gr.Row():
         | 
| 764 | 
            +
                                    if 'mbart-' in kwargs['model_lower']:
         | 
| 765 | 
            +
                                        src_lang = gr.Dropdown(list(languages_covered().keys()),
         | 
| 766 | 
            +
                                                               value=kwargs['src_lang'],
         | 
| 767 | 
            +
                                                               label="Input Language")
         | 
| 768 | 
            +
                                        tgt_lang = gr.Dropdown(list(languages_covered().keys()),
         | 
| 769 | 
            +
                                                               value=kwargs['tgt_lang'],
         | 
| 770 | 
            +
                                                               label="Output Language")
         | 
| 771 | 
            +
                            with gr.TabItem("Expert"):
         | 
| 772 | 
            +
                                with gr.Row():
         | 
| 773 | 
            +
                                    with gr.Column():
         | 
| 774 | 
            +
                                        stream_output = gr.components.Checkbox(label="Stream output",
         | 
| 775 | 
            +
                                                                               value=kwargs['stream_output'])
         | 
| 776 | 
            +
                                        prompt_type = gr.Dropdown(prompt_types_strings,
         | 
| 777 | 
            +
                                                                  value=kwargs['prompt_type'], label="Prompt Type",
         | 
| 778 | 
            +
                                                                  visible=not is_public)
         | 
| 779 | 
            +
                                        prompt_type2 = gr.Dropdown(prompt_types_strings,
         | 
| 780 | 
            +
                                                                   value=kwargs['prompt_type'], label="Prompt Type Model 2",
         | 
| 781 | 
            +
                                                                   visible=not is_public and False)
         | 
| 782 | 
            +
                                        do_sample = gr.Checkbox(label="Sample", info="Enable sampler, required for use of temperature, top_p, top_k",
         | 
| 783 | 
            +
                                                                value=kwargs['do_sample'])
         | 
| 784 | 
            +
                                        temperature = gr.Slider(minimum=0.01, maximum=3,
         | 
| 785 | 
            +
                                                                value=kwargs['temperature'],
         | 
| 786 | 
            +
                                                                label="Temperature",
         | 
| 787 | 
            +
                                                                info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
         | 
| 788 | 
            +
                                        top_p = gr.Slider(minimum=0, maximum=1,
         | 
| 789 | 
            +
                                                          value=kwargs['top_p'], label="Top p",
         | 
| 790 | 
            +
                                                          info="Cumulative probability of tokens to sample from")
         | 
| 791 | 
            +
                                        top_k = gr.Slider(
         | 
| 792 | 
            +
                                            minimum=0, maximum=100, step=1,
         | 
| 793 | 
            +
                                            value=kwargs['top_k'], label="Top k",
         | 
| 794 | 
            +
                                            info='Num. tokens to sample from'
         | 
| 795 | 
            +
                                        )
         | 
| 796 | 
            +
                                        max_beams = 8 if not is_low_mem else 2
         | 
| 797 | 
            +
                                        num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
         | 
| 798 | 
            +
                                                              value=min(max_beams, kwargs['num_beams']), label="Beams",
         | 
| 799 | 
            +
                                                              info="Number of searches for optimal overall probability.  "
         | 
| 800 | 
            +
                                                                   "Uses more GPU memory/compute")
         | 
| 801 | 
            +
                                        max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
         | 
| 802 | 
            +
                                        max_new_tokens = gr.Slider(
         | 
| 803 | 
            +
                                            minimum=1, maximum=max_max_new_tokens, step=1,
         | 
| 804 | 
            +
                                            value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
         | 
| 805 | 
            +
                                        )
         | 
| 806 | 
            +
                                        min_new_tokens = gr.Slider(
         | 
| 807 | 
            +
                                            minimum=0, maximum=max_max_new_tokens, step=1,
         | 
| 808 | 
            +
                                            value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
         | 
| 809 | 
            +
                                        )
         | 
| 810 | 
            +
                                        early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
         | 
| 811 | 
            +
                                                                     value=kwargs['early_stopping'])
         | 
| 812 | 
            +
                                        max_max_time = 60 * 5 if not is_low_mem else 60
         | 
| 813 | 
            +
                                        max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
         | 
| 814 | 
            +
                                                             value=min(max_max_time, kwargs['max_time']), label="Max. time",
         | 
| 815 | 
            +
                                                             info="Max. time to search optimal output.")
         | 
| 816 | 
            +
                                        repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
         | 
| 817 | 
            +
                                                                       value=kwargs['repetition_penalty'],
         | 
| 818 | 
            +
                                                                       label="Repetition Penalty")
         | 
| 819 | 
            +
                                        num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
         | 
| 820 | 
            +
                                                                         value=kwargs['num_return_sequences'],
         | 
| 821 | 
            +
                                                                         label="Number Returns", info="Must be <= num_beams",
         | 
| 822 | 
            +
                                                                         visible=not is_public)
         | 
| 823 | 
            +
                                        iinput = gr.Textbox(lines=4, label="Input",
         | 
| 824 | 
            +
                                                            placeholder=kwargs['placeholder_input'],
         | 
| 825 | 
            +
                                                            visible=not is_public)
         | 
| 826 | 
            +
                                        context = gr.Textbox(lines=3, label="System Pre-Context",
         | 
| 827 | 
            +
                                                             info="Directly pre-appended without prompt processing",
         | 
| 828 | 
            +
                                                             visible=not is_public and not kwargs['chat'])
         | 
| 829 | 
            +
                                        chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
         | 
| 830 | 
            +
                                                                      visible=not is_public)
         | 
| 831 | 
            +
             | 
| 832 | 
            +
                            with gr.TabItem("Models"):
         | 
| 833 | 
            +
                                load_msg = "Load-Unload Model/LORA" if not is_public \
         | 
| 834 | 
            +
                                    else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
         | 
| 835 | 
            +
                                load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
         | 
| 836 | 
            +
                                    else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
         | 
| 837 | 
            +
                                compare_checkbox = gr.components.Checkbox(label="Compare Mode",
         | 
| 838 | 
            +
                                                                          value=False, visible=not is_public)
         | 
| 839 | 
            +
                                with gr.Row():
         | 
| 840 | 
            +
                                    n_gpus = torch.cuda.device_count()
         | 
| 841 | 
            +
                                    n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
         | 
| 842 | 
            +
                                    with gr.Column():
         | 
| 843 | 
            +
                                        with gr.Row(scale=1):
         | 
| 844 | 
            +
                                            with gr.Column(scale=50):
         | 
| 845 | 
            +
                                                model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
         | 
| 846 | 
            +
                                                                           value=kwargs['base_model'])
         | 
| 847 | 
            +
                                                lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
         | 
| 848 | 
            +
                                                                          value=kwargs['lora_weights'], visible=kwargs['show_lora'])
         | 
| 849 | 
            +
                                            with gr.Column(scale=1):
         | 
| 850 | 
            +
                                                load_model_button = gr.Button(load_msg)
         | 
| 851 | 
            +
                                                model_load8bit_checkbox = gr.components.Checkbox(
         | 
| 852 | 
            +
                                                    label="Load 8-bit [Not all models support]",
         | 
| 853 | 
            +
                                                    value=kwargs['load_8bit'])
         | 
| 854 | 
            +
                                                model_infer_devices_checkbox = gr.components.Checkbox(
         | 
| 855 | 
            +
                                                    label="Infer Devices [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
         | 
| 856 | 
            +
                                                    value=kwargs['infer_devices'])
         | 
| 857 | 
            +
                                                model_gpu = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
         | 
| 858 | 
            +
                                                                        value=kwargs['gpu_id'])
         | 
| 859 | 
            +
                                                model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
         | 
| 860 | 
            +
                                                lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
         | 
| 861 | 
            +
                                                                       visible=kwargs['show_lora'])
         | 
| 862 | 
            +
                                        with gr.Row(scale=1):
         | 
| 863 | 
            +
                                            with gr.Column(scale=50):
         | 
| 864 | 
            +
                                                new_model = gr.Textbox(label="New Model HF name/path")
         | 
| 865 | 
            +
                                                new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
         | 
| 866 | 
            +
                                            with gr.Column(scale=1):
         | 
| 867 | 
            +
                                                add_model_button = gr.Button("Add new model name")
         | 
| 868 | 
            +
                                                add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
         | 
| 869 | 
            +
                                    col_model2 = gr.Column(visible=False)
         | 
| 870 | 
            +
                                    with col_model2:
         | 
| 871 | 
            +
                                        with gr.Row(scale=1):
         | 
| 872 | 
            +
                                            with gr.Column(scale=50):
         | 
| 873 | 
            +
                                                model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
         | 
| 874 | 
            +
                                                                            value=no_model_str)
         | 
| 875 | 
            +
                                                lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
         | 
| 876 | 
            +
                                                                           value=no_lora_str,
         | 
| 877 | 
            +
                                                                           visible=kwargs['show_lora'])
         | 
| 878 | 
            +
                                            with gr.Column(scale=1):
         | 
| 879 | 
            +
                                                load_model_button2 = gr.Button(load_msg2)
         | 
| 880 | 
            +
                                                model_load8bit_checkbox2 = gr.components.Checkbox(
         | 
| 881 | 
            +
                                                    label="Load 8-bit 2 [Not all models support]",
         | 
| 882 | 
            +
                                                    value=kwargs['load_8bit'])
         | 
| 883 | 
            +
                                                model_infer_devices_checkbox2 = gr.components.Checkbox(
         | 
| 884 | 
            +
                                                    label="Infer Devices 2 [If GPU ID=-1 or not Checked, then will spread model over GPUs]",
         | 
| 885 | 
            +
                                                    value=kwargs[
         | 
| 886 | 
            +
                                                        'infer_devices'])
         | 
| 887 | 
            +
                                                model_gpu2 = gr.Dropdown(n_gpus_list, label="GPU ID [-1 = all GPUs]",
         | 
| 888 | 
            +
                                                                         value=kwargs['gpu_id'])
         | 
| 889 | 
            +
                                                # no model/lora loaded ever in model2 by default
         | 
| 890 | 
            +
                                                model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
         | 
| 891 | 
            +
                                                lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
         | 
| 892 | 
            +
                                                                        visible=kwargs['show_lora'])
         | 
| 893 | 
            +
                            with gr.TabItem("System"):
         | 
| 894 | 
            +
                                system_row = gr.Row(visible=not is_public)
         | 
| 895 | 
            +
                                admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
         | 
| 896 | 
            +
                                admin_btn = gr.Button(value="admin", visible=is_public)
         | 
| 897 | 
            +
                                with system_row:
         | 
| 898 | 
            +
                                    with gr.Column():
         | 
| 899 | 
            +
                                        system_text = gr.Textbox(label='System Info')
         | 
| 900 | 
            +
                                        system_btn = gr.Button(value='Get System Info')
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                                        zip_btn = gr.Button("Zip")
         | 
| 903 | 
            +
                                        file_output = gr.File()
         | 
| 904 | 
            +
             | 
| 905 | 
            +
                    # Get flagged data
         | 
| 906 | 
            +
                    zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
         | 
| 907 | 
            +
                    zip_btn.click(zip_data1, inputs=None, outputs=file_output)
         | 
| 908 | 
            +
             | 
| 909 | 
            +
                    def check_admin_pass(x):
         | 
| 910 | 
            +
                        return gr.update(visible=x == admin_pass)
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row)
         | 
| 913 | 
            +
             | 
| 914 | 
            +
                    # Get inputs to evaluate()
         | 
| 915 | 
            +
                    inputs_list = get_inputs_list(locals(), kwargs['model_lower'])
         | 
| 916 | 
            +
                    from functools import partial
         | 
| 917 | 
            +
                    all_kwargs = kwargs.copy()
         | 
| 918 | 
            +
                    all_kwargs.update(locals())
         | 
| 919 | 
            +
                    kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
         | 
| 920 | 
            +
                    fun = partial(evaluate,
         | 
| 921 | 
            +
                                  **kwargs_evaluate)
         | 
| 922 | 
            +
                    fun2 = partial(evaluate,
         | 
| 923 | 
            +
                                   model_state2,
         | 
| 924 | 
            +
                                   **kwargs_evaluate)
         | 
| 925 | 
            +
             | 
| 926 | 
            +
                    dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
         | 
| 927 | 
            +
                        size="sm",
         | 
| 928 | 
            +
                    )
         | 
| 929 | 
            +
                    dark_mode_btn.click(
         | 
| 930 | 
            +
                        None,
         | 
| 931 | 
            +
                        None,
         | 
| 932 | 
            +
                        None,
         | 
| 933 | 
            +
                        _js="""() => {
         | 
| 934 | 
            +
                        if (document.querySelectorAll('.dark').length) {
         | 
| 935 | 
            +
                            document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
         | 
| 936 | 
            +
                        } else {
         | 
| 937 | 
            +
                            document.querySelector('body').classList.add('dark');
         | 
| 938 | 
            +
                        }
         | 
| 939 | 
            +
                    }""",
         | 
| 940 | 
            +
                        api_name="dark",
         | 
| 941 | 
            +
                    )
         | 
| 942 | 
            +
             | 
| 943 | 
            +
                    # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
         | 
| 944 | 
            +
                    def col_nochat_fun(x):
         | 
| 945 | 
            +
                        return gr.Column.update(visible=not x)
         | 
| 946 | 
            +
             | 
| 947 | 
            +
                    def col_chat_fun(x):
         | 
| 948 | 
            +
                        return gr.Column.update(visible=x)
         | 
| 949 | 
            +
             | 
| 950 | 
            +
                    def context_fun(x):
         | 
| 951 | 
            +
                        return gr.Textbox.update(visible=not x)
         | 
| 952 | 
            +
             | 
| 953 | 
            +
                    chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox") \
         | 
| 954 | 
            +
                        .then(col_chat_fun, chat, col_chat) \
         | 
| 955 | 
            +
                        .then(context_fun, chat, context)
         | 
| 956 | 
            +
             | 
| 957 | 
            +
                    # examples after submit or any other buttons for chat or no chat
         | 
| 958 | 
            +
                    if kwargs['examples'] is not None and kwargs['show_examples']:
         | 
| 959 | 
            +
                        gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                    # Score
         | 
| 962 | 
            +
                    def score_last_response(*args, nochat=False, model2=False):
         | 
| 963 | 
            +
                        """ Similar to user() """
         | 
| 964 | 
            +
                        args_list = list(args)
         | 
| 965 | 
            +
             | 
| 966 | 
            +
                        max_length_tokenize = 512 if is_low_mem else 2048
         | 
| 967 | 
            +
                        cutoff_len = max_length_tokenize * 4  # restrict deberta related to max for LLM
         | 
| 968 | 
            +
             | 
| 969 | 
            +
                        if not nochat:
         | 
| 970 | 
            +
                            history = args_list[-1]
         | 
| 971 | 
            +
                            if history is None:
         | 
| 972 | 
            +
                                if not model2:
         | 
| 973 | 
            +
                                    # maybe only doing first model, no need to complain
         | 
| 974 | 
            +
                                    print("Bad history in scoring last response, fix for now", flush=True)
         | 
| 975 | 
            +
                                history = []
         | 
| 976 | 
            +
                            if smodel is not None and \
         | 
| 977 | 
            +
                                    stokenizer is not None and \
         | 
| 978 | 
            +
                                    sdevice is not None and \
         | 
| 979 | 
            +
                                    history is not None and len(history) > 0 and \
         | 
| 980 | 
            +
                                    history[-1] is not None and \
         | 
| 981 | 
            +
                                    len(history[-1]) >= 2:
         | 
| 982 | 
            +
                                os.environ['TOKENIZERS_PARALLELISM'] = 'false'
         | 
| 983 | 
            +
             | 
| 984 | 
            +
                                question = history[-1][0]
         | 
| 985 | 
            +
             | 
| 986 | 
            +
                                answer = history[-1][1]
         | 
| 987 | 
            +
                            else:
         | 
| 988 | 
            +
                                return 'Response Score: NA'
         | 
| 989 | 
            +
                        else:
         | 
| 990 | 
            +
                            answer = args_list[-1]
         | 
| 991 | 
            +
                            instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
         | 
| 992 | 
            +
                            question = args_list[instruction_nochat_arg_id]
         | 
| 993 | 
            +
             | 
| 994 | 
            +
                        if question is None:
         | 
| 995 | 
            +
                            return 'Response Score: Bad Question'
         | 
| 996 | 
            +
                        if answer is None:
         | 
| 997 | 
            +
                            return 'Response Score: Bad Answer'
         | 
| 998 | 
            +
             | 
| 999 | 
            +
                        question = question[-cutoff_len:]
         | 
| 1000 | 
            +
                        answer = answer[-cutoff_len:]
         | 
| 1001 | 
            +
             | 
| 1002 | 
            +
                        inputs = stokenizer(question, answer,
         | 
| 1003 | 
            +
                                            return_tensors="pt",
         | 
| 1004 | 
            +
                                            truncation=True,
         | 
| 1005 | 
            +
                                            max_length=max_length_tokenize).to(smodel.device)
         | 
| 1006 | 
            +
                        try:
         | 
| 1007 | 
            +
                            score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
         | 
| 1008 | 
            +
                        except torch.cuda.OutOfMemoryError as e:
         | 
| 1009 | 
            +
                            print("GPU OOM: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
         | 
| 1010 | 
            +
                            del inputs
         | 
| 1011 | 
            +
                            traceback.print_exc()
         | 
| 1012 | 
            +
                            clear_torch_cache()
         | 
| 1013 | 
            +
                            return 'Response Score: GPU OOM'
         | 
| 1014 | 
            +
                        except (Exception, RuntimeError) as e:
         | 
| 1015 | 
            +
                            if 'Expected all tensors to be on the same device' in str(e) or \
         | 
| 1016 | 
            +
                                    'expected scalar type Half but found Float' in str(e) or \
         | 
| 1017 | 
            +
                                    'probability tensor contains either' in str(e) or \
         | 
| 1018 | 
            +
                                    'cublasLt ran into an error!' in str(e):
         | 
| 1019 | 
            +
                                print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
         | 
| 1020 | 
            +
                                      flush=True)
         | 
| 1021 | 
            +
                                traceback.print_exc()
         | 
| 1022 | 
            +
                                clear_torch_cache()
         | 
| 1023 | 
            +
                                return 'Response Score: GPU Error'
         | 
| 1024 | 
            +
                            else:
         | 
| 1025 | 
            +
                                raise
         | 
| 1026 | 
            +
                        os.environ['TOKENIZERS_PARALLELISM'] = 'true'
         | 
| 1027 | 
            +
                        return 'Response Score: {:.1%}'.format(score)
         | 
| 1028 | 
            +
             | 
| 1029 | 
            +
                    if kwargs['score_model']:
         | 
| 1030 | 
            +
                        score_args = dict(fn=score_last_response,
         | 
| 1031 | 
            +
                                          inputs=inputs_list + [text_output],
         | 
| 1032 | 
            +
                                          outputs=[score_text],
         | 
| 1033 | 
            +
                                          )
         | 
| 1034 | 
            +
                        score_args2 = dict(fn=partial(score_last_response, model2=True),
         | 
| 1035 | 
            +
                                           inputs=inputs_list + [text_output2],
         | 
| 1036 | 
            +
                                           outputs=[score_text2],
         | 
| 1037 | 
            +
                                           )
         | 
| 1038 | 
            +
             | 
| 1039 | 
            +
                        score_args_nochat = dict(fn=partial(score_last_response, nochat=True),
         | 
| 1040 | 
            +
                                                 inputs=inputs_list + [text_output_nochat],
         | 
| 1041 | 
            +
                                                 outputs=[score_text_nochat],
         | 
| 1042 | 
            +
                                                 )
         | 
| 1043 | 
            +
                        if not kwargs['auto_score']:
         | 
| 1044 | 
            +
                            score_event = score_btn.click(**score_args, queue=stream_output, api_name='score') \
         | 
| 1045 | 
            +
                                .then(**score_args2, queue=stream_output, api_name='score2')
         | 
| 1046 | 
            +
                            score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=stream_output,
         | 
| 1047 | 
            +
                                                                        api_name='score_nochat')
         | 
| 1048 | 
            +
             | 
| 1049 | 
            +
                    def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
         | 
| 1050 | 
            +
                        """
         | 
| 1051 | 
            +
                        User that fills history for bot
         | 
| 1052 | 
            +
                        :param args:
         | 
| 1053 | 
            +
                        :param undo:
         | 
| 1054 | 
            +
                        :param sanitize_user_prompt:
         | 
| 1055 | 
            +
                        :param model2:
         | 
| 1056 | 
            +
                        :return:
         | 
| 1057 | 
            +
                        """
         | 
| 1058 | 
            +
                        args_list = list(args)
         | 
| 1059 | 
            +
                        user_message = args_list[0]
         | 
| 1060 | 
            +
                        input1 = args_list[1]
         | 
| 1061 | 
            +
                        context1 = args_list[2]
         | 
| 1062 | 
            +
                        if input1 and not user_message.endswith(':'):
         | 
| 1063 | 
            +
                            user_message1 = user_message + ":" + input1
         | 
| 1064 | 
            +
                        elif input1:
         | 
| 1065 | 
            +
                            user_message1 = user_message + input1
         | 
| 1066 | 
            +
                        else:
         | 
| 1067 | 
            +
                            user_message1 = user_message
         | 
| 1068 | 
            +
                        if sanitize_user_prompt:
         | 
| 1069 | 
            +
                            from better_profanity import profanity
         | 
| 1070 | 
            +
                            user_message1 = profanity.censor(user_message1)
         | 
| 1071 | 
            +
             | 
| 1072 | 
            +
                        history = args_list[-1]
         | 
| 1073 | 
            +
                        if undo and history:
         | 
| 1074 | 
            +
                            history.pop()
         | 
| 1075 | 
            +
                        args_list = args_list[:-1]  # FYI, even if unused currently
         | 
| 1076 | 
            +
                        if history is None:
         | 
| 1077 | 
            +
                            if not model2:
         | 
| 1078 | 
            +
                                # no need to complain so often unless model1
         | 
| 1079 | 
            +
                                print("Bad history, fix for now", flush=True)
         | 
| 1080 | 
            +
                            history = []
         | 
| 1081 | 
            +
                        # ensure elements not mixed across models as output,
         | 
| 1082 | 
            +
                        # even if input is currently same source
         | 
| 1083 | 
            +
                        history = history.copy()
         | 
| 1084 | 
            +
                        if undo:
         | 
| 1085 | 
            +
                            return history
         | 
| 1086 | 
            +
                        else:
         | 
| 1087 | 
            +
                            # FIXME: compare, same history for now
         | 
| 1088 | 
            +
                            return history + [[user_message1, None]]
         | 
| 1089 | 
            +
             | 
| 1090 | 
            +
                    def bot(*args, retry=False):
         | 
| 1091 | 
            +
                        """
         | 
| 1092 | 
            +
                        bot that consumes history for user input
         | 
| 1093 | 
            +
                        instruction (from input_list) itself is not consumed by bot
         | 
| 1094 | 
            +
                        :param args:
         | 
| 1095 | 
            +
                        :param retry:
         | 
| 1096 | 
            +
                        :return:
         | 
| 1097 | 
            +
                        """
         | 
| 1098 | 
            +
                        args_list = list(args).copy()
         | 
| 1099 | 
            +
                        history = args_list[-1]  # model_state is -2
         | 
| 1100 | 
            +
                        if retry and history:
         | 
| 1101 | 
            +
                            history.pop()
         | 
| 1102 | 
            +
                        if not history:
         | 
| 1103 | 
            +
                            print("No history", flush=True)
         | 
| 1104 | 
            +
                            return
         | 
| 1105 | 
            +
                        # ensure output will be unique to models
         | 
| 1106 | 
            +
                        history = history.copy()
         | 
| 1107 | 
            +
                        instruction1 = history[-1][0]
         | 
| 1108 | 
            +
                        context1 = ''
         | 
| 1109 | 
            +
                        if kwargs['chat_history'] > 0:
         | 
| 1110 | 
            +
                            prompt_type_arg_id = eval_func_param_names.index('prompt_type')
         | 
| 1111 | 
            +
                            prompt_type1 = args_list[prompt_type_arg_id]
         | 
| 1112 | 
            +
                            chat_arg_id = eval_func_param_names.index('chat')
         | 
| 1113 | 
            +
                            chat1 = args_list[chat_arg_id]
         | 
| 1114 | 
            +
                            context1 = ''
         | 
| 1115 | 
            +
                            for histi in range(len(history) - 1):
         | 
| 1116 | 
            +
                                data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
         | 
| 1117 | 
            +
                                context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
         | 
| 1118 | 
            +
                                    '<br>', '\n')
         | 
| 1119 | 
            +
                                if not context1.endswith('\n'):
         | 
| 1120 | 
            +
                                    context1 += '\n'
         | 
| 1121 | 
            +
                            if context1 and not context1.endswith('\n'):
         | 
| 1122 | 
            +
                                context1 += '\n'  # ensure if terminates abruptly, then human continues on next line
         | 
| 1123 | 
            +
                        args_list[0] = instruction1  # override original instruction with history from user
         | 
| 1124 | 
            +
                        # only include desired chat history
         | 
| 1125 | 
            +
                        args_list[2] = context1[-kwargs['chat_history']:]
         | 
| 1126 | 
            +
                        model_state1 = args_list[-2]
         | 
| 1127 | 
            +
                        if model_state1[0] is None or model_state1[0] == no_model_str:
         | 
| 1128 | 
            +
                            return
         | 
| 1129 | 
            +
                        args_list = args_list[:-2]
         | 
| 1130 | 
            +
                        fun1 = partial(evaluate,
         | 
| 1131 | 
            +
                                       model_state1,
         | 
| 1132 | 
            +
                                       **kwargs_evaluate)
         | 
| 1133 | 
            +
                        try:
         | 
| 1134 | 
            +
                            for output in fun1(*tuple(args_list)):
         | 
| 1135 | 
            +
                                bot_message = output
         | 
| 1136 | 
            +
                                history[-1][1] = bot_message
         | 
| 1137 | 
            +
                                yield history
         | 
| 1138 | 
            +
                        except StopIteration:
         | 
| 1139 | 
            +
                            yield history
         | 
| 1140 | 
            +
                        except RuntimeError as e:
         | 
| 1141 | 
            +
                            if "generator raised StopIteration" in str(e):
         | 
| 1142 | 
            +
                                # assume last entry was bad, undo
         | 
| 1143 | 
            +
                                history.pop()
         | 
| 1144 | 
            +
                                yield history
         | 
| 1145 | 
            +
                            raise
         | 
| 1146 | 
            +
                        except Exception as e:
         | 
| 1147 | 
            +
                            # put error into user input
         | 
| 1148 | 
            +
                            history[-1][0] = "Exception: %s" % str(e)
         | 
| 1149 | 
            +
                            yield history
         | 
| 1150 | 
            +
                            raise
         | 
| 1151 | 
            +
                        return
         | 
| 1152 | 
            +
             | 
| 1153 | 
            +
                    # NORMAL MODEL
         | 
| 1154 | 
            +
                    user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
         | 
| 1155 | 
            +
                                     inputs=inputs_list + [text_output],
         | 
| 1156 | 
            +
                                     outputs=text_output,
         | 
| 1157 | 
            +
                                     )
         | 
| 1158 | 
            +
                    bot_args = dict(fn=bot,
         | 
| 1159 | 
            +
                                    inputs=inputs_list + [model_state] + [text_output],
         | 
| 1160 | 
            +
                                    outputs=text_output,
         | 
| 1161 | 
            +
                                    )
         | 
| 1162 | 
            +
                    retry_bot_args = dict(fn=functools.partial(bot, retry=True),
         | 
| 1163 | 
            +
                                          inputs=inputs_list + [model_state] + [text_output],
         | 
| 1164 | 
            +
                                          outputs=text_output,
         | 
| 1165 | 
            +
                                          )
         | 
| 1166 | 
            +
                    undo_user_args = dict(fn=functools.partial(user, undo=True),
         | 
| 1167 | 
            +
                                          inputs=inputs_list + [text_output],
         | 
| 1168 | 
            +
                                          outputs=text_output,
         | 
| 1169 | 
            +
                                          )
         | 
| 1170 | 
            +
             | 
| 1171 | 
            +
                    # MODEL2
         | 
| 1172 | 
            +
                    user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
         | 
| 1173 | 
            +
                                      inputs=inputs_list + [text_output2],
         | 
| 1174 | 
            +
                                      outputs=text_output2,
         | 
| 1175 | 
            +
                                      )
         | 
| 1176 | 
            +
                    bot_args2 = dict(fn=bot,
         | 
| 1177 | 
            +
                                     inputs=inputs_list + [model_state2] + [text_output2],
         | 
| 1178 | 
            +
                                     outputs=text_output2,
         | 
| 1179 | 
            +
                                     )
         | 
| 1180 | 
            +
                    retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
         | 
| 1181 | 
            +
                                           inputs=inputs_list + [model_state2] + [text_output2],
         | 
| 1182 | 
            +
                                           outputs=text_output2,
         | 
| 1183 | 
            +
                                           )
         | 
| 1184 | 
            +
                    undo_user_args2 = dict(fn=functools.partial(user, undo=True),
         | 
| 1185 | 
            +
                                           inputs=inputs_list + [text_output2],
         | 
| 1186 | 
            +
                                           outputs=text_output2,
         | 
| 1187 | 
            +
                                           )
         | 
| 1188 | 
            +
             | 
| 1189 | 
            +
                    def clear_instruct():
         | 
| 1190 | 
            +
                        return gr.Textbox.update(value='')
         | 
| 1191 | 
            +
             | 
| 1192 | 
            +
                    if kwargs['auto_score']:
         | 
| 1193 | 
            +
                        # in case 2nd model, consume instruction first, so can clear quickly
         | 
| 1194 | 
            +
                        # bot doesn't consume instruction itself, just history from user, so why works
         | 
| 1195 | 
            +
                        submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
         | 
| 1196 | 
            +
                            .then(**user_args2, queue=stream_output, api_name='instruction2') \
         | 
| 1197 | 
            +
                            .then(clear_instruct, None, instruction) \
         | 
| 1198 | 
            +
                            .then(**bot_args, api_name='instruction_bot') \
         | 
| 1199 | 
            +
                            .then(**score_args, api_name='instruction_bot_score') \
         | 
| 1200 | 
            +
                            .then(**bot_args2, api_name='instruction_bot2') \
         | 
| 1201 | 
            +
                            .then(**score_args2, api_name='instruction_bot_score2') \
         | 
| 1202 | 
            +
                            .then(clear_torch_cache)
         | 
| 1203 | 
            +
                        submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
         | 
| 1204 | 
            +
                            .then(**user_args2, queue=stream_output, api_name='submit2') \
         | 
| 1205 | 
            +
                            .then(**bot_args, api_name='submit_bot') \
         | 
| 1206 | 
            +
                            .then(clear_instruct, None, instruction) \
         | 
| 1207 | 
            +
                            .then(**score_args, api_name='submit_bot_score') \
         | 
| 1208 | 
            +
                            .then(**bot_args2, api_name='submit_bot2') \
         | 
| 1209 | 
            +
                            .then(**score_args2, api_name='submit_bot_score2') \
         | 
| 1210 | 
            +
                            .then(clear_torch_cache)
         | 
| 1211 | 
            +
                        submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
         | 
| 1212 | 
            +
                            .then(**user_args2, queue=stream_output, api_name='retry2') \
         | 
| 1213 | 
            +
                            .then(clear_instruct, None, instruction) \
         | 
| 1214 | 
            +
                            .then(**retry_bot_args, api_name='retry_bot') \
         | 
| 1215 | 
            +
                            .then(**score_args, api_name='retry_bot_score') \
         | 
| 1216 | 
            +
                            .then(**retry_bot_args2, api_name='retry_bot2') \
         | 
| 1217 | 
            +
                            .then(**score_args2, api_name='retry_bot_score2') \
         | 
| 1218 | 
            +
                            .then(clear_torch_cache)
         | 
| 1219 | 
            +
                        submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
         | 
| 1220 | 
            +
                            .then(**score_args, api_name='undo_score') \
         | 
| 1221 | 
            +
                            .then(**undo_user_args2, queue=stream_output, api_name='undo2') \
         | 
| 1222 | 
            +
                            .then(**score_args2, api_name='undo_score2') \
         | 
| 1223 | 
            +
                            .then(clear_instruct, None, instruction)
         | 
| 1224 | 
            +
                    else:
         | 
| 1225 | 
            +
                        submit_event = instruction.submit(**user_args, queue=stream_output, api_name='instruction') \
         | 
| 1226 | 
            +
                            .then(**user_args2, queue=stream_output, api_name='instruction2') \
         | 
| 1227 | 
            +
                            .then(clear_instruct, None, instruction) \
         | 
| 1228 | 
            +
                            .then(**bot_args, api_name='instruction_bot') \
         | 
| 1229 | 
            +
                            .then(**bot_args2, api_name='instruction_bot2') \
         | 
| 1230 | 
            +
                            .then(clear_torch_cache)
         | 
| 1231 | 
            +
                        submit_event2 = submit.click(**user_args, queue=stream_output, api_name='submit') \
         | 
| 1232 | 
            +
                            .then(**user_args2, queue=stream_output, api_name='submit2') \
         | 
| 1233 | 
            +
                            .then(clear_instruct, None, instruction) \
         | 
| 1234 | 
            +
                            .then(**bot_args, api_name='submit_bot') \
         | 
| 1235 | 
            +
                            .then(**bot_args2, api_name='submit_bot2') \
         | 
| 1236 | 
            +
                            .then(clear_torch_cache)
         | 
| 1237 | 
            +
                        submit_event3 = retry.click(**user_args, queue=stream_output, api_name='retry') \
         | 
| 1238 | 
            +
                            .then(**user_args2, queue=stream_output, api_name='retry2') \
         | 
| 1239 | 
            +
                            .then(clear_instruct, None, instruction) \
         | 
| 1240 | 
            +
                            .then(**retry_bot_args, api_name='retry_bot') \
         | 
| 1241 | 
            +
                            .then(**retry_bot_args2, api_name='retry_bot2') \
         | 
| 1242 | 
            +
                            .then(clear_torch_cache)
         | 
| 1243 | 
            +
                        submit_event4 = undo.click(**undo_user_args, queue=stream_output, api_name='undo') \
         | 
| 1244 | 
            +
                            .then(**undo_user_args2, queue=stream_output, api_name='undo2')
         | 
| 1245 | 
            +
             | 
| 1246 | 
            +
                    # does both models
         | 
| 1247 | 
            +
                    clear.click(lambda: None, None, text_output, queue=False, api_name='clear') \
         | 
| 1248 | 
            +
                        .then(lambda: None, None, text_output2, queue=False, api_name='clear2')
         | 
| 1249 | 
            +
                    # FIXME: compare
         | 
| 1250 | 
            +
                    submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
         | 
| 1251 | 
            +
                                                              outputs=text_output_nochat, api_name='submit_nochat') \
         | 
| 1252 | 
            +
                        .then(**score_args_nochat, api_name='instruction_bot_score_nochat') \
         | 
| 1253 | 
            +
                        .then(clear_torch_cache)
         | 
| 1254 | 
            +
             | 
| 1255 | 
            +
                    def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
         | 
| 1256 | 
            +
                        # ensure old model removed from GPU memory
         | 
| 1257 | 
            +
                        if kwargs['debug']:
         | 
| 1258 | 
            +
                            print("Pre-switch pre-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
         | 
| 1259 | 
            +
             | 
| 1260 | 
            +
                        if isinstance(model_state_old[0], str) and model0 is not None:
         | 
| 1261 | 
            +
                            # best can do, move model loaded at first to CPU
         | 
| 1262 | 
            +
                            model0.cpu()
         | 
| 1263 | 
            +
             | 
| 1264 | 
            +
                        if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
         | 
| 1265 | 
            +
                            try:
         | 
| 1266 | 
            +
                                model_state_old[0].cpu()
         | 
| 1267 | 
            +
                            except Exception as e:
         | 
| 1268 | 
            +
                                # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
         | 
| 1269 | 
            +
                                print("Unable to put model on CPU: %s" % str(e), flush=True)
         | 
| 1270 | 
            +
                            del model_state_old[0]
         | 
| 1271 | 
            +
                            model_state_old[0] = None
         | 
| 1272 | 
            +
             | 
| 1273 | 
            +
                        if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
         | 
| 1274 | 
            +
                            del model_state_old[1]
         | 
| 1275 | 
            +
                            model_state_old[1] = None
         | 
| 1276 | 
            +
             | 
| 1277 | 
            +
                        clear_torch_cache()
         | 
| 1278 | 
            +
                        if kwargs['debug']:
         | 
| 1279 | 
            +
                            print("Pre-switch post-del GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
         | 
| 1280 | 
            +
             | 
| 1281 | 
            +
                        if model_name is None or model_name == no_model_str:
         | 
| 1282 | 
            +
                            # no-op if no model, just free memory
         | 
| 1283 | 
            +
                            # no detranscribe needed for model, never go into evaluate
         | 
| 1284 | 
            +
                            lora_weights = no_lora_str
         | 
| 1285 | 
            +
                            return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
         | 
| 1286 | 
            +
             | 
| 1287 | 
            +
                        all_kwargs1 = all_kwargs.copy()
         | 
| 1288 | 
            +
                        all_kwargs1['base_model'] = model_name.strip()
         | 
| 1289 | 
            +
                        all_kwargs1['load_8bit'] = load_8bit
         | 
| 1290 | 
            +
                        all_kwargs1['infer_devices'] = infer_devices
         | 
| 1291 | 
            +
                        all_kwargs1['gpu_id'] = int(gpu_id)  # detranscribe
         | 
| 1292 | 
            +
                        model_lower = model_name.strip().lower()
         | 
| 1293 | 
            +
                        if model_lower in inv_prompt_type_to_model_lower:
         | 
| 1294 | 
            +
                            prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
         | 
| 1295 | 
            +
                        else:
         | 
| 1296 | 
            +
                            prompt_type1 = prompt_type_old
         | 
| 1297 | 
            +
             | 
| 1298 | 
            +
                        # detranscribe
         | 
| 1299 | 
            +
                        if lora_weights == no_lora_str:
         | 
| 1300 | 
            +
                            lora_weights = ''
         | 
| 1301 | 
            +
             | 
| 1302 | 
            +
                        all_kwargs1['lora_weights'] = lora_weights.strip()
         | 
| 1303 | 
            +
                        model1, tokenizer1, device1 = get_model(**all_kwargs1)
         | 
| 1304 | 
            +
                        clear_torch_cache()
         | 
| 1305 | 
            +
             | 
| 1306 | 
            +
                        if kwargs['debug']:
         | 
| 1307 | 
            +
                            print("Post-switch GPU memory: %s" % torch.cuda.memory_allocated(), flush=True)
         | 
| 1308 | 
            +
                        return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
         | 
| 1309 | 
            +
             | 
| 1310 | 
            +
                    def dropdown_prompt_type_list(x):
         | 
| 1311 | 
            +
                        return gr.Dropdown.update(value=x)
         | 
| 1312 | 
            +
             | 
| 1313 | 
            +
                    def chatbot_list(x, model_used_in):
         | 
| 1314 | 
            +
                        return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
         | 
| 1315 | 
            +
             | 
| 1316 | 
            +
                    load_model_args = dict(fn=load_model,
         | 
| 1317 | 
            +
                                           inputs=[model_choice, lora_choice, model_state, prompt_type,
         | 
| 1318 | 
            +
                                                   model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
         | 
| 1319 | 
            +
                                           outputs=[model_state, model_used, lora_used, prompt_type])
         | 
| 1320 | 
            +
                    prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
         | 
| 1321 | 
            +
                    chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
         | 
| 1322 | 
            +
                    nochat_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output_nochat)
         | 
| 1323 | 
            +
                    if not is_public:
         | 
| 1324 | 
            +
                        load_model_event = load_model_button.click(**load_model_args) \
         | 
| 1325 | 
            +
                            .then(**prompt_update_args) \
         | 
| 1326 | 
            +
                            .then(**chatbot_update_args) \
         | 
| 1327 | 
            +
                            .then(**nochat_update_args) \
         | 
| 1328 | 
            +
                            .then(clear_torch_cache)
         | 
| 1329 | 
            +
             | 
| 1330 | 
            +
                    load_model_args2 = dict(fn=load_model,
         | 
| 1331 | 
            +
                                            inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
         | 
| 1332 | 
            +
                                                    model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
         | 
| 1333 | 
            +
                                            outputs=[model_state2, model_used2, lora_used2, prompt_type2])
         | 
| 1334 | 
            +
                    prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
         | 
| 1335 | 
            +
                    chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
         | 
| 1336 | 
            +
                    if not is_public:
         | 
| 1337 | 
            +
                        load_model_event2 = load_model_button2.click(**load_model_args2) \
         | 
| 1338 | 
            +
                            .then(**prompt_update_args2) \
         | 
| 1339 | 
            +
                            .then(**chatbot_update_args2) \
         | 
| 1340 | 
            +
                            .then(clear_torch_cache)
         | 
| 1341 | 
            +
             | 
| 1342 | 
            +
                    def dropdown_model_list(list0, x):
         | 
| 1343 | 
            +
                        new_state = [list0[0] + [x]]
         | 
| 1344 | 
            +
                        new_options = [*new_state[0]]
         | 
| 1345 | 
            +
                        return gr.Dropdown.update(value=x, choices=new_options), \
         | 
| 1346 | 
            +
                               gr.Dropdown.update(value=x, choices=new_options), \
         | 
| 1347 | 
            +
                               '', new_state
         | 
| 1348 | 
            +
             | 
| 1349 | 
            +
                    add_model_event = add_model_button.click(fn=dropdown_model_list,
         | 
| 1350 | 
            +
                                                             inputs=[model_options_state, new_model],
         | 
| 1351 | 
            +
                                                             outputs=[model_choice, model_choice2, new_model, model_options_state])
         | 
| 1352 | 
            +
             | 
| 1353 | 
            +
                    def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
         | 
| 1354 | 
            +
                        new_state = [list0[0] + [x]]
         | 
| 1355 | 
            +
                        new_options = [*new_state[0]]
         | 
| 1356 | 
            +
                        # don't switch drop-down to added lora if already have model loaded
         | 
| 1357 | 
            +
                        x1 = x if model_used1 == no_model_str else lora_used1
         | 
| 1358 | 
            +
                        x2 = x if model_used2 == no_model_str else lora_used2
         | 
| 1359 | 
            +
                        return gr.Dropdown.update(value=x1, choices=new_options), \
         | 
| 1360 | 
            +
                               gr.Dropdown.update(value=x2, choices=new_options), \
         | 
| 1361 | 
            +
                               '', new_state
         | 
| 1362 | 
            +
             | 
| 1363 | 
            +
                    add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
         | 
| 1364 | 
            +
                                                           inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2, lora_used2],
         | 
| 1365 | 
            +
                                                           outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
         | 
| 1366 | 
            +
             | 
| 1367 | 
            +
                    go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go") \
         | 
| 1368 | 
            +
                        .then(lambda: gr.update(visible=True), None, normal_block) \
         | 
| 1369 | 
            +
                        .then(**load_model_args).then(**prompt_update_args)
         | 
| 1370 | 
            +
             | 
| 1371 | 
            +
                    def compare_textbox_fun(x):
         | 
| 1372 | 
            +
                        return gr.Textbox.update(visible=x)
         | 
| 1373 | 
            +
             | 
| 1374 | 
            +
                    def compare_column_fun(x):
         | 
| 1375 | 
            +
                        return gr.Column.update(visible=x)
         | 
| 1376 | 
            +
             | 
| 1377 | 
            +
                    def compare_prompt_fun(x):
         | 
| 1378 | 
            +
                        return gr.Dropdown.update(visible=x)
         | 
| 1379 | 
            +
             | 
| 1380 | 
            +
                    compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, api_name="compare_checkbox") \
         | 
| 1381 | 
            +
                        .then(compare_column_fun, compare_checkbox, col_model2) \
         | 
| 1382 | 
            +
                        .then(compare_prompt_fun, compare_checkbox, prompt_type2) \
         | 
| 1383 | 
            +
                        .then(compare_textbox_fun, compare_checkbox, score_text2)
         | 
| 1384 | 
            +
                    # FIXME: add score_res2 in condition, but do better
         | 
| 1385 | 
            +
             | 
| 1386 | 
            +
                    # callback for logging flagged input/output
         | 
| 1387 | 
            +
                    callback.setup(inputs_list + [text_output], "flagged_data_points")
         | 
| 1388 | 
            +
                    flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
         | 
| 1389 | 
            +
                                   api_name='flag')
         | 
| 1390 | 
            +
                    flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output], None, preprocess=False,
         | 
| 1391 | 
            +
                                          api_name='flag_nochat')
         | 
| 1392 | 
            +
             | 
| 1393 | 
            +
                    def get_system_info():
         | 
| 1394 | 
            +
                        return gr.Textbox.update(value=system_info_print())
         | 
| 1395 | 
            +
             | 
| 1396 | 
            +
                    system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info')
         | 
| 1397 | 
            +
             | 
| 1398 | 
            +
                    # don't pass text_output, don't want to clear output, just stop it
         | 
| 1399 | 
            +
                    # FIXME: have to click once to stop output and second time to stop GPUs going
         | 
| 1400 | 
            +
                    stop_btn.click(lambda: None, None, None,
         | 
| 1401 | 
            +
                                   cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
         | 
| 1402 | 
            +
                                   queue=False, api_name='stop').then(clear_torch_cache)
         | 
| 1403 | 
            +
             | 
| 1404 | 
            +
                demo.queue(concurrency_count=1)
         | 
| 1405 | 
            +
                favicon_path = "h2o-logo.svg"
         | 
| 1406 | 
            +
                demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
         | 
| 1407 | 
            +
                            favicon_path=favicon_path, prevent_thread_lock=True)  # , enable_queue=True)
         | 
| 1408 | 
            +
                print("Started GUI", flush=True)
         | 
| 1409 | 
            +
                demo.block_thread()
         | 
| 1410 | 
            +
             | 
| 1411 | 
            +
             | 
| 1412 | 
            +
            input_args_list = ['model_state']
         | 
| 1413 | 
            +
            inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0']
         | 
| 1414 | 
            +
             | 
| 1415 | 
            +
             | 
| 1416 | 
            +
            def get_inputs_list(inputs_dict, model_lower):
         | 
| 1417 | 
            +
                """
         | 
| 1418 | 
            +
                map gradio objects in locals() to inputs for evaluate().
         | 
| 1419 | 
            +
                :param inputs_dict:
         | 
| 1420 | 
            +
                :param model_lower:
         | 
| 1421 | 
            +
                :return:
         | 
| 1422 | 
            +
                """
         | 
| 1423 | 
            +
                inputs_list_names = list(inspect.signature(evaluate).parameters)
         | 
| 1424 | 
            +
                inputs_list = []
         | 
| 1425 | 
            +
                for k in inputs_list_names:
         | 
| 1426 | 
            +
                    if k == 'kwargs':
         | 
| 1427 | 
            +
                        continue
         | 
| 1428 | 
            +
                    if k in input_args_list + inputs_kwargs_list:
         | 
| 1429 | 
            +
                        # these are added via partial, not taken as input
         | 
| 1430 | 
            +
                        continue
         | 
| 1431 | 
            +
                    if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
         | 
| 1432 | 
            +
                        continue
         | 
| 1433 | 
            +
                    inputs_list.append(inputs_dict[k])
         | 
| 1434 | 
            +
                return inputs_list
         | 
| 1435 | 
            +
             | 
| 1436 | 
            +
             | 
| 1437 | 
            +
            eval_func_param_names = ['instruction',
         | 
| 1438 | 
            +
                                     'iinput',
         | 
| 1439 | 
            +
                                     'context',
         | 
| 1440 | 
            +
                                     'stream_output',
         | 
| 1441 | 
            +
                                     'prompt_type',
         | 
| 1442 | 
            +
                                     'temperature',
         | 
| 1443 | 
            +
                                     'top_p',
         | 
| 1444 | 
            +
                                     'top_k',
         | 
| 1445 | 
            +
                                     'num_beams',
         | 
| 1446 | 
            +
                                     'max_new_tokens',
         | 
| 1447 | 
            +
                                     'min_new_tokens',
         | 
| 1448 | 
            +
                                     'early_stopping',
         | 
| 1449 | 
            +
                                     'max_time',
         | 
| 1450 | 
            +
                                     'repetition_penalty',
         | 
| 1451 | 
            +
                                     'num_return_sequences',
         | 
| 1452 | 
            +
                                     'do_sample',
         | 
| 1453 | 
            +
                                     'chat',
         | 
| 1454 | 
            +
                                     'instruction_nochat',
         | 
| 1455 | 
            +
                                     'iinput_nochat',
         | 
| 1456 | 
            +
                                     ]
         | 
| 1457 | 
            +
             | 
| 1458 | 
            +
             | 
| 1459 | 
            +
            def evaluate(
         | 
| 1460 | 
            +
                    model_state,
         | 
| 1461 | 
            +
                    # START NOTE: Examples must have same order of parameters
         | 
| 1462 | 
            +
                    instruction,
         | 
| 1463 | 
            +
                    iinput,
         | 
| 1464 | 
            +
                    context,
         | 
| 1465 | 
            +
                    stream_output,
         | 
| 1466 | 
            +
                    prompt_type,
         | 
| 1467 | 
            +
                    temperature,
         | 
| 1468 | 
            +
                    top_p,
         | 
| 1469 | 
            +
                    top_k,
         | 
| 1470 | 
            +
                    num_beams,
         | 
| 1471 | 
            +
                    max_new_tokens,
         | 
| 1472 | 
            +
                    min_new_tokens,
         | 
| 1473 | 
            +
                    early_stopping,
         | 
| 1474 | 
            +
                    max_time,
         | 
| 1475 | 
            +
                    repetition_penalty,
         | 
| 1476 | 
            +
                    num_return_sequences,
         | 
| 1477 | 
            +
                    do_sample,
         | 
| 1478 | 
            +
                    chat,
         | 
| 1479 | 
            +
                    instruction_nochat,
         | 
| 1480 | 
            +
                    iinput_nochat,
         | 
| 1481 | 
            +
                    # END NOTE: Examples must have same order of parameters
         | 
| 1482 | 
            +
                    src_lang=None,
         | 
| 1483 | 
            +
                    tgt_lang=None,
         | 
| 1484 | 
            +
                    debug=False,
         | 
| 1485 | 
            +
                    save_dir=None,
         | 
| 1486 | 
            +
                    hard_stop_list=None,
         | 
| 1487 | 
            +
                    sanitize_bot_response=True,
         | 
| 1488 | 
            +
                    model_state0=None,
         | 
| 1489 | 
            +
                    **kwargs,
         | 
| 1490 | 
            +
            ):
         | 
| 1491 | 
            +
                if debug:
         | 
| 1492 | 
            +
                    locals_dict = locals().copy()
         | 
| 1493 | 
            +
                    locals_dict.pop('model_state', None)
         | 
| 1494 | 
            +
                    locals_dict.pop('model_state0', None)
         | 
| 1495 | 
            +
                    print(locals_dict)
         | 
| 1496 | 
            +
             | 
| 1497 | 
            +
                no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
         | 
| 1498 | 
            +
             | 
| 1499 | 
            +
                if model_state0 is None:
         | 
| 1500 | 
            +
                    # e.g. for no gradio case, set dummy value, else should be set
         | 
| 1501 | 
            +
                    model_state0 = [None, None, None, None]
         | 
| 1502 | 
            +
             | 
| 1503 | 
            +
                if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
         | 
| 1504 | 
            +
                    # try to free-up original model (i.e. list was passed as reference)
         | 
| 1505 | 
            +
                    if model_state0 is not None and model_state0[0] is not None:
         | 
| 1506 | 
            +
                        model_state0[0].cpu()
         | 
| 1507 | 
            +
                        model_state0[0] = None
         | 
| 1508 | 
            +
                    # try to free-up original tokenizer (i.e. list was passed as reference)
         | 
| 1509 | 
            +
                    if model_state0 is not None and model_state0[1] is not None:
         | 
| 1510 | 
            +
                        model_state0[1] = None
         | 
| 1511 | 
            +
                    clear_torch_cache()
         | 
| 1512 | 
            +
                    model, tokenizer, device, base_model = model_state
         | 
| 1513 | 
            +
                elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
         | 
| 1514 | 
            +
                    assert isinstance(model_state[0], str)
         | 
| 1515 | 
            +
                    model, tokenizer, device, base_model = model_state0
         | 
| 1516 | 
            +
                else:
         | 
| 1517 | 
            +
                    raise AssertionError(no_model_msg)
         | 
| 1518 | 
            +
             | 
| 1519 | 
            +
                if base_model is None:
         | 
| 1520 | 
            +
                    raise AssertionError(no_model_msg)
         | 
| 1521 | 
            +
             | 
| 1522 | 
            +
                assert base_model.strip(), no_model_msg
         | 
| 1523 | 
            +
                assert model, "Model is missing"
         | 
| 1524 | 
            +
                assert tokenizer, "Tokenizer is missing"
         | 
| 1525 | 
            +
             | 
| 1526 | 
            +
                # choose chat or non-chat mode
         | 
| 1527 | 
            +
                if not chat:
         | 
| 1528 | 
            +
                    instruction = instruction_nochat
         | 
| 1529 | 
            +
                    iinput = iinput_nochat
         | 
| 1530 | 
            +
             | 
| 1531 | 
            +
                data_point = dict(context=context, instruction=instruction, input=iinput)
         | 
| 1532 | 
            +
                prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
         | 
| 1533 | 
            +
                prompt = prompter.generate_prompt(data_point)
         | 
| 1534 | 
            +
             | 
| 1535 | 
            +
                if hard_stop_list is None:
         | 
| 1536 | 
            +
                    # acts like undo on user entry and bot response
         | 
| 1537 | 
            +
                    hard_stop_list = []
         | 
| 1538 | 
            +
             | 
| 1539 | 
            +
                if isinstance(tokenizer, str):
         | 
| 1540 | 
            +
                    # pipeline
         | 
| 1541 | 
            +
                    if tokenizer == "summarization":
         | 
| 1542 | 
            +
                        key = 'summary_text'
         | 
| 1543 | 
            +
                    else:
         | 
| 1544 | 
            +
                        raise RuntimeError("No such task type %s" % tokenizer)
         | 
| 1545 | 
            +
                    # NOTE: uses max_length only
         | 
| 1546 | 
            +
                    yield model(prompt, max_length=max_new_tokens)[0][key]
         | 
| 1547 | 
            +
             | 
| 1548 | 
            +
                if 'mbart-' in base_model.lower():
         | 
| 1549 | 
            +
                    assert src_lang is not None
         | 
| 1550 | 
            +
                    tokenizer.src_lang = languages_covered()[src_lang]
         | 
| 1551 | 
            +
             | 
| 1552 | 
            +
                if chat:
         | 
| 1553 | 
            +
                    # override, ignore user change
         | 
| 1554 | 
            +
                    num_return_sequences = 1
         | 
| 1555 | 
            +
                if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
         | 
| 1556 | 
            +
                    if prompt_type == 'human_bot':
         | 
| 1557 | 
            +
                        # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
         | 
| 1558 | 
            +
                        # stopping only starts once output is beyond prompt
         | 
| 1559 | 
            +
                        # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
         | 
| 1560 | 
            +
                        stop_words = [human, bot, '\n' + human, '\n' + bot]
         | 
| 1561 | 
            +
                        encounters = [1, 2]
         | 
| 1562 | 
            +
                    elif prompt_type == 'instruct_vicuna':
         | 
| 1563 | 
            +
                        # even below is not enough, generic strings and many ways to encode
         | 
| 1564 | 
            +
                        stop_words = [
         | 
| 1565 | 
            +
                            '### Human:',
         | 
| 1566 | 
            +
                            """
         | 
| 1567 | 
            +
            ### Human:""",
         | 
| 1568 | 
            +
                            """
         | 
| 1569 | 
            +
            ### Human:
         | 
| 1570 | 
            +
            """,
         | 
| 1571 | 
            +
                            '### Assistant:',
         | 
| 1572 | 
            +
                            """
         | 
| 1573 | 
            +
            ### Assistant:""",
         | 
| 1574 | 
            +
                            """
         | 
| 1575 | 
            +
            ### Assistant:
         | 
| 1576 | 
            +
            """,
         | 
| 1577 | 
            +
                        ]
         | 
| 1578 | 
            +
                        encounters = [1, 2]
         | 
| 1579 | 
            +
                    else:
         | 
| 1580 | 
            +
                        # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
         | 
| 1581 | 
            +
                        stop_words = ['### End']
         | 
| 1582 | 
            +
                        encounters = [1]
         | 
| 1583 | 
            +
                    stop_words_ids = [
         | 
| 1584 | 
            +
                        tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
         | 
| 1585 | 
            +
                    # handle single token case
         | 
| 1586 | 
            +
                    stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
         | 
| 1587 | 
            +
                    stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
         | 
| 1588 | 
            +
                    # avoid padding in front of tokens
         | 
| 1589 | 
            +
                    if tokenizer.pad_token:
         | 
| 1590 | 
            +
                        stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
         | 
| 1591 | 
            +
                    # handle fake \n added
         | 
| 1592 | 
            +
                    stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
         | 
| 1593 | 
            +
                    # build stopper
         | 
| 1594 | 
            +
                    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters)])
         | 
| 1595 | 
            +
                else:
         | 
| 1596 | 
            +
                    stopping_criteria = StoppingCriteriaList()
         | 
| 1597 | 
            +
             | 
| 1598 | 
            +
                # help to avoid errors like:
         | 
| 1599 | 
            +
                # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
         | 
| 1600 | 
            +
                # RuntimeError: expected scalar type Half but found Float
         | 
| 1601 | 
            +
                # with - 256
         | 
| 1602 | 
            +
                max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
         | 
| 1603 | 
            +
                cutoff_len = max_length_tokenize * 4  # if reaches limit, then can't generate new tokens
         | 
| 1604 | 
            +
                output_smallest = 30 * 4
         | 
| 1605 | 
            +
                prompt = prompt[-cutoff_len - output_smallest:]
         | 
| 1606 | 
            +
                inputs = tokenizer(prompt,
         | 
| 1607 | 
            +
                                   return_tensors="pt",
         | 
| 1608 | 
            +
                                   truncation=True,
         | 
| 1609 | 
            +
                                   max_length=max_length_tokenize)
         | 
| 1610 | 
            +
                if debug and len(inputs["input_ids"]) > 0:
         | 
| 1611 | 
            +
                    print('input_ids length', len(inputs["input_ids"][0]), flush=True)
         | 
| 1612 | 
            +
                input_ids = inputs["input_ids"].to(device)
         | 
| 1613 | 
            +
                generation_config = GenerationConfig(
         | 
| 1614 | 
            +
                    temperature=float(temperature),
         | 
| 1615 | 
            +
                    top_p=float(top_p),
         | 
| 1616 | 
            +
                    top_k=top_k,
         | 
| 1617 | 
            +
                    num_beams=num_beams,
         | 
| 1618 | 
            +
                    do_sample=do_sample,
         | 
| 1619 | 
            +
                    repetition_penalty=float(repetition_penalty),
         | 
| 1620 | 
            +
                    num_return_sequences=num_return_sequences,
         | 
| 1621 | 
            +
                    renormalize_logits=True,
         | 
| 1622 | 
            +
                    remove_invalid_values=True,
         | 
| 1623 | 
            +
                    **kwargs,
         | 
| 1624 | 
            +
                )
         | 
| 1625 | 
            +
             | 
| 1626 | 
            +
                gen_kwargs = dict(input_ids=input_ids,
         | 
| 1627 | 
            +
                                  generation_config=generation_config,
         | 
| 1628 | 
            +
                                  return_dict_in_generate=True,
         | 
| 1629 | 
            +
                                  output_scores=True,
         | 
| 1630 | 
            +
                                  max_new_tokens=max_new_tokens,  # prompt + new
         | 
| 1631 | 
            +
                                  min_new_tokens=min_new_tokens,  # prompt + new
         | 
| 1632 | 
            +
                                  early_stopping=early_stopping,  # False, True, "never"
         | 
| 1633 | 
            +
                                  max_time=max_time,
         | 
| 1634 | 
            +
                                  stopping_criteria=stopping_criteria,
         | 
| 1635 | 
            +
                                  )
         | 
| 1636 | 
            +
                if 'gpt2' in base_model.lower():
         | 
| 1637 | 
            +
                    gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
         | 
| 1638 | 
            +
                elif 'mbart-' in base_model.lower():
         | 
| 1639 | 
            +
                    assert tgt_lang is not None
         | 
| 1640 | 
            +
                    tgt_lang = languages_covered()[tgt_lang]
         | 
| 1641 | 
            +
                    gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
         | 
| 1642 | 
            +
                else:
         | 
| 1643 | 
            +
                    gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
         | 
| 1644 | 
            +
             | 
| 1645 | 
            +
                decoder = functools.partial(tokenizer.decode,
         | 
| 1646 | 
            +
                                            skip_special_tokens=True,
         | 
| 1647 | 
            +
                                            clean_up_tokenization_spaces=True,
         | 
| 1648 | 
            +
                                            )
         | 
| 1649 | 
            +
                decoder_raw = functools.partial(tokenizer.decode,
         | 
| 1650 | 
            +
                                                skip_special_tokens=False,
         | 
| 1651 | 
            +
                                                clean_up_tokenization_spaces=True,
         | 
| 1652 | 
            +
                                                )
         | 
| 1653 | 
            +
             | 
| 1654 | 
            +
                with torch.no_grad():
         | 
| 1655 | 
            +
                    # decoded tokenized prompt can deviate from prompt due to special characters
         | 
| 1656 | 
            +
                    inputs_decoded = decoder(input_ids[0])
         | 
| 1657 | 
            +
                    inputs_decoded_raw = decoder_raw(input_ids[0])
         | 
| 1658 | 
            +
                    if inputs_decoded == prompt:
         | 
| 1659 | 
            +
                        # normal
         | 
| 1660 | 
            +
                        pass
         | 
| 1661 | 
            +
                    elif inputs_decoded.lstrip() == prompt.lstrip():
         | 
| 1662 | 
            +
                        # sometimes extra space in front, make prompt same for prompt removal
         | 
| 1663 | 
            +
                        prompt = inputs_decoded
         | 
| 1664 | 
            +
                    elif inputs_decoded_raw == prompt:
         | 
| 1665 | 
            +
                        # some models specify special tokens that are part of normal prompt, so can't skip them
         | 
| 1666 | 
            +
                        inputs_decoded_raw = inputs_decoded
         | 
| 1667 | 
            +
                        decoder = decoder_raw
         | 
| 1668 | 
            +
                    else:
         | 
| 1669 | 
            +
                        print("WARNING: Special characters in prompt", flush=True)
         | 
| 1670 | 
            +
                    if stream_output:
         | 
| 1671 | 
            +
                        def generate(callback=None, **kwargs):
         | 
| 1672 | 
            +
                            # re-order stopping so Stream first and get out all chunks before stop for other reasons
         | 
| 1673 | 
            +
                            stopping_criteria0 = kwargs.get('stopping_criteria', StoppingCriteriaList()).copy()
         | 
| 1674 | 
            +
                            kwargs['stopping_criteria'] = StoppingCriteriaList()
         | 
| 1675 | 
            +
                            kwargs['stopping_criteria'].append(Stream(func=callback))
         | 
| 1676 | 
            +
                            for stopping_criteria1 in stopping_criteria0:
         | 
| 1677 | 
            +
                                kwargs['stopping_criteria'].append(stopping_criteria1)
         | 
| 1678 | 
            +
             | 
| 1679 | 
            +
                            try:
         | 
| 1680 | 
            +
                                model.generate(**kwargs)
         | 
| 1681 | 
            +
                            except torch.cuda.OutOfMemoryError as e:
         | 
| 1682 | 
            +
                                print("GPU OOM: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
         | 
| 1683 | 
            +
                                      flush=True)
         | 
| 1684 | 
            +
                                if kwargs['input_ids'] is not None:
         | 
| 1685 | 
            +
                                    kwargs['input_ids'].cpu()
         | 
| 1686 | 
            +
                                kwargs['input_ids'] = None
         | 
| 1687 | 
            +
                                traceback.print_exc()
         | 
| 1688 | 
            +
                                clear_torch_cache()
         | 
| 1689 | 
            +
                                return
         | 
| 1690 | 
            +
                            except (Exception, RuntimeError) as e:
         | 
| 1691 | 
            +
                                if 'Expected all tensors to be on the same device' in str(e) or \
         | 
| 1692 | 
            +
                                        'expected scalar type Half but found Float' in str(e) or \
         | 
| 1693 | 
            +
                                        'probability tensor contains either' in str(e) or \
         | 
| 1694 | 
            +
                                        'cublasLt ran into an error!' in str(e):
         | 
| 1695 | 
            +
                                    print(
         | 
| 1696 | 
            +
                                        "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
         | 
| 1697 | 
            +
                                        flush=True)
         | 
| 1698 | 
            +
                                    traceback.print_exc()
         | 
| 1699 | 
            +
                                    clear_torch_cache()
         | 
| 1700 | 
            +
                                    if raise_generate_gpu_exceptions:
         | 
| 1701 | 
            +
                                        raise
         | 
| 1702 | 
            +
                                    return
         | 
| 1703 | 
            +
                                else:
         | 
| 1704 | 
            +
                                    raise
         | 
| 1705 | 
            +
             | 
| 1706 | 
            +
                        decoded_output = None
         | 
| 1707 | 
            +
                        for output in CallbackToGenerator(generate, callback=None, **gen_kwargs):
         | 
| 1708 | 
            +
                            decoded_output = decoder(output)
         | 
| 1709 | 
            +
                            if output[-1] in [tokenizer.eos_token_id]:
         | 
| 1710 | 
            +
                                if debug:
         | 
| 1711 | 
            +
                                    print("HIT EOS", flush=True)
         | 
| 1712 | 
            +
                                break
         | 
| 1713 | 
            +
                            if any(ele in decoded_output for ele in hard_stop_list):
         | 
| 1714 | 
            +
                                raise StopIteration
         | 
| 1715 | 
            +
                            yield prompter.get_response(decoded_output, prompt=inputs_decoded,
         | 
| 1716 | 
            +
                                                        sanitize_bot_response=sanitize_bot_response)
         | 
| 1717 | 
            +
                        if save_dir and decoded_output:
         | 
| 1718 | 
            +
                            save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
         | 
| 1719 | 
            +
                    else:
         | 
| 1720 | 
            +
                        outputs = model.generate(**gen_kwargs)
         | 
| 1721 | 
            +
                        outputs = [decoder(s) for s in outputs.sequences]
         | 
| 1722 | 
            +
                        yield prompter.get_response(outputs, prompt=inputs_decoded,
         | 
| 1723 | 
            +
                                                    sanitize_bot_response=sanitize_bot_response)
         | 
| 1724 | 
            +
                        if save_dir and outputs and len(outputs) >= 1:
         | 
| 1725 | 
            +
                            decoded_output = prompt + outputs[0]
         | 
| 1726 | 
            +
                            save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
         | 
| 1727 | 
            +
             | 
| 1728 | 
            +
             | 
| 1729 | 
            +
            def get_generate_params(model_lower, chat,
         | 
| 1730 | 
            +
                                    stream_output, show_examples,
         | 
| 1731 | 
            +
                                    prompt_type, temperature, top_p, top_k, num_beams,
         | 
| 1732 | 
            +
                                    max_new_tokens, min_new_tokens, early_stopping, max_time,
         | 
| 1733 | 
            +
                                    repetition_penalty, num_return_sequences,
         | 
| 1734 | 
            +
                                    do_sample):
         | 
| 1735 | 
            +
                use_defaults = False
         | 
| 1736 | 
            +
                use_default_examples = True
         | 
| 1737 | 
            +
                examples = []
         | 
| 1738 | 
            +
                task_info = f"{prompt_type}"
         | 
| 1739 | 
            +
                if model_lower:
         | 
| 1740 | 
            +
                    print(f"Using Model {model_lower}", flush=True)
         | 
| 1741 | 
            +
                else:
         | 
| 1742 | 
            +
                    print("No model defined yet", flush=True)
         | 
| 1743 | 
            +
             | 
| 1744 | 
            +
                min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
         | 
| 1745 | 
            +
                early_stopping = early_stopping if early_stopping is not None else False
         | 
| 1746 | 
            +
                max_time_defaults = 60 * 3
         | 
| 1747 | 
            +
                max_time = max_time if max_time is not None else max_time_defaults
         | 
| 1748 | 
            +
             | 
| 1749 | 
            +
                if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
         | 
| 1750 | 
            +
                    prompt_type = inv_prompt_type_to_model_lower[model_lower]
         | 
| 1751 | 
            +
             | 
| 1752 | 
            +
                # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
         | 
| 1753 | 
            +
                if show_examples is None:
         | 
| 1754 | 
            +
                    if chat:
         | 
| 1755 | 
            +
                        show_examples = False
         | 
| 1756 | 
            +
                    else:
         | 
| 1757 | 
            +
                        show_examples = True
         | 
| 1758 | 
            +
             | 
| 1759 | 
            +
                summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker? 
         | 
| 1760 | 
            +
            Philipp: Sure you can use the new Hugging Face Deep Learning Container. 
         | 
| 1761 | 
            +
            Jeff: ok.
         | 
| 1762 | 
            +
            Jeff: and how can I get started? 
         | 
| 1763 | 
            +
            Jeff: where can I find documentation? 
         | 
| 1764 | 
            +
            Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
         | 
| 1765 | 
            +
             | 
| 1766 | 
            +
                if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
         | 
| 1767 | 
            +
                    placeholder_instruction = summarize_example1
         | 
| 1768 | 
            +
                    placeholder_input = ""
         | 
| 1769 | 
            +
                    use_defaults = True
         | 
| 1770 | 
            +
                    use_default_examples = False
         | 
| 1771 | 
            +
                    examples += [
         | 
| 1772 | 
            +
                        [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
         | 
| 1773 | 
            +
                         1.0, 1,
         | 
| 1774 | 
            +
                         False]]
         | 
| 1775 | 
            +
                    task_info = "Summarization"
         | 
| 1776 | 
            +
                elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
         | 
| 1777 | 
            +
                    placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
         | 
| 1778 | 
            +
                    placeholder_input = ""
         | 
| 1779 | 
            +
                    use_defaults = True
         | 
| 1780 | 
            +
                    use_default_examples = True
         | 
| 1781 | 
            +
                    task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc.  Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
         | 
| 1782 | 
            +
                elif 'mbart-' in model_lower:
         | 
| 1783 | 
            +
                    placeholder_instruction = "The girl has long hair."
         | 
| 1784 | 
            +
                    placeholder_input = ""
         | 
| 1785 | 
            +
                    use_defaults = True
         | 
| 1786 | 
            +
                    use_default_examples = False
         | 
| 1787 | 
            +
                    examples += [
         | 
| 1788 | 
            +
                        [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
         | 
| 1789 | 
            +
                         1.0, 1,
         | 
| 1790 | 
            +
                         False]]
         | 
| 1791 | 
            +
                elif 'gpt2' in model_lower:
         | 
| 1792 | 
            +
                    placeholder_instruction = "The sky is"
         | 
| 1793 | 
            +
                    placeholder_input = ""
         | 
| 1794 | 
            +
                    prompt_type = prompt_type or 'plain'
         | 
| 1795 | 
            +
                    use_default_examples = True  # some will be odd "continuations" but can be ok
         | 
| 1796 | 
            +
                    examples += [
         | 
| 1797 | 
            +
                        [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
         | 
| 1798 | 
            +
                         1.0, 1,
         | 
| 1799 | 
            +
                         False]]
         | 
| 1800 | 
            +
                    task_info = "Auto-complete phrase, code, etc."
         | 
| 1801 | 
            +
                    use_defaults = True
         | 
| 1802 | 
            +
                else:
         | 
| 1803 | 
            +
                    if chat:
         | 
| 1804 | 
            +
                        placeholder_instruction = "Enter a question or imperative."
         | 
| 1805 | 
            +
                    else:
         | 
| 1806 | 
            +
                        placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
         | 
| 1807 | 
            +
                    placeholder_input = ""
         | 
| 1808 | 
            +
                    if model_lower:
         | 
| 1809 | 
            +
                        prompt_type = prompt_type or 'human_bot'
         | 
| 1810 | 
            +
                    else:
         | 
| 1811 | 
            +
                        prompt_type = ''
         | 
| 1812 | 
            +
                    examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
         | 
| 1813 | 
            +
                                  stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
         | 
| 1814 | 
            +
                                  False]]
         | 
| 1815 | 
            +
                    task_info = "No task"
         | 
| 1816 | 
            +
                    if prompt_type == 'instruct':
         | 
| 1817 | 
            +
                        task_info = "Answer question or follow imperative as instruction with optionally input."
         | 
| 1818 | 
            +
                    elif prompt_type == 'plain':
         | 
| 1819 | 
            +
                        task_info = "Auto-complete phrase, code, etc."
         | 
| 1820 | 
            +
                    elif prompt_type == 'human_bot':
         | 
| 1821 | 
            +
                        if chat:
         | 
| 1822 | 
            +
                            task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
         | 
| 1823 | 
            +
                        else:
         | 
| 1824 | 
            +
                            task_info = "Ask question/imperative (input concatenated with instruction)"
         | 
| 1825 | 
            +
             | 
| 1826 | 
            +
                # revert to plain if still nothing
         | 
| 1827 | 
            +
                prompt_type = prompt_type or 'plain'
         | 
| 1828 | 
            +
                if use_defaults:
         | 
| 1829 | 
            +
                    temperature = 1.0 if temperature is None else temperature
         | 
| 1830 | 
            +
                    top_p = 1.0 if top_p is None else top_p
         | 
| 1831 | 
            +
                    top_k = 40 if top_k is None else top_k
         | 
| 1832 | 
            +
                    num_beams = num_beams or 1
         | 
| 1833 | 
            +
                    max_new_tokens = max_new_tokens or 128
         | 
| 1834 | 
            +
                    repetition_penalty = repetition_penalty or 1.07
         | 
| 1835 | 
            +
                    num_return_sequences = min(num_beams, num_return_sequences or 1)
         | 
| 1836 | 
            +
                    do_sample = False if do_sample is None else do_sample
         | 
| 1837 | 
            +
                else:
         | 
| 1838 | 
            +
                    temperature = 0.1 if temperature is None else temperature
         | 
| 1839 | 
            +
                    top_p = 0.75 if top_p is None else top_p
         | 
| 1840 | 
            +
                    top_k = 40 if top_k is None else top_k
         | 
| 1841 | 
            +
                    if chat:
         | 
| 1842 | 
            +
                        num_beams = num_beams or 1
         | 
| 1843 | 
            +
                    else:
         | 
| 1844 | 
            +
                        num_beams = num_beams or 4
         | 
| 1845 | 
            +
                    max_new_tokens = max_new_tokens or 256
         | 
| 1846 | 
            +
                    repetition_penalty = repetition_penalty or 1.07
         | 
| 1847 | 
            +
                    num_return_sequences = min(num_beams, num_return_sequences or 1)
         | 
| 1848 | 
            +
                    do_sample = False if do_sample is None else do_sample
         | 
| 1849 | 
            +
                # doesn't include chat, instruction_nochat, iinput_nochat, added later
         | 
| 1850 | 
            +
                params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
         | 
| 1851 | 
            +
                               early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
         | 
| 1852 | 
            +
             | 
| 1853 | 
            +
                if use_default_examples:
         | 
| 1854 | 
            +
                    examples += [
         | 
| 1855 | 
            +
                        ["Translate English to French", "Good morning"] + params_list,
         | 
| 1856 | 
            +
                        ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
         | 
| 1857 | 
            +
                        ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
         | 
| 1858 | 
            +
                        [
         | 
| 1859 | 
            +
                            "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
         | 
| 1860 | 
            +
                            ''] + params_list,
         | 
| 1861 | 
            +
                        ['Translate to German:  My name is Arthur', ''] + params_list,
         | 
| 1862 | 
            +
                        ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
         | 
| 1863 | 
            +
                        ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
         | 
| 1864 | 
            +
                         ''] + params_list,
         | 
| 1865 | 
            +
                        ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
         | 
| 1866 | 
            +
                        ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
         | 
| 1867 | 
            +
                        ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
         | 
| 1868 | 
            +
                        [
         | 
| 1869 | 
            +
                            "Premise: At my age you will probably have learnt one lesson. Hypothesis:  It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
         | 
| 1870 | 
            +
                            ''] + params_list,
         | 
| 1871 | 
            +
                        ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
         | 
| 1872 | 
            +
                        [
         | 
| 1873 | 
            +
                            'Answer the following question by reasoning step by step.  The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
         | 
| 1874 | 
            +
                            ''] + params_list,
         | 
| 1875 | 
            +
                        ["""def area_of_rectangle(a: float, b: float):
         | 
| 1876 | 
            +
                \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
         | 
| 1877 | 
            +
                        ["""# a function in native python:
         | 
| 1878 | 
            +
            def mean(a):
         | 
| 1879 | 
            +
                return sum(a)/len(a)
         | 
| 1880 | 
            +
             | 
| 1881 | 
            +
            # the same function using numpy:
         | 
| 1882 | 
            +
            import numpy as np
         | 
| 1883 | 
            +
            def mean(a):""", ''] + params_list,
         | 
| 1884 | 
            +
                        ["""X = np.random.randn(100, 100)
         | 
| 1885 | 
            +
            y = np.random.randint(0, 1, 100)
         | 
| 1886 | 
            +
             | 
| 1887 | 
            +
            # fit random forest classifier with 20 estimators""", ''] + params_list,
         | 
| 1888 | 
            +
                    ]
         | 
| 1889 | 
            +
             | 
| 1890 | 
            +
                src_lang = "English"
         | 
| 1891 | 
            +
                tgt_lang = "Russian"
         | 
| 1892 | 
            +
             | 
| 1893 | 
            +
                # move to correct position
         | 
| 1894 | 
            +
                for example in examples:
         | 
| 1895 | 
            +
                    example += [chat, '', '']
         | 
| 1896 | 
            +
                    # adjust examples if non-chat mode
         | 
| 1897 | 
            +
                    if not chat:
         | 
| 1898 | 
            +
                        example[eval_func_param_names.index('instruction_nochat')] = example[
         | 
| 1899 | 
            +
                            eval_func_param_names.index('instruction')]
         | 
| 1900 | 
            +
                        example[eval_func_param_names.index('instruction')] = ''
         | 
| 1901 | 
            +
             | 
| 1902 | 
            +
                        example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
         | 
| 1903 | 
            +
                        example[eval_func_param_names.index('iinput')] = ''
         | 
| 1904 | 
            +
             | 
| 1905 | 
            +
                return placeholder_instruction, placeholder_input, \
         | 
| 1906 | 
            +
                       stream_output, show_examples, \
         | 
| 1907 | 
            +
                       prompt_type, temperature, top_p, top_k, num_beams, \
         | 
| 1908 | 
            +
                       max_new_tokens, min_new_tokens, early_stopping, max_time, \
         | 
| 1909 | 
            +
                       repetition_penalty, num_return_sequences, \
         | 
| 1910 | 
            +
                       do_sample, \
         | 
| 1911 | 
            +
                       src_lang, tgt_lang, \
         | 
| 1912 | 
            +
                       examples, \
         | 
| 1913 | 
            +
                       task_info
         | 
| 1914 | 
            +
             | 
| 1915 | 
            +
             | 
| 1916 | 
            +
            def languages_covered():
         | 
| 1917 | 
            +
                # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
         | 
| 1918 | 
            +
                covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
         | 
| 1919 | 
            +
                covered = covered.split(', ')
         | 
| 1920 | 
            +
                covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
         | 
| 1921 | 
            +
                return covered
         | 
| 1922 | 
            +
             | 
| 1923 | 
            +
             | 
| 1924 | 
            +
            def test_test_prompt(prompt_type='instruct', data_point=0):
         | 
| 1925 | 
            +
                example_data_point = example_data_points[data_point]
         | 
| 1926 | 
            +
                example_data_point.pop('output', None)
         | 
| 1927 | 
            +
                return generate_prompt(example_data_point, prompt_type, False, False)
         | 
| 1928 | 
            +
             | 
| 1929 | 
            +
             | 
| 1930 | 
            +
            if __name__ == "__main__":
         | 
| 1931 | 
            +
                print("""
         | 
| 1932 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
         | 
| 1933 | 
            +
                python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
         | 
| 1934 | 
            +
                python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
         | 
| 1935 | 
            +
                
         | 
| 1936 | 
            +
                # generate without lora weights, no prompt
         | 
| 1937 | 
            +
                python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
         | 
| 1938 | 
            +
                python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
         | 
| 1939 | 
            +
             | 
| 1940 | 
            +
                python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
         | 
| 1941 | 
            +
                # OpenChatKit settings:
         | 
| 1942 | 
            +
                python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
         | 
| 1943 | 
            +
             | 
| 1944 | 
            +
                python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
         | 
| 1945 | 
            +
                python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
         | 
| 1946 | 
            +
                python generate.py --base_model='philschmid/bart-large-cnn-samsum'
         | 
| 1947 | 
            +
                python generate.py --base_model='philschmid/flan-t5-base-samsum'
         | 
| 1948 | 
            +
                python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
         | 
| 1949 | 
            +
             | 
| 1950 | 
            +
                python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
         | 
| 1951 | 
            +
             | 
| 1952 | 
            +
                must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
         | 
| 1953 | 
            +
                can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
         | 
| 1954 | 
            +
                python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
         | 
| 1955 | 
            +
             | 
| 1956 | 
            +
                python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
         | 
| 1957 | 
            +
             | 
| 1958 | 
            +
                """, flush=True)
         | 
| 1959 | 
            +
                fire.Fire(main)
         | 
    	
        client_test.py
    ADDED
    
    | @@ -0,0 +1,93 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """
         | 
| 2 | 
            +
            Client test.
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            Run server:
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            python generate.py  --base_model=h2oai/h2ogpt-oig-oasst1-256-6.9b
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            NOTE: For private models, add --use-auth_token=True
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            NOTE: --infer_devices=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches.
         | 
| 11 | 
            +
            Currently, this will force model to be on a single GPU.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            Then run this client as:
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            python client_test.py
         | 
| 16 | 
            +
            """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            debug = False
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import os
         | 
| 21 | 
            +
            os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
         | 
| 22 | 
            +
            from gradio_client import Client
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            client = Client("http://localhost:7860")
         | 
| 25 | 
            +
            if debug:
         | 
| 26 | 
            +
                print(client.view_api(all_endpoints=True))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            instruction = ''  # only for chat=True
         | 
| 29 | 
            +
            iinput = ''  # only for chat=True
         | 
| 30 | 
            +
            context = ''
         | 
| 31 | 
            +
            # streaming output is supported, loops over and outputs each generation in streaming mode
         | 
| 32 | 
            +
            # but leave stream_output=False for simple input/output mode
         | 
| 33 | 
            +
            stream_output = False
         | 
| 34 | 
            +
            prompt_type = 'human_bot'
         | 
| 35 | 
            +
            temperature = 0.1
         | 
| 36 | 
            +
            top_p = 0.75
         | 
| 37 | 
            +
            top_k = 40
         | 
| 38 | 
            +
            num_beams = 1
         | 
| 39 | 
            +
            max_new_tokens = 50
         | 
| 40 | 
            +
            min_new_tokens = 0
         | 
| 41 | 
            +
            early_stopping = False
         | 
| 42 | 
            +
            max_time = 20
         | 
| 43 | 
            +
            repetition_penalty = 1.0
         | 
| 44 | 
            +
            num_return_sequences = 1
         | 
| 45 | 
            +
            do_sample = True
         | 
| 46 | 
            +
            # only these 2 below used if pass chat=False
         | 
| 47 | 
            +
            chat = False
         | 
| 48 | 
            +
            instruction_nochat = "Who are you?"
         | 
| 49 | 
            +
            iinput_nochat = ''
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def test_client_basic():
         | 
| 53 | 
            +
                args = [instruction,
         | 
| 54 | 
            +
                        iinput,
         | 
| 55 | 
            +
                        context,
         | 
| 56 | 
            +
                        stream_output,
         | 
| 57 | 
            +
                        prompt_type,
         | 
| 58 | 
            +
                        temperature,
         | 
| 59 | 
            +
                        top_p,
         | 
| 60 | 
            +
                        top_k,
         | 
| 61 | 
            +
                        num_beams,
         | 
| 62 | 
            +
                        max_new_tokens,
         | 
| 63 | 
            +
                        min_new_tokens,
         | 
| 64 | 
            +
                        early_stopping,
         | 
| 65 | 
            +
                        max_time,
         | 
| 66 | 
            +
                        repetition_penalty,
         | 
| 67 | 
            +
                        num_return_sequences,
         | 
| 68 | 
            +
                        do_sample,
         | 
| 69 | 
            +
                        chat,
         | 
| 70 | 
            +
                        instruction_nochat,
         | 
| 71 | 
            +
                        iinput_nochat,
         | 
| 72 | 
            +
                        ]
         | 
| 73 | 
            +
                api_name = '/submit_nochat'
         | 
| 74 | 
            +
                res = client.predict(
         | 
| 75 | 
            +
                    *tuple(args),
         | 
| 76 | 
            +
                    api_name=api_name,
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
                res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
         | 
| 79 | 
            +
                print(res_dict)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            import markdown  # pip install markdown
         | 
| 83 | 
            +
            from bs4 import BeautifulSoup  # pip install beautifulsoup4
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def md_to_text(md):
         | 
| 87 | 
            +
                html = markdown.markdown(md)
         | 
| 88 | 
            +
                soup = BeautifulSoup(html, features='html.parser')
         | 
| 89 | 
            +
                return soup.get_text()
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            if __name__ == '__main__':
         | 
| 93 | 
            +
                test_client_basic()
         | 
    	
        finetune.py
    ADDED
    
    | @@ -0,0 +1,934 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import pathlib
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import shutil
         | 
| 5 | 
            +
            import subprocess
         | 
| 6 | 
            +
            import sys
         | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            from datetime import datetime
         | 
| 9 | 
            +
            from typing import List, Union
         | 
| 10 | 
            +
            import fire
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            from datasets import load_dataset, concatenate_datasets
         | 
| 14 | 
            +
            import transformers
         | 
| 15 | 
            +
            import torch.distributed as dist
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from peft import (
         | 
| 18 | 
            +
                prepare_model_for_int8_training,
         | 
| 19 | 
            +
                LoraConfig,
         | 
| 20 | 
            +
                get_peft_model,
         | 
| 21 | 
            +
                get_peft_model_state_dict,
         | 
| 22 | 
            +
                set_peft_model_state_dict,
         | 
| 23 | 
            +
            )
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from peft import mapping
         | 
| 26 | 
            +
            lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def log(*args, **kwargs):
         | 
| 30 | 
            +
                if int(os.environ.get("LOCAL_RANK", 0)) == 0:
         | 
| 31 | 
            +
                    print(*args, **kwargs)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            try:
         | 
| 35 | 
            +
                import neptune
         | 
| 36 | 
            +
                from transformers.integrations import NeptuneCallback
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                neptune_run = neptune.init_run(
         | 
| 39 | 
            +
                    source_files=[],
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                log("Connected to Neptune.")
         | 
| 42 | 
            +
            except ImportError:
         | 
| 43 | 
            +
                neptune_run = None
         | 
| 44 | 
            +
                log("Please pip install neptune for tracking.")
         | 
| 45 | 
            +
            except neptune.exceptions.NeptuneMissingApiTokenException:
         | 
| 46 | 
            +
                neptune_run = None
         | 
| 47 | 
            +
                os.environ["NEPTUNE_MODE"] = 'debug'
         | 
| 48 | 
            +
                log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            from enum import Enum
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class PromptType(Enum):
         | 
| 54 | 
            +
                plain = 0
         | 
| 55 | 
            +
                instruct = 1
         | 
| 56 | 
            +
                quality = 2
         | 
| 57 | 
            +
                human_bot = 3
         | 
| 58 | 
            +
                dai_faq = 4
         | 
| 59 | 
            +
                summarize = 5
         | 
| 60 | 
            +
                simple_instruct = 6
         | 
| 61 | 
            +
                instruct_vicuna = 7
         | 
| 62 | 
            +
                instruct_with_end = 8
         | 
| 63 | 
            +
                human_bot_orig = 9
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            prompt_type_to_model_name = {
         | 
| 67 | 
            +
                'plain': [
         | 
| 68 | 
            +
                    'EleutherAI/gpt-j-6B',
         | 
| 69 | 
            +
                    'EleutherAI/pythia-6.9b',
         | 
| 70 | 
            +
                    'EleutherAI/pythia-12b',
         | 
| 71 | 
            +
                    'EleutherAI/pythia-12b-deduped',
         | 
| 72 | 
            +
                    'EleutherAI/gpt-neox-20b',
         | 
| 73 | 
            +
                    'decapoda-research/llama-7b-hf',
         | 
| 74 | 
            +
                    'decapoda-research/llama-13b-hf',
         | 
| 75 | 
            +
                    'decapoda-research/llama-30b-hf',
         | 
| 76 | 
            +
                    'decapoda-research/llama-65b-hf',
         | 
| 77 | 
            +
                    'facebook/mbart-large-50-many-to-many-mmt',
         | 
| 78 | 
            +
                    'philschmid/bart-large-cnn-samsum',
         | 
| 79 | 
            +
                    'philschmid/flan-t5-base-samsum',
         | 
| 80 | 
            +
                    'gpt2',
         | 
| 81 | 
            +
                    'distilgpt2',
         | 
| 82 | 
            +
                ],
         | 
| 83 | 
            +
                'instruct': [],
         | 
| 84 | 
            +
                'instruct_with_end': ['databricks/dolly-v2-12b'],
         | 
| 85 | 
            +
                'quality': [],
         | 
| 86 | 
            +
                'human_bot': [
         | 
| 87 | 
            +
                    'h2oai/h2ogpt-oig-oasst1-256-12b',
         | 
| 88 | 
            +
                    'h2oai/h2ogpt-oasst1-512-12b',
         | 
| 89 | 
            +
                    'h2oai/h2ogpt-oasst1-256-20b',
         | 
| 90 | 
            +
                    'h2oai/h2ogpt-oasst1-512-20b',
         | 
| 91 | 
            +
                    'h2oai/h2ogpt-oig-oasst1-256-6.9b',
         | 
| 92 | 
            +
                ],
         | 
| 93 | 
            +
                'dai_faq': [],
         | 
| 94 | 
            +
                'summarize': [],
         | 
| 95 | 
            +
                'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
         | 
| 96 | 
            +
                'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
         | 
| 97 | 
            +
                'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
         | 
| 98 | 
            +
            }
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
         | 
| 101 | 
            +
            inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
         | 
| 102 | 
            +
             | 
| 103 | 
            +
            human = '<human>:'
         | 
| 104 | 
            +
            bot = "<bot>:"
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            prompt_types_strings = []
         | 
| 107 | 
            +
            for p in PromptType:
         | 
| 108 | 
            +
                prompt_types_strings.extend([p.name])
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            prompt_types = []
         | 
| 112 | 
            +
            for p in PromptType:
         | 
| 113 | 
            +
                prompt_types.extend([p.name, p.value, str(p.value)])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
             | 
| 116 | 
            +
            # supported by huggingface evaluate
         | 
| 117 | 
            +
            supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            def train(
         | 
| 121 | 
            +
                    save_code: bool = False,
         | 
| 122 | 
            +
                    run_id: int = None,
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    base_model: str = 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
         | 
| 125 | 
            +
                    # base_model: str = 'h2oai/h2ogpt-oasst1-512-12b',
         | 
| 126 | 
            +
                    # base_model: str = 'h2oai/h2ogpt-oasst1-512-20b',
         | 
| 127 | 
            +
                    # base_model: str = 'EleutherAI/gpt-neox-20b',
         | 
| 128 | 
            +
                    # base_model: str = 'EleutherAI/pythia-12b-deduped',
         | 
| 129 | 
            +
                    # base_model: str = 'togethercomputer/GPT-NeoXT-Chat-Base-20B',
         | 
| 130 | 
            +
                    # base_model: str = 'decapoda-research/llama-7b-hf',
         | 
| 131 | 
            +
                    # base_model: str = 'decapoda-research/llama-13b-hf',
         | 
| 132 | 
            +
                    # base_model: str = 'decapoda-research/llama-30b-hf',
         | 
| 133 | 
            +
                    # base_model: str = 'EleutherAI/gpt-j-6B',
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    # only needed if base_model is self-exported HF state without tokenizer
         | 
| 136 | 
            +
                    tokenizer_base_model: str = None,
         | 
| 137 | 
            +
                    # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    data_path: str = None,
         | 
| 140 | 
            +
                    data_col_dict: dict = None,
         | 
| 141 | 
            +
                    # data_path: str = "./dai_docs.train.json",
         | 
| 142 | 
            +
                    prompt_type: Union[str, int] = "plain",  # "plain", "instruct", "quality", "human_bot", "dai_faq"
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    valid_path: str = None,
         | 
| 145 | 
            +
                    # valid_path: str = "./dai_docs.valid.json",
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # data_mix_in_path: str = "laion/OIG",  # way too big, medium quality
         | 
| 148 | 
            +
                    data_mix_in_path: str = "0-hero/OIG-small-chip2",  # high quality, 50 MB, good enough for now
         | 
| 149 | 
            +
                    data_mix_in_factor: float = 0.0,  # >1: more mix-in data, <1: more of data_path data
         | 
| 150 | 
            +
                    data_mix_in_col_dict: dict = {'user': 'instruction', 'chip2': 'output'},
         | 
| 151 | 
            +
                    data_mix_in_prompt_type: str = "instruct",  # just instruction->output, same as instruct
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    output_dir: str = None,
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # LoRA checkpoint continuation
         | 
| 156 | 
            +
                    lora_weights: str = "",
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                    # batching training hyperparams
         | 
| 159 | 
            +
                    batch_size: int = 128,
         | 
| 160 | 
            +
                    micro_batch_size: int = 4,
         | 
| 161 | 
            +
                    gradient_checkpointing=False,  # unnecessary with gradient accumulation enabled
         | 
| 162 | 
            +
                    fp16=True,
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    # general training hyperparams
         | 
| 165 | 
            +
                    num_epochs: float = 1,
         | 
| 166 | 
            +
                    learning_rate: float = 3e-4,
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    # validation settings
         | 
| 169 | 
            +
                    val_set_size: int = None,
         | 
| 170 | 
            +
                    val_metrics: List[str] = [],
         | 
| 171 | 
            +
                    eval_steps: int = None,  # to control eval steps via steps
         | 
| 172 | 
            +
                    eval_epochs: float = None,  # to control eval steps via epochs
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # lora hyperparams
         | 
| 175 | 
            +
                    lora_r: int = 8,
         | 
| 176 | 
            +
                    lora_alpha: int = 16,
         | 
| 177 | 
            +
                    lora_dropout: float = 0.05,
         | 
| 178 | 
            +
                    lora_target_modules: List[str] = None,
         | 
| 179 | 
            +
                    llama_type: bool = None,
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    # llm hyperparams
         | 
| 182 | 
            +
                    train_on_inputs: bool = True,  # if False, masks out inputs in loss
         | 
| 183 | 
            +
                    group_by_length: bool = False,  # if True, faster, but produces an odd training loss curve
         | 
| 184 | 
            +
                    resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
         | 
| 185 | 
            +
                    cutoff_len: int = 1024,  # Good default, especially when have high quality non-trivial data
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # torch training params
         | 
| 188 | 
            +
                    ddp: bool = True,  # set to False if OOM with True, for multi-GPU model parallelism
         | 
| 189 | 
            +
                    local_files_only: bool = False,  # else will download new versions, normally unwanted
         | 
| 190 | 
            +
                    resume_download: bool = True,
         | 
| 191 | 
            +
                    use_auth_token: Union[str, bool] = False,  # True requires CLI did huggingface-cli login before running
         | 
| 192 | 
            +
                    warmup_steps: int = 100,
         | 
| 193 | 
            +
                    logging_steps: int = 1,
         | 
| 194 | 
            +
                    save_steps: int = None,  # must be round multiple of eval_steps
         | 
| 195 | 
            +
                    add_eos_token: bool = False,
         | 
| 196 | 
            +
            ):
         | 
| 197 | 
            +
                # allow set token directly
         | 
| 198 | 
            +
                use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                prompt_type = str(prompt_type)  # migration from integers
         | 
| 201 | 
            +
                assert prompt_type in prompt_types
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                world_size = int(os.getenv("WORLD_SIZE", 1))
         | 
| 204 | 
            +
                local_rank = int(os.getenv("LOCAL_RANK", 0))
         | 
| 205 | 
            +
                rank = int(os.getenv("RANK", 0))
         | 
| 206 | 
            +
                print(f"local_rank: {local_rank}")
         | 
| 207 | 
            +
                print(f"global rank: {rank}")
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                gpus = max(world_size, torch.cuda.device_count())
         | 
| 210 | 
            +
                run_id = run_id or 0
         | 
| 211 | 
            +
                if not data_path:
         | 
| 212 | 
            +
                    raise ValueError("No data_path provided")
         | 
| 213 | 
            +
                if not output_dir:
         | 
| 214 | 
            +
                    output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
         | 
| 215 | 
            +
                    if os.path.exists(output_dir) and not resume_from_checkpoint:
         | 
| 216 | 
            +
                        raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
         | 
| 217 | 
            +
                else:
         | 
| 218 | 
            +
                    if os.path.exists(output_dir) and not resume_from_checkpoint:
         | 
| 219 | 
            +
                        raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
         | 
| 220 | 
            +
                device_map = "auto"
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                if save_code:
         | 
| 223 | 
            +
                    copy_code(run_id)
         | 
| 224 | 
            +
                if tokenizer_base_model is None:
         | 
| 225 | 
            +
                    tokenizer_base_model = base_model
         | 
| 226 | 
            +
                if llama_type is None:
         | 
| 227 | 
            +
                    llama_type = "llama" in base_model.lower()
         | 
| 228 | 
            +
                assert (
         | 
| 229 | 
            +
                    base_model
         | 
| 230 | 
            +
                ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
         | 
| 231 | 
            +
                gradient_accumulation_steps = batch_size // micro_batch_size
         | 
| 232 | 
            +
                assert gradient_accumulation_steps >= world_size, "must increase batch_size for multi-GPU"
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                device_map = "auto"
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                locals_dict = locals()
         | 
| 237 | 
            +
                locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
         | 
| 238 | 
            +
                log(f"Training model with params:\n{locals_print}")
         | 
| 239 | 
            +
                log("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()))
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                max_memory = None
         | 
| 242 | 
            +
                if gpus > 1:
         | 
| 243 | 
            +
                    if ddp:
         | 
| 244 | 
            +
                        log("Distributed: data parallel")
         | 
| 245 | 
            +
                        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
         | 
| 246 | 
            +
                        gradient_accumulation_steps = gradient_accumulation_steps // world_size
         | 
| 247 | 
            +
                    else:
         | 
| 248 | 
            +
                        free_in_GB = int(min(torch.cuda.mem_get_info()) / 1024 ** 3)
         | 
| 249 | 
            +
                        max_memory = f"{free_in_GB - 2}GB"
         | 
| 250 | 
            +
                        max_memory = {i: max_memory for i in range(gpus)}
         | 
| 251 | 
            +
                        log("world_size: %d" % world_size)
         | 
| 252 | 
            +
                        log("num_gpus: %d" % gpus)
         | 
| 253 | 
            +
                        log("max mem: %s" % max_memory)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                model = model_loader.from_pretrained(
         | 
| 258 | 
            +
                    base_model,
         | 
| 259 | 
            +
                    load_in_8bit=True,
         | 
| 260 | 
            +
                    device_map=device_map,
         | 
| 261 | 
            +
                    torch_dtype=torch.float16,
         | 
| 262 | 
            +
                    max_memory=max_memory,
         | 
| 263 | 
            +
                    local_files_only=local_files_only,
         | 
| 264 | 
            +
                    resume_download=resume_download,
         | 
| 265 | 
            +
                    use_auth_token=use_auth_token,
         | 
| 266 | 
            +
                )
         | 
| 267 | 
            +
                if gpus > 1:
         | 
| 268 | 
            +
                    if not ddp:
         | 
| 269 | 
            +
                        log("model parallel")
         | 
| 270 | 
            +
                        model.is_parallelizable = True
         | 
| 271 | 
            +
                        model.model_parallel = True
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
         | 
| 274 | 
            +
                                                             local_files_only=local_files_only,
         | 
| 275 | 
            +
                                                             resume_download=resume_download,
         | 
| 276 | 
            +
                                                             use_auth_token=use_auth_token)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                tokenizer.pad_token_id = 0  # different from the eos token
         | 
| 279 | 
            +
                # when generating, we will use the logits of right-most token to predict the next token
         | 
| 280 | 
            +
                # so the padding should be on the left,
         | 
| 281 | 
            +
                # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
         | 
| 282 | 
            +
                tokenizer.padding_side = "left"  # Allow batched inference
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def tokenize(prompt, add_eos_token=True):
         | 
| 285 | 
            +
                    # there's probably a way to do this with the tokenizer settings
         | 
| 286 | 
            +
                    # but again, gotta move fast
         | 
| 287 | 
            +
                    result = tokenizer(
         | 
| 288 | 
            +
                        prompt,
         | 
| 289 | 
            +
                        truncation=True,
         | 
| 290 | 
            +
                        max_length=cutoff_len,
         | 
| 291 | 
            +
                        padding=False,
         | 
| 292 | 
            +
                        return_tensors=None,
         | 
| 293 | 
            +
                    )
         | 
| 294 | 
            +
                    if (
         | 
| 295 | 
            +
                            result["input_ids"][-1] != tokenizer.eos_token_id
         | 
| 296 | 
            +
                            and len(result["input_ids"]) < cutoff_len
         | 
| 297 | 
            +
                            and add_eos_token
         | 
| 298 | 
            +
                    ):
         | 
| 299 | 
            +
                        result["input_ids"].append(tokenizer.eos_token_id)
         | 
| 300 | 
            +
                        result["attention_mask"].append(1)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    result["labels"] = result["input_ids"].copy()
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    return result
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
         | 
| 307 | 
            +
                    full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
         | 
| 308 | 
            +
                    tokenized_full_prompt = tokenize(full_prompt)
         | 
| 309 | 
            +
                    if not train_on_inputs:
         | 
| 310 | 
            +
                        user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
         | 
| 311 | 
            +
                        tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
         | 
| 312 | 
            +
                        user_prompt_len = len(tokenized_user_prompt["input_ids"])
         | 
| 313 | 
            +
                        if add_eos:
         | 
| 314 | 
            +
                            user_prompt_len -= 1
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                        # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
         | 
| 317 | 
            +
                        tokenized_full_prompt["labels"] = [
         | 
| 318 | 
            +
                                                              -100
         | 
| 319 | 
            +
                                                          ] * user_prompt_len + tokenized_full_prompt["labels"][
         | 
| 320 | 
            +
                                                                                user_prompt_len:
         | 
| 321 | 
            +
                                                                                ]  # could be sped up, probably
         | 
| 322 | 
            +
                    return tokenized_full_prompt
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                if "gpt-neox" not in base_model or True:
         | 
| 325 | 
            +
                    model = prepare_model_for_int8_training(model)
         | 
| 326 | 
            +
                else:
         | 
| 327 | 
            +
                    model = prepare_model_for_int8_training(
         | 
| 328 | 
            +
                        model,
         | 
| 329 | 
            +
                        output_embedding_layer_name="embed_out",  # keep output logits in float32
         | 
| 330 | 
            +
                        layer_norm_names=["layer_norm", "layernorm"],  # keep all layer norms in higher precision
         | 
| 331 | 
            +
                    )
         | 
| 332 | 
            +
                if lora_weights:
         | 
| 333 | 
            +
                    from peft import PeftModel
         | 
| 334 | 
            +
                    model = PeftModel.from_pretrained(
         | 
| 335 | 
            +
                        model,
         | 
| 336 | 
            +
                        lora_weights,
         | 
| 337 | 
            +
                        torch_dtype=torch.float16,
         | 
| 338 | 
            +
                        device_map=device_map,
         | 
| 339 | 
            +
                        local_files_only=local_files_only,
         | 
| 340 | 
            +
                        resume_download=resume_download,
         | 
| 341 | 
            +
                        use_auth_token=use_auth_token,
         | 
| 342 | 
            +
                    )
         | 
| 343 | 
            +
                else:
         | 
| 344 | 
            +
                    if lora_target_modules is None:
         | 
| 345 | 
            +
                        base_model_lower = base_model.lower()
         | 
| 346 | 
            +
                        if base_model_lower in lora_mappings:
         | 
| 347 | 
            +
                            lora_target_modules_cand = [lora_mappings[base_model_lower]]
         | 
| 348 | 
            +
                        else:
         | 
| 349 | 
            +
                            lora_target_modules_cand = [["query_key_value"], ["q_proj", "v_proj"]]
         | 
| 350 | 
            +
                    else:
         | 
| 351 | 
            +
                        lora_target_modules_cand = [lora_target_modules]
         | 
| 352 | 
            +
             | 
| 353 | 
            +
                    for lora_target_modules in lora_target_modules_cand:
         | 
| 354 | 
            +
                        try:
         | 
| 355 | 
            +
                            config = LoraConfig(
         | 
| 356 | 
            +
                                r=lora_r,
         | 
| 357 | 
            +
                                lora_alpha=lora_alpha,
         | 
| 358 | 
            +
                                target_modules=lora_target_modules,
         | 
| 359 | 
            +
                                lora_dropout=lora_dropout,
         | 
| 360 | 
            +
                                bias="none",
         | 
| 361 | 
            +
                                task_type="CAUSAL_LM",
         | 
| 362 | 
            +
                            )
         | 
| 363 | 
            +
                            model = get_peft_model(model, config)
         | 
| 364 | 
            +
                            break
         | 
| 365 | 
            +
                        except ValueError as e:
         | 
| 366 | 
            +
                            if "Target modules" in str(e) and "not found" in str(e):
         | 
| 367 | 
            +
                                continue
         | 
| 368 | 
            +
                            else:
         | 
| 369 | 
            +
                                raise
         | 
| 370 | 
            +
                    from peft import PeftModel
         | 
| 371 | 
            +
                    assert isinstance(model, PeftModel), "LoRA failed. Please provide --lora_target_modules explicitly."
         | 
| 372 | 
            +
                if resume_from_checkpoint:
         | 
| 373 | 
            +
                    # Check the available weights and load them
         | 
| 374 | 
            +
                    checkpoint_name = os.path.join(
         | 
| 375 | 
            +
                        resume_from_checkpoint, "pytorch_model.bin"
         | 
| 376 | 
            +
                    )  # Full checkpoint
         | 
| 377 | 
            +
                    if not os.path.exists(checkpoint_name):
         | 
| 378 | 
            +
                        checkpoint_name = os.path.join(
         | 
| 379 | 
            +
                            resume_from_checkpoint, "adapter_model.bin"
         | 
| 380 | 
            +
                        )  # only LoRA model - LoRA config above has to fit
         | 
| 381 | 
            +
                        resume_from_checkpoint = False  # So the trainer won't try loading its state
         | 
| 382 | 
            +
                    # The two files above have a different name depending on how they were saved, but are actually the same.
         | 
| 383 | 
            +
                    if os.path.exists(checkpoint_name):
         | 
| 384 | 
            +
                        log(f"Restarting from {checkpoint_name}")
         | 
| 385 | 
            +
                        adapters_weights = torch.load(checkpoint_name)
         | 
| 386 | 
            +
                        model = set_peft_model_state_dict(model, adapters_weights)
         | 
| 387 | 
            +
                    else:
         | 
| 388 | 
            +
                        log(f"Checkpoint {checkpoint_name} not found")
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                print(model)
         | 
| 391 | 
            +
                model.print_trainable_parameters()  # Be more transparent about the % of trainable params.
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                metrics = {}
         | 
| 394 | 
            +
                for name in supported_metrics:
         | 
| 395 | 
            +
                    if name in val_metrics:
         | 
| 396 | 
            +
                        import evaluate  # Causes hang for 'python generate.py' on dual 4090 if imported early, 100% reproducible
         | 
| 397 | 
            +
                        metrics[name] = evaluate.load(name)
         | 
| 398 | 
            +
                log("Using Validation Metrics: %s" % str(list(metrics.keys())))
         | 
| 399 | 
            +
                log("Supported Metrics: %s" % supported_metrics)
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                if val_set_size is None:
         | 
| 402 | 
            +
                    if len(metrics) == 0:
         | 
| 403 | 
            +
                        val_set_size = 1000
         | 
| 404 | 
            +
                    else:
         | 
| 405 | 
            +
                        val_set_size = 100
         | 
| 406 | 
            +
                    log("Auto set val_set_size %s" % val_set_size)
         | 
| 407 | 
            +
                elif val_set_size < 1.0 and val_set_size != 0:
         | 
| 408 | 
            +
                    raise RuntimeError("Fractional validation size not supported.")
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                if valid_path:
         | 
| 411 | 
            +
                    data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
         | 
| 412 | 
            +
                else:
         | 
| 413 | 
            +
                    if "json" in data_path:
         | 
| 414 | 
            +
                        data = load_dataset("json", data_files={"train": data_path})
         | 
| 415 | 
            +
                    else:
         | 
| 416 | 
            +
                        data = load_dataset(data_path)
         | 
| 417 | 
            +
                        data = data.rename_columns(data_col_dict or {})
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                valid_data = None
         | 
| 420 | 
            +
                train_data_mix_in = None
         | 
| 421 | 
            +
                valid_data_mix_in = None
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                if data_mix_in_path and data_mix_in_factor > 0:
         | 
| 424 | 
            +
                    # get mix-in training/validation data - to keep model "sane"
         | 
| 425 | 
            +
                    num_rows = data["train"].num_rows
         | 
| 426 | 
            +
                    log("Loading mix-in dataset: %s" % data_mix_in_path)
         | 
| 427 | 
            +
                    if "json" in data_mix_in_path:
         | 
| 428 | 
            +
                        data_mix_in = load_dataset("json", data_files={"train": data_mix_in_path})["train"]
         | 
| 429 | 
            +
                    else:
         | 
| 430 | 
            +
                        data_mix_in = load_dataset(data_mix_in_path)["train"]  # can be large
         | 
| 431 | 
            +
                    data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # only get as much as we need to balance
         | 
| 434 | 
            +
                    valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
         | 
| 435 | 
            +
                    train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
         | 
| 436 | 
            +
                    mixin_small = data_mix_in.train_test_split(
         | 
| 437 | 
            +
                        test_size=train_size + valid_size,
         | 
| 438 | 
            +
                        shuffle=True, seed=np.random.randint(10000),
         | 
| 439 | 
            +
                    )["test"]
         | 
| 440 | 
            +
                    if valid_size:
         | 
| 441 | 
            +
                        mixin_train_test = mixin_small.train_test_split(
         | 
| 442 | 
            +
                            test_size=valid_size, shuffle=False,
         | 
| 443 | 
            +
                        )
         | 
| 444 | 
            +
                        train_data_mix_in = mixin_train_test["train"]
         | 
| 445 | 
            +
                        valid_data_mix_in = mixin_train_test["test"]
         | 
| 446 | 
            +
                    else:
         | 
| 447 | 
            +
                        train_data_mix_in = mixin_small
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    if "prompt_type" not in train_data_mix_in.column_names:
         | 
| 450 | 
            +
                        train_data_mix_in = train_data_mix_in.add_column(
         | 
| 451 | 
            +
                            "prompt_type",
         | 
| 452 | 
            +
                            [data_mix_in_prompt_type] * train_data_mix_in.num_rows,
         | 
| 453 | 
            +
                        )
         | 
| 454 | 
            +
                        log("Added prompt type %s to mix-in training data" % data_mix_in_prompt_type)
         | 
| 455 | 
            +
                    if valid_data_mix_in and "prompt_type" not in valid_data_mix_in.column_names:
         | 
| 456 | 
            +
                        valid_data_mix_in = valid_data_mix_in.add_column(
         | 
| 457 | 
            +
                            "prompt_type",
         | 
| 458 | 
            +
                            [data_mix_in_prompt_type] * valid_data_mix_in.num_rows,
         | 
| 459 | 
            +
                        )
         | 
| 460 | 
            +
                        log("Added prompt type %s to mix-in validation data" % data_mix_in_prompt_type)
         | 
| 461 | 
            +
                    log("Created mix-in data:\nTrain %s\nValid %s" % (train_data_mix_in, valid_data_mix_in))
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                # get our own training/validation data - for fine-tuning
         | 
| 464 | 
            +
                if val_set_size > 0 and not valid_path and not data_mix_in_path:
         | 
| 465 | 
            +
                    # create valid split from train
         | 
| 466 | 
            +
                    train_val = data["train"].train_test_split(
         | 
| 467 | 
            +
                        test_size=val_set_size, shuffle=True, seed=42
         | 
| 468 | 
            +
                    )
         | 
| 469 | 
            +
                    train_data = train_val["train"]
         | 
| 470 | 
            +
                    valid_data = train_val["test"]
         | 
| 471 | 
            +
                else:
         | 
| 472 | 
            +
                    train_data = data["train"]
         | 
| 473 | 
            +
                    if valid_path:
         | 
| 474 | 
            +
                        # use given valid split, has priority over data_mix_in_path
         | 
| 475 | 
            +
                        valid_data = data["valid"]
         | 
| 476 | 
            +
                if "prompt_type" not in train_data.column_names:
         | 
| 477 | 
            +
                    train_data = train_data.add_column(
         | 
| 478 | 
            +
                        "prompt_type",
         | 
| 479 | 
            +
                        [prompt_type] * train_data.num_rows,
         | 
| 480 | 
            +
                    )
         | 
| 481 | 
            +
                    log("Added prompt type %s to training data" % prompt_type)
         | 
| 482 | 
            +
                if valid_data and "prompt_type" not in valid_data.column_names:
         | 
| 483 | 
            +
                    valid_data = valid_data.add_column(
         | 
| 484 | 
            +
                        "prompt_type",
         | 
| 485 | 
            +
                        [prompt_type] * valid_data.num_rows,
         | 
| 486 | 
            +
                    )
         | 
| 487 | 
            +
                    log("Added prompt type %s to validation data" % prompt_type)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                assert train_data is not None
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                # shuffle and tokenize data
         | 
| 492 | 
            +
                if train_data_mix_in:
         | 
| 493 | 
            +
                    train_data = concatenate_datasets([train_data, train_data_mix_in])
         | 
| 494 | 
            +
                train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
         | 
| 495 | 
            +
                train_set_size = len(train_data)
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                if valid_data and valid_data_mix_in:
         | 
| 498 | 
            +
                    valid_data = concatenate_datasets([valid_data, valid_data_mix_in])
         | 
| 499 | 
            +
                elif valid_data_mix_in:
         | 
| 500 | 
            +
                    valid_data = valid_data_mix_in
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                if valid_data:
         | 
| 503 | 
            +
                    valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
         | 
| 504 | 
            +
                    val_set_size = len(valid_data)
         | 
| 505 | 
            +
                else:
         | 
| 506 | 
            +
                    val_set_size = 0
         | 
| 507 | 
            +
                log("Final fine-tuning data:\nTrain %s\nValid %s" % (train_data, valid_data))
         | 
| 508 | 
            +
                sample_row_dict = train_data[:1]
         | 
| 509 | 
            +
                del sample_row_dict['input_ids']
         | 
| 510 | 
            +
                del sample_row_dict['attention_mask']
         | 
| 511 | 
            +
                del sample_row_dict['labels']
         | 
| 512 | 
            +
                log("Sample input: %s" % sample_row_dict)
         | 
| 513 | 
            +
             | 
| 514 | 
            +
                if neptune_run:
         | 
| 515 | 
            +
                    neptune_callback = NeptuneCallback(run=neptune_run)
         | 
| 516 | 
            +
                    callbacks = [neptune_callback]
         | 
| 517 | 
            +
                else:
         | 
| 518 | 
            +
                    from transformers.integrations import TensorBoardCallback, is_tensorboard_available
         | 
| 519 | 
            +
                    if is_tensorboard_available:
         | 
| 520 | 
            +
                        # tensorboard --logdir=runs/
         | 
| 521 | 
            +
                        from torch.utils.tensorboard import SummaryWriter
         | 
| 522 | 
            +
                        tb_writer = SummaryWriter()
         | 
| 523 | 
            +
                        callbacks = [TensorBoardCallback(tb_writer=tb_writer)]
         | 
| 524 | 
            +
                    else:
         | 
| 525 | 
            +
                        callbacks = []
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                expected_steps = (train_set_size * num_epochs) // batch_size
         | 
| 528 | 
            +
                if eval_steps is None and eval_epochs is None:
         | 
| 529 | 
            +
                    # 20 evaluations for a run
         | 
| 530 | 
            +
                    eval_steps = max(1, int(expected_steps / 20))
         | 
| 531 | 
            +
                    log("Auto set eval_steps to %s out of %s total training steps" % (eval_steps, expected_steps))
         | 
| 532 | 
            +
                elif eval_steps is None and eval_epochs is not None:
         | 
| 533 | 
            +
                    eval_steps = max(1, int(expected_steps * eval_epochs / num_epochs))
         | 
| 534 | 
            +
                    log("Auto converted eval_epochs=%s to eval_steps %s"
         | 
| 535 | 
            +
                        " out of %s total training steps" % (eval_epochs, eval_steps, expected_steps))
         | 
| 536 | 
            +
                if save_steps is None:
         | 
| 537 | 
            +
                    save_steps = eval_steps
         | 
| 538 | 
            +
                    log("Auto step save_steps to %s" % save_steps)
         | 
| 539 | 
            +
                elif save_steps > eval_steps:
         | 
| 540 | 
            +
                    # save steps must be round multiple of eval_steps
         | 
| 541 | 
            +
                    save_steps0 = save_steps
         | 
| 542 | 
            +
                    save_steps = max(1, (save_steps//eval_steps)) * eval_steps
         | 
| 543 | 
            +
                    if save_steps0 != save_steps:
         | 
| 544 | 
            +
                        log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps))
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                def compute_metrics(eval_preds):
         | 
| 547 | 
            +
                    # e.g. see: https://huggingface.co/docs/transformers/v4.25.1/en/tasks/translation#evaluate
         | 
| 548 | 
            +
                    inputs = eval_preds.inputs
         | 
| 549 | 
            +
                    label_ids = eval_preds.label_ids
         | 
| 550 | 
            +
                    predictions = eval_preds.predictions
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id)
         | 
| 553 | 
            +
                    #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True)
         | 
| 554 | 
            +
                    #decoded_inputs = [pred.strip() for pred in decoded_inputs]
         | 
| 555 | 
            +
             | 
| 556 | 
            +
                    label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
         | 
| 557 | 
            +
                    # tokenizer behavior like generate time
         | 
| 558 | 
            +
                    decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True,
         | 
| 559 | 
            +
                                                                       clean_up_tokenization_spaces=True)
         | 
| 560 | 
            +
                    decoded_labels = [pred.strip() for pred in decoded_labels]
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                    predictions = np.argmax(predictions, -1)
         | 
| 563 | 
            +
                    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
         | 
| 564 | 
            +
                    # tokenizer behavior like generate time
         | 
| 565 | 
            +
                    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True,
         | 
| 566 | 
            +
                                                                              clean_up_tokenization_spaces=True)
         | 
| 567 | 
            +
                    decoded_predictions = [pred.strip() for pred in decoded_predictions]
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                    result = {}
         | 
| 570 | 
            +
                    for metric in metrics.values():
         | 
| 571 | 
            +
                        result1 = metric.compute(predictions=decoded_predictions, references=decoded_labels)
         | 
| 572 | 
            +
                        # get rid of lists, for precision etc., for now
         | 
| 573 | 
            +
                        numeric_results = {k: v for k, v in result1.items() if isinstance(v, (int, float))}
         | 
| 574 | 
            +
                        result.update(numeric_results)
         | 
| 575 | 
            +
                    return result
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                # the callback that computes metrics of interest
         | 
| 578 | 
            +
                if val_metrics:
         | 
| 579 | 
            +
                    trainer_kwargs = dict(compute_metrics=compute_metrics)
         | 
| 580 | 
            +
                else:
         | 
| 581 | 
            +
                    trainer_kwargs = dict()
         | 
| 582 | 
            +
             | 
| 583 | 
            +
                trainer = transformers.Trainer(
         | 
| 584 | 
            +
                    model=model,
         | 
| 585 | 
            +
                    tokenizer=tokenizer,
         | 
| 586 | 
            +
                    train_dataset=train_data,
         | 
| 587 | 
            +
                    eval_dataset=valid_data,
         | 
| 588 | 
            +
                    # NOTE: CausalLM is not supporting Seq2SeqTrainingArguments arguments, but not incompatible
         | 
| 589 | 
            +
                    args=transformers.Seq2SeqTrainingArguments(
         | 
| 590 | 
            +
                        per_device_train_batch_size=micro_batch_size,
         | 
| 591 | 
            +
                        per_device_eval_batch_size=1,
         | 
| 592 | 
            +
                        eval_accumulation_steps=10,
         | 
| 593 | 
            +
                        # predict_with_generate=True,  # SEQ2SEQ only
         | 
| 594 | 
            +
                        include_inputs_for_metrics=True,
         | 
| 595 | 
            +
                        gradient_accumulation_steps=gradient_accumulation_steps,
         | 
| 596 | 
            +
                        warmup_steps=warmup_steps,
         | 
| 597 | 
            +
                        num_train_epochs=num_epochs,
         | 
| 598 | 
            +
                        learning_rate=learning_rate,
         | 
| 599 | 
            +
                        gradient_checkpointing=gradient_checkpointing,
         | 
| 600 | 
            +
                        fp16=fp16,
         | 
| 601 | 
            +
                        # cosnider 8-bit adam: https://huggingface.co/docs/transformers/v4.18.0/en/performance#8bit-adam
         | 
| 602 | 
            +
                        optim="adamw_torch",  # consider "adafactor" to save memory
         | 
| 603 | 
            +
                        logging_steps=logging_steps,
         | 
| 604 | 
            +
                        logging_strategy="steps",
         | 
| 605 | 
            +
                        evaluation_strategy="steps" if val_set_size > 0 else "no",
         | 
| 606 | 
            +
                        save_strategy="steps",
         | 
| 607 | 
            +
                        eval_steps=eval_steps if val_set_size > 0 else None,
         | 
| 608 | 
            +
                        save_steps=save_steps,
         | 
| 609 | 
            +
                        output_dir=output_dir,
         | 
| 610 | 
            +
                        save_total_limit=3,
         | 
| 611 | 
            +
                        load_best_model_at_end=True if val_set_size > 0 else False,
         | 
| 612 | 
            +
                        ddp_find_unused_parameters=False if ddp else None,
         | 
| 613 | 
            +
                        group_by_length=group_by_length,
         | 
| 614 | 
            +
                        #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None,
         | 
| 615 | 
            +
                        #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None,
         | 
| 616 | 
            +
                        report_to='tensorboard' if not neptune_run else 'neptune',
         | 
| 617 | 
            +
                    ),
         | 
| 618 | 
            +
                    data_collator=transformers.DataCollatorForSeq2Seq(
         | 
| 619 | 
            +
                        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
         | 
| 620 | 
            +
                    ),
         | 
| 621 | 
            +
                    callbacks=callbacks,
         | 
| 622 | 
            +
                    **trainer_kwargs,
         | 
| 623 | 
            +
                )
         | 
| 624 | 
            +
                model.config.use_cache = False
         | 
| 625 | 
            +
             | 
| 626 | 
            +
                old_state_dict = model.state_dict
         | 
| 627 | 
            +
                model.state_dict = (
         | 
| 628 | 
            +
                    lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
         | 
| 629 | 
            +
                ).__get__(model, type(model))
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                if torch.__version__ >= "2" and sys.platform != "win32":
         | 
| 632 | 
            +
                    model = torch.compile(model)
         | 
| 633 | 
            +
                    # WIP (not generally replacing layers until pytorch 2.1)
         | 
| 634 | 
            +
                    torch.backends.cuda.enable_flash_sdp(True)
         | 
| 635 | 
            +
             | 
| 636 | 
            +
                if gpus > 1 and not ddp:
         | 
| 637 | 
            +
                    assert trainer.is_model_parallel
         | 
| 638 | 
            +
                else:
         | 
| 639 | 
            +
                    assert not trainer.is_model_parallel
         | 
| 640 | 
            +
                trainer.train(resume_from_checkpoint=resume_from_checkpoint)
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                model.save_pretrained(output_dir)
         | 
| 643 | 
            +
             | 
| 644 | 
            +
                log("\n If there's a warning about missing keys above, please disregard :)")
         | 
| 645 | 
            +
             | 
| 646 | 
            +
             | 
| 647 | 
            +
            def get_loaders(llama_type, model_name, reward_type):
         | 
| 648 | 
            +
                # NOTE: Some models need specific new prompt_type
         | 
| 649 | 
            +
                # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
         | 
| 650 | 
            +
                if llama_type:
         | 
| 651 | 
            +
                    from transformers import LlamaForCausalLM, LlamaTokenizer
         | 
| 652 | 
            +
                    model_loader = LlamaForCausalLM
         | 
| 653 | 
            +
                    tokenizer_loader = LlamaTokenizer
         | 
| 654 | 
            +
                elif 'gpt2' in model_name.lower():
         | 
| 655 | 
            +
                    from transformers import GPT2LMHeadModel, GPT2Tokenizer
         | 
| 656 | 
            +
                    return GPT2LMHeadModel, GPT2Tokenizer
         | 
| 657 | 
            +
                elif 'mbart-' in model_name.lower():
         | 
| 658 | 
            +
                    from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
         | 
| 659 | 
            +
                    return MBartForConditionalGeneration, MBart50TokenizerFast
         | 
| 660 | 
            +
                elif 't5' == model_name.lower() or \
         | 
| 661 | 
            +
                     't5-' in model_name.lower() or \
         | 
| 662 | 
            +
                     'flan-' in model_name.lower():
         | 
| 663 | 
            +
                    from transformers import AutoTokenizer, T5ForConditionalGeneration
         | 
| 664 | 
            +
                    return T5ForConditionalGeneration, AutoTokenizer
         | 
| 665 | 
            +
                elif 'bigbird' in model_name:
         | 
| 666 | 
            +
                    from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
         | 
| 667 | 
            +
                    return BigBirdPegasusForConditionalGeneration, AutoTokenizer
         | 
| 668 | 
            +
                elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
         | 
| 669 | 
            +
                    from transformers import pipeline
         | 
| 670 | 
            +
                    return pipeline, "summarization"
         | 
| 671 | 
            +
                elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
         | 
| 672 | 
            +
                    from transformers import AutoModelForSequenceClassification, AutoTokenizer
         | 
| 673 | 
            +
                    return AutoModelForSequenceClassification, AutoTokenizer
         | 
| 674 | 
            +
                else:
         | 
| 675 | 
            +
                    from transformers import AutoTokenizer, AutoModelForCausalLM
         | 
| 676 | 
            +
                    model_loader = AutoModelForCausalLM
         | 
| 677 | 
            +
                    tokenizer_loader = AutoTokenizer
         | 
| 678 | 
            +
                return model_loader, tokenizer_loader
         | 
| 679 | 
            +
             | 
| 680 | 
            +
             | 
| 681 | 
            +
            def get_githash():
         | 
| 682 | 
            +
                try:
         | 
| 683 | 
            +
                    githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
         | 
| 684 | 
            +
                except:
         | 
| 685 | 
            +
                    githash = ''
         | 
| 686 | 
            +
                return githash
         | 
| 687 | 
            +
             | 
| 688 | 
            +
             | 
| 689 | 
            +
            def copy_code(run_id):
         | 
| 690 | 
            +
                """
         | 
| 691 | 
            +
                copy code to track changes
         | 
| 692 | 
            +
                :param run_id:
         | 
| 693 | 
            +
                :return:
         | 
| 694 | 
            +
                """
         | 
| 695 | 
            +
                rnd_num = str(random.randint(0, 2 ** 31))
         | 
| 696 | 
            +
                run_id = 'run_' + str(run_id)
         | 
| 697 | 
            +
                os.makedirs(run_id, exist_ok=True)
         | 
| 698 | 
            +
                me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
         | 
| 699 | 
            +
                me_file = os.path.basename(__file__)
         | 
| 700 | 
            +
                new_me = os.path.join(run_id, me_file + '_' + get_githash())
         | 
| 701 | 
            +
                if os.path.isfile(new_me):
         | 
| 702 | 
            +
                    new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
         | 
| 703 | 
            +
                    shutil.copy(me_full, new_me)
         | 
| 704 | 
            +
                else:
         | 
| 705 | 
            +
                    shutil.copy(me_full, new_me)
         | 
| 706 | 
            +
             | 
| 707 | 
            +
             | 
| 708 | 
            +
            def get_prompt(prompt_type, chat, context, reduced):
         | 
| 709 | 
            +
                if prompt_type in [-1, "-1", "plain"]:
         | 
| 710 | 
            +
                    promptA = promptB = PreInstruct = PreInput = PreResponse = ''
         | 
| 711 | 
            +
                    terminate_response = []
         | 
| 712 | 
            +
                elif prompt_type == 'simple_instruct':
         | 
| 713 | 
            +
                    promptA = promptB = PreInstruct = PreInput = PreResponse = None
         | 
| 714 | 
            +
                    terminate_response = []
         | 
| 715 | 
            +
                elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
         | 
| 716 | 
            +
                    promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
         | 
| 717 | 
            +
                    promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    PreInstruct = """
         | 
| 720 | 
            +
            ### Instruction:
         | 
| 721 | 
            +
            """
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                    PreInput = """
         | 
| 724 | 
            +
            ### Input:
         | 
| 725 | 
            +
            """
         | 
| 726 | 
            +
             | 
| 727 | 
            +
                    PreResponse = """
         | 
| 728 | 
            +
            ### Response:
         | 
| 729 | 
            +
            """
         | 
| 730 | 
            +
                    if prompt_type in [7, "7", "instruct_with_end"]:
         | 
| 731 | 
            +
                        terminate_response = ['### End']
         | 
| 732 | 
            +
                    else:
         | 
| 733 | 
            +
                        terminate_response = None
         | 
| 734 | 
            +
                elif prompt_type in [1, "1", "quality"]:
         | 
| 735 | 
            +
                    promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
         | 
| 736 | 
            +
                    promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
         | 
| 737 | 
            +
             | 
| 738 | 
            +
                    PreInstruct = """
         | 
| 739 | 
            +
            ### Instruction:
         | 
| 740 | 
            +
            """
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                    PreInput = """
         | 
| 743 | 
            +
            ### Input:
         | 
| 744 | 
            +
            """
         | 
| 745 | 
            +
             | 
| 746 | 
            +
                    PreResponse = """
         | 
| 747 | 
            +
            ### Response:
         | 
| 748 | 
            +
            """
         | 
| 749 | 
            +
                    terminate_response = None
         | 
| 750 | 
            +
                elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
         | 
| 751 | 
            +
                    if reduced or context or prompt_type in [2, "2", "human_bot"]:
         | 
| 752 | 
            +
                        preprompt = ''
         | 
| 753 | 
            +
                    else:
         | 
| 754 | 
            +
                        cur_date = time.strftime('%Y-%m-%d')
         | 
| 755 | 
            +
                        cur_time = time.strftime('%H:%M:%S %p %Z')
         | 
| 756 | 
            +
             | 
| 757 | 
            +
                        PRE_PROMPT = """\
         | 
| 758 | 
            +
            Current Date: {}
         | 
| 759 | 
            +
            Current Time: {}
         | 
| 760 | 
            +
             | 
| 761 | 
            +
            """
         | 
| 762 | 
            +
                        preprompt = PRE_PROMPT.format(cur_date, cur_time)
         | 
| 763 | 
            +
                    start = human
         | 
| 764 | 
            +
                    promptB = promptA = '%s%s ' % (preprompt, start)
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                    PreInstruct = ""
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                    PreInput = None
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                    PreResponse = bot
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                    terminate_response = [start, PreResponse]
         | 
| 773 | 
            +
                elif prompt_type in [3, "3", "dai_faq"]:
         | 
| 774 | 
            +
                    promptA = ''
         | 
| 775 | 
            +
                    promptB = 'Answer the following Driverless AI question.\n'
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                    PreInstruct = """
         | 
| 778 | 
            +
            ### Driverless AI frequently asked question:
         | 
| 779 | 
            +
            """
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                    PreInput = None
         | 
| 782 | 
            +
             | 
| 783 | 
            +
                    PreResponse = """
         | 
| 784 | 
            +
            ### Driverless AI documentation answer:
         | 
| 785 | 
            +
            """
         | 
| 786 | 
            +
                    terminate_response = ['\n\n']
         | 
| 787 | 
            +
                elif prompt_type in [5, "5", "summarize"]:
         | 
| 788 | 
            +
                    promptA = promptB = PreInput = ''
         | 
| 789 | 
            +
                    PreInstruct = '## Main Text\n\n'
         | 
| 790 | 
            +
                    PreResponse = '\n\n## Summary\n\n'
         | 
| 791 | 
            +
                    terminate_response = None
         | 
| 792 | 
            +
                elif prompt_type in [6, "6", "instruct_vicuna"]:
         | 
| 793 | 
            +
                    promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
         | 
| 794 | 
            +
                        "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
         | 
| 795 | 
            +
             | 
| 796 | 
            +
                    PreInstruct = """
         | 
| 797 | 
            +
            ### Human:
         | 
| 798 | 
            +
            """
         | 
| 799 | 
            +
             | 
| 800 | 
            +
                    PreInput = None
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                    PreResponse = """
         | 
| 803 | 
            +
            ### Assistant:
         | 
| 804 | 
            +
            """
         | 
| 805 | 
            +
                    terminate_response = ['### Human:']  # but only allow terminate after prompt is found correctly, else can't terminate
         | 
| 806 | 
            +
                else:
         | 
| 807 | 
            +
                    raise RuntimeError("No such prompt_type=%s" % prompt_type)
         | 
| 808 | 
            +
             | 
| 809 | 
            +
                return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
         | 
| 810 | 
            +
             | 
| 811 | 
            +
             | 
| 812 | 
            +
            def generate_prompt(data_point, prompt_type, chat, reduced):
         | 
| 813 | 
            +
                context = data_point.get('context')
         | 
| 814 | 
            +
                if context is None:
         | 
| 815 | 
            +
                    context = ''
         | 
| 816 | 
            +
                instruction = data_point.get('instruction')
         | 
| 817 | 
            +
                input = data_point.get('input')
         | 
| 818 | 
            +
                output = data_point.get('output')
         | 
| 819 | 
            +
                prompt_type = data_point.get('prompt_type', prompt_type)
         | 
| 820 | 
            +
                assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
         | 
| 821 | 
            +
                promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                prompt = context
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                if input and promptA:
         | 
| 826 | 
            +
                    prompt += f"""{promptA}"""
         | 
| 827 | 
            +
                elif promptB:
         | 
| 828 | 
            +
                    prompt += f"""{promptB}"""
         | 
| 829 | 
            +
             | 
| 830 | 
            +
                if instruction and PreInstruct is not None and input and PreInput is not None:
         | 
| 831 | 
            +
                    prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
         | 
| 832 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 833 | 
            +
                elif instruction and input and PreInstruct is None and PreInput is not None:
         | 
| 834 | 
            +
                    prompt += f"""{PreInput}{instruction}
         | 
| 835 | 
            +
            {input}"""
         | 
| 836 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 837 | 
            +
                elif input and instruction and PreInput is None and PreInstruct is not None:
         | 
| 838 | 
            +
                    prompt += f"""{PreInstruct}{instruction}
         | 
| 839 | 
            +
            {input}"""
         | 
| 840 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 841 | 
            +
                elif instruction and PreInstruct is not None:
         | 
| 842 | 
            +
                    prompt += f"""{PreInstruct}{instruction}"""
         | 
| 843 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 844 | 
            +
                elif input and PreInput is not None:
         | 
| 845 | 
            +
                    prompt += f"""{PreInput}{input}"""
         | 
| 846 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 847 | 
            +
                elif input and instruction and PreInput is not None:
         | 
| 848 | 
            +
                    prompt += f"""{PreInput}{instruction}{input}"""
         | 
| 849 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 850 | 
            +
                elif input and instruction and PreInstruct is not None:
         | 
| 851 | 
            +
                    prompt += f"""{PreInstruct}{instruction}{input}"""
         | 
| 852 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 853 | 
            +
                elif input and instruction:
         | 
| 854 | 
            +
                    # i.e. for simple_instruct
         | 
| 855 | 
            +
                    prompt += f"""{instruction}: {input}"""
         | 
| 856 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 857 | 
            +
                elif input:
         | 
| 858 | 
            +
                    prompt += f"""{input}"""
         | 
| 859 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 860 | 
            +
                elif instruction:
         | 
| 861 | 
            +
                    prompt += f"""{instruction}"""
         | 
| 862 | 
            +
                    prompt = inject_newline(prompt_type, prompt)
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                if PreResponse is not None:
         | 
| 865 | 
            +
                    prompt += f"""{PreResponse}"""
         | 
| 866 | 
            +
                    pre_response = PreResponse  # Don't use strip
         | 
| 867 | 
            +
                else:
         | 
| 868 | 
            +
                    pre_response = ''
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                if output:
         | 
| 871 | 
            +
                    prompt += f"""{output}"""
         | 
| 872 | 
            +
             | 
| 873 | 
            +
                return prompt, pre_response, terminate_response
         | 
| 874 | 
            +
             | 
| 875 | 
            +
             | 
| 876 | 
            +
            def inject_newline(prompt_type, prompt):
         | 
| 877 | 
            +
                if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
         | 
| 878 | 
            +
                    # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
         | 
| 879 | 
            +
                    prompt += '\n'
         | 
| 880 | 
            +
                return prompt
         | 
| 881 | 
            +
             | 
| 882 | 
            +
             | 
| 883 | 
            +
            example_data_point0 = dict(instruction="Summarize",
         | 
| 884 | 
            +
                                       input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
         | 
| 885 | 
            +
                                       output="Ducks eat and swim at the lake.")
         | 
| 886 | 
            +
             | 
| 887 | 
            +
            example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
         | 
| 888 | 
            +
                                       output="Einstein.")
         | 
| 889 | 
            +
             | 
| 890 | 
            +
            example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
         | 
| 891 | 
            +
                                       output="Einstein.")
         | 
| 892 | 
            +
             | 
| 893 | 
            +
            example_data_points = [example_data_point0, example_data_point1, example_data_point2]
         | 
| 894 | 
            +
             | 
| 895 | 
            +
             | 
| 896 | 
            +
            def test_train_prompt(prompt_type='instruct', data_point=0):
         | 
| 897 | 
            +
                example_data_point = example_data_points[data_point]
         | 
| 898 | 
            +
                return generate_prompt(example_data_point, prompt_type, False, False)
         | 
| 899 | 
            +
             | 
| 900 | 
            +
             | 
| 901 | 
            +
            def test_debug():
         | 
| 902 | 
            +
                fire.Fire(train)
         | 
| 903 | 
            +
             | 
| 904 | 
            +
             | 
| 905 | 
            +
            if __name__ == "__main__":
         | 
| 906 | 
            +
                CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
         | 
| 907 | 
            +
                CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
         | 
| 908 | 
            +
                log(f"""
         | 
| 909 | 
            +
                Example runs on 4 GPUs:
         | 
| 910 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-7b-hf' --data_path=data/config.json --run_id=0 &> 0.log
         | 
| 911 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='decapoda-research/llama-30b-hf' --data_path=data/config.json --batch_size=16 --micro_batch_size=1 --run_id=1 --save_code=True &> 1.log
         | 
| 912 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-j-6B' --data_path=data/config.json --run_id=2 &> 2.log
         | 
| 913 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='EleutherAI/gpt-neox-20b' --data_path=data/config.json --run_id=8 --batch_size=16 --micro_batch_size=4 &> 8.log
         | 
| 914 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --prompt_type='dai_faq' --run_id=13 --batch_size=16 --micro_batch_size=4 --num_epochs=100 --val_set_size=0 data_mix_in_path='' &> 13.log
         | 
| 915 | 
            +
                WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 finetune.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --data_path=data/config.json --run_id=28 --batch_size=16 --micro_batch_size=4 --num_epochs=8 --val_set_size=0 --data_mix_in_factor=0.1 --data_mix_in_prompt_type='human_bot' --save_code=True --cutoff_len=512  &> 28.log
         | 
| 916 | 
            +
             | 
| 917 | 
            +
                All metrics:
         | 
| 918 | 
            +
                CUDA_VISIBLE_DEVICES= finetune.py --data_mix_in_factor=0 --eval_steps=100 --warmup_steps=2 --val_set_size=100 --val_metrics="['bleu', 'rouge', 'sacrebleu', 'meteor']"
         | 
| 919 | 
            +
             | 
| 920 | 
            +
                # Fine-tune 20B on 24GB GPUs across 3 nodes with 3+2+2 GPUs
         | 
| 921 | 
            +
                rippa>
         | 
| 922 | 
            +
            NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1,2" torchrun --node_rank 0 --nproc_per_node=3 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank0
         | 
| 923 | 
            +
                ova>
         | 
| 924 | 
            +
            NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 1 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank1
         | 
| 925 | 
            +
                timemachine>
         | 
| 926 | 
            +
            NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank 2 --nproc_per_node=2 --master_port=1234 --nnodes=3 --master_addr=10.10.10.2 finetune.py --data_path=merged_shuffled_OIG_87f6a1e788.json --micro_batch_size=1 --batch_size=7 --cutoff_len=512 --run_id=17 &>log.17.rank2
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                """, flush=True)
         | 
| 929 | 
            +
             | 
| 930 | 
            +
                if os.environ.get("LOCAL_RANK") is None:
         | 
| 931 | 
            +
                    # then not using torchrun, so can't do distributed, ensure CVD set
         | 
| 932 | 
            +
                    assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
         | 
| 933 | 
            +
             | 
| 934 | 
            +
                fire.Fire(train)
         | 
    	
        h2o-logo.svg
    ADDED
    
    |  | 
    	
        prompter.py
    ADDED
    
    | @@ -0,0 +1,106 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from finetune import generate_prompt
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class Prompter(object):
         | 
| 5 | 
            +
                def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
         | 
| 6 | 
            +
                             allowed_repeat_line_length=10):
         | 
| 7 | 
            +
                    self.prompt_type = prompt_type
         | 
| 8 | 
            +
                    data_point = dict(instruction='', input='', output='')
         | 
| 9 | 
            +
                    _, self.pre_response, self.terminate_response = generate_prompt(data_point, prompt_type, chat, False)
         | 
| 10 | 
            +
                    self.debug = debug
         | 
| 11 | 
            +
                    self.chat = chat
         | 
| 12 | 
            +
                    self.stream_output = stream_output
         | 
| 13 | 
            +
                    self.repeat_penalty = repeat_penalty
         | 
| 14 | 
            +
                    self.allowed_repeat_line_length = allowed_repeat_line_length
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                def generate_prompt(self, data_point):
         | 
| 17 | 
            +
                    reduced = False
         | 
| 18 | 
            +
                    prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
         | 
| 19 | 
            +
                    if self.debug:
         | 
| 20 | 
            +
                        print("prompt: ", prompt, flush=True)
         | 
| 21 | 
            +
                    self.prompt = prompt
         | 
| 22 | 
            +
                    return prompt
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
         | 
| 25 | 
            +
                    if isinstance(outputs, str):
         | 
| 26 | 
            +
                        outputs = [outputs]
         | 
| 27 | 
            +
                    if self.debug:
         | 
| 28 | 
            +
                        print("output: ", '\n\n'.join(outputs), flush=True)
         | 
| 29 | 
            +
                    if prompt is not None:
         | 
| 30 | 
            +
                        self.prompt = prompt
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    def clean_response(response):
         | 
| 33 | 
            +
                        meaningless_words = ['<pad>', '</s>', '<|endoftext|>', '”\n']
         | 
| 34 | 
            +
                        for word in meaningless_words:
         | 
| 35 | 
            +
                            response = response.replace(word, "")
         | 
| 36 | 
            +
                        if sanitize_bot_response:
         | 
| 37 | 
            +
                            from better_profanity import profanity
         | 
| 38 | 
            +
                            response = profanity.censor(response)
         | 
| 39 | 
            +
                        response = response.strip("\n")
         | 
| 40 | 
            +
                        return response
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    def clean_repeats(response):
         | 
| 43 | 
            +
                        lines = response.split('\n')
         | 
| 44 | 
            +
                        new_lines = []
         | 
| 45 | 
            +
                        [new_lines.append(line) for line in lines if
         | 
| 46 | 
            +
                         line not in new_lines or len(line) < self.allowed_repeat_line_length]
         | 
| 47 | 
            +
                        if self.debug and len(lines) != len(new_lines):
         | 
| 48 | 
            +
                            print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
         | 
| 49 | 
            +
                        response = '\n'.join(new_lines)
         | 
| 50 | 
            +
                        return response
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    multi_output = len(outputs) > 1
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    for oi, output in enumerate(outputs):
         | 
| 55 | 
            +
                        if self.prompt_type in [0, '0', 'plain']:
         | 
| 56 | 
            +
                            output = clean_response(output)
         | 
| 57 | 
            +
                        else:
         | 
| 58 | 
            +
                            # find first instance of prereponse
         | 
| 59 | 
            +
                            # prompt sometimes has odd characters, that mutate length,
         | 
| 60 | 
            +
                            # so can't go by length alone
         | 
| 61 | 
            +
                            if self.pre_response:
         | 
| 62 | 
            +
                                outputi = output.find(prompt)
         | 
| 63 | 
            +
                                if outputi >= 0:
         | 
| 64 | 
            +
                                    output = output[outputi + len(prompt):]
         | 
| 65 | 
            +
                                    allow_terminate = True
         | 
| 66 | 
            +
                                else:
         | 
| 67 | 
            +
                                    # subtraction is risky due to space offsets sometimes, so only do if necessary
         | 
| 68 | 
            +
                                    output = output[len(prompt) - len(self.pre_response):]
         | 
| 69 | 
            +
                                    # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
         | 
| 70 | 
            +
                                    if self.pre_response in output:
         | 
| 71 | 
            +
                                        output = output.split(self.pre_response)[1]
         | 
| 72 | 
            +
                                        allow_terminate = True
         | 
| 73 | 
            +
                                    else:
         | 
| 74 | 
            +
                                        print("Failure of parsing: %s" % output, flush=True)
         | 
| 75 | 
            +
                                        allow_terminate = False
         | 
| 76 | 
            +
                            else:
         | 
| 77 | 
            +
                                allow_terminate = True
         | 
| 78 | 
            +
                                output = output[len(prompt):]
         | 
| 79 | 
            +
                            # clean after subtract prompt out, so correct removal of pre_response
         | 
| 80 | 
            +
                            output = clean_response(output).strip()
         | 
| 81 | 
            +
                            if self.repeat_penalty:
         | 
| 82 | 
            +
                                output = clean_repeats(output).strip()
         | 
| 83 | 
            +
                            if self.terminate_response and allow_terminate:
         | 
| 84 | 
            +
                                finds = []
         | 
| 85 | 
            +
                                for term in self.terminate_response:
         | 
| 86 | 
            +
                                    finds.append(output.find(term))
         | 
| 87 | 
            +
                                finds = [x for x in finds if x >= 0]
         | 
| 88 | 
            +
                                if len(finds) > 0:
         | 
| 89 | 
            +
                                    termi = finds[0]
         | 
| 90 | 
            +
                                    output = output[:termi].strip()
         | 
| 91 | 
            +
                                else:
         | 
| 92 | 
            +
                                    output = output.strip()
         | 
| 93 | 
            +
                            else:
         | 
| 94 | 
            +
                                output = output.strip()
         | 
| 95 | 
            +
                        if multi_output:
         | 
| 96 | 
            +
                            # prefix with output counter
         | 
| 97 | 
            +
                            output = "\n=========== Output %d\n\n" % (1 + oi) + output
         | 
| 98 | 
            +
                            if oi > 0:
         | 
| 99 | 
            +
                                # post fix outputs with seperator
         | 
| 100 | 
            +
                                output += '\n'
         | 
| 101 | 
            +
                        outputs[oi] = output
         | 
| 102 | 
            +
                    # join all outputs, only one extra new line between outputs
         | 
| 103 | 
            +
                    output = '\n'.join(outputs)
         | 
| 104 | 
            +
                    if self.debug:
         | 
| 105 | 
            +
                        print("outputclean: ", '\n\n'.join(outputs), flush=True)
         | 
| 106 | 
            +
                    return output
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # for generate (gradio server) and finetune
         | 
| 2 | 
            +
            datasets==2.10.1
         | 
| 3 | 
            +
            sentencepiece==0.1.97
         | 
| 4 | 
            +
            accelerate==0.18.0
         | 
| 5 | 
            +
            gradio==3.27.0
         | 
| 6 | 
            +
            huggingface_hub==0.13.4
         | 
| 7 | 
            +
            appdirs==1.4.4
         | 
| 8 | 
            +
            fire==0.5.0
         | 
| 9 | 
            +
            docutils==0.19
         | 
| 10 | 
            +
            torch==2.0.0
         | 
| 11 | 
            +
            evaluate==0.4.0
         | 
| 12 | 
            +
            rouge_score==0.1.2
         | 
| 13 | 
            +
            sacrebleu==2.3.1
         | 
| 14 | 
            +
            scikit-learn==1.2.2
         | 
| 15 | 
            +
            alt-profanity-check==1.2.2
         | 
| 16 | 
            +
            better-profanity==0.6.1
         | 
| 17 | 
            +
            numpy==1.24.2
         | 
| 18 | 
            +
            pandas==1.5.3
         | 
| 19 | 
            +
            matplotlib==3.7.1
         | 
| 20 | 
            +
            loralib==0.1.1
         | 
| 21 | 
            +
            bitsandbytes==0.38.1
         | 
| 22 | 
            +
            git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
         | 
| 23 | 
            +
            transformers==4.28.1
         | 
| 24 | 
            +
            tokenizers==0.13.3
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # optional for generate
         | 
| 27 | 
            +
            pynvml==11.5.0
         | 
| 28 | 
            +
            psutil==5.9.4
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # optional for finetune
         | 
| 31 | 
            +
            tensorboard==2.12.1
         | 
| 32 | 
            +
            neptune==1.1.1
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            # for gradio client
         | 
| 35 | 
            +
            gradio_client==0.1.3
         | 
| 36 | 
            +
            beautifulsoup4==4.12.2
         | 
| 37 | 
            +
            markdown==3.4.1
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            # data and testing
         | 
| 40 | 
            +
            pytest==7.2.2
         | 
| 41 | 
            +
            pytest-xdist==3.2.1
         | 
| 42 | 
            +
            nltk==3.8.1
         | 
| 43 | 
            +
            textstat==0.7.3
         | 
| 44 | 
            +
            pandoc==2.3
         | 
| 45 | 
            +
            pypandoc==1.11
         | 
| 46 | 
            +
            openpyxl==3.1.2
         | 
| 47 | 
            +
            lm_dataformat==0.0.20
         | 
| 48 | 
            +
            bioc==2.0
         | 
    	
        stopping.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import traceback
         | 
| 2 | 
            +
            from queue import Queue
         | 
| 3 | 
            +
            from threading import Thread
         | 
| 4 | 
            +
            import collections.abc
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            from transformers import StoppingCriteria
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class StoppingCriteriaSub(StoppingCriteria):
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __init__(self, stops=[], encounters=[]):
         | 
| 13 | 
            +
                    super().__init__()
         | 
| 14 | 
            +
                    assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
         | 
| 15 | 
            +
                    self.encounters = encounters
         | 
| 16 | 
            +
                    self.stops = [stop.to("cuda") for stop in stops]
         | 
| 17 | 
            +
                    self.num_stops = [0] * len(stops)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         | 
| 20 | 
            +
                    for stopi, stop in enumerate(self.stops):
         | 
| 21 | 
            +
                        if torch.all((stop == input_ids[0][-len(stop):])).item():
         | 
| 22 | 
            +
                            self.num_stops[stopi] += 1
         | 
| 23 | 
            +
                            if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
         | 
| 24 | 
            +
                                return True
         | 
| 25 | 
            +
                    # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
         | 
| 26 | 
            +
                    # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
         | 
| 27 | 
            +
                    return False
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Stream(StoppingCriteria):
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                This class can be used to callback during generation. Keep
         | 
| 33 | 
            +
                in mind for decoder-only type of transformers, this will include the initial prompted tokens.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                Args:
         | 
| 36 | 
            +
                    func (`callable`):
         | 
| 37 | 
            +
                        A callable function to apply on first input in list every iteration of generation
         | 
| 38 | 
            +
                """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                def __init__(self, func=None):
         | 
| 41 | 
            +
                    self.func = func
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
         | 
| 44 | 
            +
                    if self.func is not None:
         | 
| 45 | 
            +
                        # only consume first of multiple responses
         | 
| 46 | 
            +
                        self.func(input_ids[0])
         | 
| 47 | 
            +
                    return False
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            class CallbackToGenerator(collections.abc.Generator):
         | 
| 51 | 
            +
                """
         | 
| 52 | 
            +
                A generator wrapper for a function that invokes a callback multiple times.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                Calling `send` on the generator emits a value from one callback, and returns
         | 
| 55 | 
            +
                the next.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                Note this starts a background thread
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __init__(self, func, *args, callback=None, **kwargs):
         | 
| 61 | 
            +
                    self.func = func
         | 
| 62 | 
            +
                    self.args = args
         | 
| 63 | 
            +
                    self.kwargs = kwargs
         | 
| 64 | 
            +
                    self.callback = callback
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    self._ready_queue = Queue(1)
         | 
| 67 | 
            +
                    self._done_queue = Queue(1)
         | 
| 68 | 
            +
                    self._done_holder = [False]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # local to avoid reference cycles
         | 
| 71 | 
            +
                    ready_queue = self._ready_queue
         | 
| 72 | 
            +
                    done_queue = self._done_queue
         | 
| 73 | 
            +
                    done_holder = self._done_holder
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    def val_callback(value):
         | 
| 76 | 
            +
                        done_queue.put((False, value))
         | 
| 77 | 
            +
                        cmd, val = ready_queue.get()
         | 
| 78 | 
            +
                        if cmd == 'send':
         | 
| 79 | 
            +
                            return val
         | 
| 80 | 
            +
                        elif cmd == 'throw':
         | 
| 81 | 
            +
                            raise val
         | 
| 82 | 
            +
                        else:
         | 
| 83 | 
            +
                            assert False  # pragma: no cover
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    def thread_func():
         | 
| 86 | 
            +
                        while True:
         | 
| 87 | 
            +
                            cmd, val = ready_queue.get()
         | 
| 88 | 
            +
                            if cmd == 'send' and val is not None:
         | 
| 89 | 
            +
                                done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
         | 
| 90 | 
            +
                                continue
         | 
| 91 | 
            +
                            break
         | 
| 92 | 
            +
                        try:
         | 
| 93 | 
            +
                            if cmd == 'throw':
         | 
| 94 | 
            +
                                raise val
         | 
| 95 | 
            +
                            ret = func(callback=val_callback, **self.kwargs)
         | 
| 96 | 
            +
                            raise StopIteration(ret) if ret is not None else StopIteration
         | 
| 97 | 
            +
                        except BaseException as e:
         | 
| 98 | 
            +
                            done_holder[0] = True
         | 
| 99 | 
            +
                            done_queue.put((True, e))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    self._thread = Thread(target=thread_func)
         | 
| 102 | 
            +
                    self._thread.start()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def _put(self, *args):
         | 
| 105 | 
            +
                    if self._done_holder[0]:
         | 
| 106 | 
            +
                        raise StopIteration
         | 
| 107 | 
            +
                    self._ready_queue.put(args)
         | 
| 108 | 
            +
                    is_exception, val = self._done_queue.get()
         | 
| 109 | 
            +
                    if is_exception:
         | 
| 110 | 
            +
                        try:
         | 
| 111 | 
            +
                            raise val
         | 
| 112 | 
            +
                        finally:
         | 
| 113 | 
            +
                            # prevent val's traceback containing a reference cycle
         | 
| 114 | 
            +
                            del val
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        return val
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def send(self, value):
         | 
| 119 | 
            +
                    return self._put('send', value)
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def throw(self, exc):
         | 
| 122 | 
            +
                    return self._put('throw', exc)
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def close(self):
         | 
| 125 | 
            +
                    try:
         | 
| 126 | 
            +
                        self.throw(GeneratorExit)
         | 
| 127 | 
            +
                    except StopIteration:
         | 
| 128 | 
            +
                        self._thread.join()
         | 
| 129 | 
            +
                    except GeneratorExit:
         | 
| 130 | 
            +
                        self._thread.join()
         | 
| 131 | 
            +
                    except BaseException:
         | 
| 132 | 
            +
                        self._thread.join()
         | 
| 133 | 
            +
                        raise
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        # yielded again, can't clean up the thread
         | 
| 136 | 
            +
                        raise RuntimeError('Task with callback ignored GeneratorExit')
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def __del__(self):
         | 
| 139 | 
            +
                    self.close()
         | 
    	
        utils.py
    ADDED
    
    | @@ -0,0 +1,154 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
            import random
         | 
| 4 | 
            +
            import time
         | 
| 5 | 
            +
            import traceback
         | 
| 6 | 
            +
            import zipfile
         | 
| 7 | 
            +
            from datetime import datetime
         | 
| 8 | 
            +
            import filelock
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import pandas as pd
         | 
| 11 | 
            +
            import torch
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def set_seed(seed: int):
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                Sets the seed of the entire notebook so results are the same every time we run.
         | 
| 17 | 
            +
                This is for REPRODUCIBILITY.
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                np.random.seed(seed)
         | 
| 20 | 
            +
                random_state = np.random.RandomState(seed)
         | 
| 21 | 
            +
                random.seed(seed)
         | 
| 22 | 
            +
                torch.manual_seed(seed)
         | 
| 23 | 
            +
                torch.cuda.manual_seed(seed)
         | 
| 24 | 
            +
                torch.backends.cudnn.deterministic = True
         | 
| 25 | 
            +
                torch.backends.cudnn.benchmark = False
         | 
| 26 | 
            +
                os.environ['PYTHONHASHSEED'] = str(seed)
         | 
| 27 | 
            +
                return random_state
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            def flatten_list(lis):
         | 
| 31 | 
            +
                """Given a list, possibly nested to any level, return it flattened."""
         | 
| 32 | 
            +
                new_lis = []
         | 
| 33 | 
            +
                for item in lis:
         | 
| 34 | 
            +
                    if type(item) == type([]):
         | 
| 35 | 
            +
                        new_lis.extend(flatten_list(item))
         | 
| 36 | 
            +
                    else:
         | 
| 37 | 
            +
                        new_lis.append(item)
         | 
| 38 | 
            +
                return new_lis
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            def clear_torch_cache():
         | 
| 42 | 
            +
                if torch.cuda.is_available:
         | 
| 43 | 
            +
                    torch.cuda.empty_cache()
         | 
| 44 | 
            +
                    torch.cuda.ipc_collect()
         | 
| 45 | 
            +
                    gc.collect()
         | 
| 46 | 
            +
             | 
| 47 | 
            +
             | 
| 48 | 
            +
            def system_info():
         | 
| 49 | 
            +
                import psutil
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                system = {}
         | 
| 52 | 
            +
                # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
         | 
| 53 | 
            +
                # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
         | 
| 54 | 
            +
                temps = psutil.sensors_temperatures(fahrenheit=False)
         | 
| 55 | 
            +
                if 'coretemp' in temps:
         | 
| 56 | 
            +
                    coretemp = temps['coretemp']
         | 
| 57 | 
            +
                    temp_dict = {k.label: k.current for k in coretemp}
         | 
| 58 | 
            +
                    for k, v in temp_dict.items():
         | 
| 59 | 
            +
                        system['CPU_C/%s' % k] = v
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
         | 
| 62 | 
            +
                from pynvml.smi import nvidia_smi
         | 
| 63 | 
            +
                nvsmi = nvidia_smi.getInstance()
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
         | 
| 66 | 
            +
                                  enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
         | 
| 67 | 
            +
                for k, v in gpu_power_dict.items():
         | 
| 68 | 
            +
                    system['GPU_W/%s' % k] = v
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
         | 
| 71 | 
            +
                                 enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
         | 
| 72 | 
            +
                for k, v in gpu_temp_dict.items():
         | 
| 73 | 
            +
                    system['GPU_C/%s' % k] = v
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
         | 
| 76 | 
            +
                                        enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
         | 
| 77 | 
            +
                gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
         | 
| 78 | 
            +
                                         enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
         | 
| 79 | 
            +
                gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
         | 
| 80 | 
            +
                for k, v in gpu_memory_frac_dict.items():
         | 
| 81 | 
            +
                    system[f'GPU_M/%s' % k] = v
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                return system
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            def system_info_print():
         | 
| 87 | 
            +
                try:
         | 
| 88 | 
            +
                    df = pd.DataFrame.from_dict(system_info(), orient='index')
         | 
| 89 | 
            +
                    # avoid slamming GPUs
         | 
| 90 | 
            +
                    time.sleep(1)
         | 
| 91 | 
            +
                    return df.to_markdown()
         | 
| 92 | 
            +
                except Exception as e:
         | 
| 93 | 
            +
                    return "Error: %s" % str(e)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
            +
            def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
         | 
| 97 | 
            +
                try:
         | 
| 98 | 
            +
                    return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
         | 
| 99 | 
            +
                except Exception as e:
         | 
| 100 | 
            +
                    traceback.print_exc()
         | 
| 101 | 
            +
                    print('Exception in zipping: %s' % str(e))
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
         | 
| 105 | 
            +
                if zip_file is None:
         | 
| 106 | 
            +
                    datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
         | 
| 107 | 
            +
                    host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
         | 
| 108 | 
            +
                    zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
         | 
| 109 | 
            +
                assert root_dirs is not None
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                with zipfile.ZipFile(zip_file, "w") as expt_zip:
         | 
| 112 | 
            +
                    for root_dir in root_dirs:
         | 
| 113 | 
            +
                        if root_dir is None:
         | 
| 114 | 
            +
                            continue
         | 
| 115 | 
            +
                        for root, d, files in os.walk(root_dir):
         | 
| 116 | 
            +
                            for file in files:
         | 
| 117 | 
            +
                                file_to_archive = os.path.join(root, file)
         | 
| 118 | 
            +
                                assert os.path.exists(file_to_archive)
         | 
| 119 | 
            +
                                path_to_archive = os.path.relpath(file_to_archive, base_dir)
         | 
| 120 | 
            +
                                expt_zip.write(filename=file_to_archive, arcname=path_to_archive)
         | 
| 121 | 
            +
                return zip_file
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            def save_generate_output(output=None, base_model=None, save_dir=None):
         | 
| 125 | 
            +
                try:
         | 
| 126 | 
            +
                    return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
         | 
| 127 | 
            +
                except Exception as e:
         | 
| 128 | 
            +
                    traceback.print_exc()
         | 
| 129 | 
            +
                    print('Exception in saving: %s' % str(e))
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def _save_generate_output(output=None, base_model=None, save_dir=None):
         | 
| 133 | 
            +
                """
         | 
| 134 | 
            +
                Save conversation to .json, row by row.
         | 
| 135 | 
            +
                json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
         | 
| 136 | 
            +
                Appends if file exists
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                assert save_dir, "save_dir must be provided"
         | 
| 139 | 
            +
                if os.path.exists(save_dir) and not os.path.isdir(save_dir):
         | 
| 140 | 
            +
                    raise RuntimeError("save_dir already exists and is not a directory!")
         | 
| 141 | 
            +
                os.makedirs(save_dir, exist_ok=True)
         | 
| 142 | 
            +
                import json
         | 
| 143 | 
            +
                if output[-10:] == '\n\n<human>:':
         | 
| 144 | 
            +
                    # remove trailing <human>:
         | 
| 145 | 
            +
                    output = output[:-10]
         | 
| 146 | 
            +
                with filelock.FileLock("save_dir.lock"):
         | 
| 147 | 
            +
                    # lock logging in case have concurrency
         | 
| 148 | 
            +
                    with open(os.path.join(save_dir, "history.json"), "a") as f:
         | 
| 149 | 
            +
                        # just add [ at start, and ] at end, and have proper JSON dataset
         | 
| 150 | 
            +
                        f.write(
         | 
| 151 | 
            +
                            "  " + json.dumps(
         | 
| 152 | 
            +
                                dict(text=output, time=time.ctime(), base_model=base_model)
         | 
| 153 | 
            +
                            ) + ",\n"
         | 
| 154 | 
            +
                        )
         | 
 
			

