Spaces:
Runtime error
Runtime error
Commit
·
8b07bee
1
Parent(s):
34a8736
updated weights for first rev model
Browse files
app.py
CHANGED
@@ -42,6 +42,14 @@ gpt_rev_weights_path = huggingface_hub.hf_hub_download(
|
|
42 |
"jefsnacker/surname_generator",
|
43 |
"rev_gpt_weights.pt")
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
with open(mlp_config_path, 'r') as file:
|
46 |
mlp_config = yaml.safe_load(file)
|
47 |
|
@@ -54,6 +62,9 @@ with open(gpt_micro_config_path, 'r') as file:
|
|
54 |
with open(gpt_rev_config_path, 'r') as file:
|
55 |
gpt_rev_config = yaml.safe_load(file)
|
56 |
|
|
|
|
|
|
|
57 |
##################################################################################
|
58 |
## MLP
|
59 |
##################################################################################
|
@@ -329,6 +340,10 @@ gpt_rev = GPT(gpt_rev_config)
|
|
329 |
gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
|
330 |
gpt_rev.eval()
|
331 |
|
|
|
|
|
|
|
|
|
332 |
##################################################################################
|
333 |
## Gradio App
|
334 |
##################################################################################
|
@@ -351,9 +366,8 @@ def generate_names(name_start, name_end, number_of_names, model):
|
|
351 |
config = gpt_rev_config
|
352 |
sample_fcn = gpt_rev.sample_char
|
353 |
elif model == "GPT First Rev":
|
354 |
-
|
355 |
-
|
356 |
-
sample_fcn = gpt_rev.sample_char
|
357 |
else:
|
358 |
return "Error: Model not selected"
|
359 |
|
|
|
42 |
"jefsnacker/surname_generator",
|
43 |
"rev_gpt_weights.pt")
|
44 |
|
45 |
+
gpt_first_rev_config_path = huggingface_hub.hf_hub_download(
|
46 |
+
"jefsnacker/surname_generator",
|
47 |
+
"first_name_gpt_config.yaml")
|
48 |
+
|
49 |
+
gpt_first_rev_weights_path = huggingface_hub.hf_hub_download(
|
50 |
+
"jefsnacker/surname_generator",
|
51 |
+
"first_name_gpt_weights.pt")
|
52 |
+
|
53 |
with open(mlp_config_path, 'r') as file:
|
54 |
mlp_config = yaml.safe_load(file)
|
55 |
|
|
|
62 |
with open(gpt_rev_config_path, 'r') as file:
|
63 |
gpt_rev_config = yaml.safe_load(file)
|
64 |
|
65 |
+
with open(gpt_first_rev_config_path, 'r') as file:
|
66 |
+
gpt_first_rev_config = yaml.safe_load(file)
|
67 |
+
|
68 |
##################################################################################
|
69 |
## MLP
|
70 |
##################################################################################
|
|
|
340 |
gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
|
341 |
gpt_rev.eval()
|
342 |
|
343 |
+
gpt_first_rev = GPT(gpt_first_rev_config)
|
344 |
+
gpt_first_rev.load_state_dict(torch.load(gpt_first_rev_weights_path))
|
345 |
+
gpt_first_rev.eval()
|
346 |
+
|
347 |
##################################################################################
|
348 |
## Gradio App
|
349 |
##################################################################################
|
|
|
366 |
config = gpt_rev_config
|
367 |
sample_fcn = gpt_rev.sample_char
|
368 |
elif model == "GPT First Rev":
|
369 |
+
config = gpt_first_rev_config
|
370 |
+
sample_fcn = gpt_first_rev.sample_char
|
|
|
371 |
else:
|
372 |
return "Error: Model not selected"
|
373 |
|