Fix model name from InternVL2-8B to InternVL2
Browse files
    	
        app.py
    CHANGED
    
    | @@ -104,7 +104,7 @@ def create_in_memory_flash_attn_mock(): | |
| 104 | 
             
                # Add attributes used by transformers checks
         | 
| 105 | 
             
                flash_attn.__version__ = "0.0.0-mocked"
         | 
| 106 |  | 
| 107 | 
            -
                # Create common submodules -  | 
| 108 | 
             
                for submodule in ['flash_attn.flash_attn_interface', 'flash_attn.flash_attn_triton']:
         | 
| 109 | 
             
                    parts = submodule.split('.')
         | 
| 110 | 
             
                    parent_name = '.'.join(parts[:-1])
         | 
| @@ -170,7 +170,8 @@ def load_model(): | |
| 170 | 
             
                    print("\nLoading InternVL2 model...")
         | 
| 171 |  | 
| 172 | 
             
                    # Load the model and tokenizer
         | 
| 173 | 
            -
                     | 
|  | |
| 174 |  | 
| 175 | 
             
                    # Print downloading status
         | 
| 176 | 
             
                    print("Downloading model shards. This may take some time...")
         | 
|  | |
| 104 | 
             
                # Add attributes used by transformers checks
         | 
| 105 | 
             
                flash_attn.__version__ = "0.0.0-mocked"
         | 
| 106 |  | 
| 107 | 
            +
                # Create common submodules - without 'parent' parameter
         | 
| 108 | 
             
                for submodule in ['flash_attn.flash_attn_interface', 'flash_attn.flash_attn_triton']:
         | 
| 109 | 
             
                    parts = submodule.split('.')
         | 
| 110 | 
             
                    parent_name = '.'.join(parts[:-1])
         | 
|  | |
| 170 | 
             
                    print("\nLoading InternVL2 model...")
         | 
| 171 |  | 
| 172 | 
             
                    # Load the model and tokenizer
         | 
| 173 | 
            +
                    # FIXED: Corrected model name from InternVL2-8B to InternVL2
         | 
| 174 | 
            +
                    model_path = "OpenGVLab/InternVL2"
         | 
| 175 |  | 
| 176 | 
             
                    # Print downloading status
         | 
| 177 | 
             
                    print("Downloading model shards. This may take some time...")
         |