Tousifahamed commited on
Commit
85c27c2
·
verified ·
1 Parent(s): aa150c0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. model_utils.py +4 -5
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import torch
3
  import os
4
- from model_utils import load_model, generate_text
5
 
6
  # Initialize model
7
  try:
 
1
  import gradio as gr
2
  import torch
3
  import os
4
+ from model_utils import load_model, generate_text, GPTConfig
5
 
6
  # Initialize model
7
  try:
model_utils.py CHANGED
@@ -115,11 +115,10 @@ def load_model(model_path):
115
 
116
  # Create config from the saved dictionary
117
  config_dict = checkpoint['config']
118
- if isinstance(config_dict, str):
119
- # If config was saved as string, parse it to dict
120
- import ast
121
- config_dict = ast.literal_eval(config_dict)
122
- config = GPTConfig(**config_dict)
123
 
124
  model = GPT(config)
125
  model.load_state_dict(checkpoint['model_state_dict'])
 
115
 
116
  # Create config from the saved dictionary
117
  config_dict = checkpoint['config']
118
+ if isinstance(config_dict, dict):
119
+ config = GPTConfig(**config_dict.__dict__) # Convert dataclass to dict
120
+ else:
121
+ config = config_dict # If config was already saved as GPTConfig instance
 
122
 
123
  model = GPT(config)
124
  model.load_state_dict(checkpoint['model_state_dict'])