Upload 2 files
Browse files- app.py +1 -1
- model_utils.py +4 -5
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
-
from model_utils import load_model, generate_text
|
5 |
|
6 |
# Initialize model
|
7 |
try:
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
+
from model_utils import load_model, generate_text, GPTConfig
|
5 |
|
6 |
# Initialize model
|
7 |
try:
|
model_utils.py
CHANGED
@@ -115,11 +115,10 @@ def load_model(model_path):
|
|
115 |
|
116 |
# Create config from the saved dictionary
|
117 |
config_dict = checkpoint['config']
|
118 |
-
if isinstance(config_dict,
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
config = GPTConfig(**config_dict)
|
123 |
|
124 |
model = GPT(config)
|
125 |
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
115 |
|
116 |
# Create config from the saved dictionary
|
117 |
config_dict = checkpoint['config']
|
118 |
+
if isinstance(config_dict, dict):
|
119 |
+
config = GPTConfig(**config_dict.__dict__) # Convert dataclass to dict
|
120 |
+
else:
|
121 |
+
config = config_dict # If config was already saved as GPTConfig instance
|
|
|
122 |
|
123 |
model = GPT(config)
|
124 |
model.load_state_dict(checkpoint['model_state_dict'])
|