jefsnacker commited on
Commit
8b07bee
·
1 Parent(s): 34a8736

updated weights for first rev model

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -42,6 +42,14 @@ gpt_rev_weights_path = huggingface_hub.hf_hub_download(
42
  "jefsnacker/surname_generator",
43
  "rev_gpt_weights.pt")
44
 
 
 
 
 
 
 
 
 
45
  with open(mlp_config_path, 'r') as file:
46
  mlp_config = yaml.safe_load(file)
47
 
@@ -54,6 +62,9 @@ with open(gpt_micro_config_path, 'r') as file:
54
  with open(gpt_rev_config_path, 'r') as file:
55
  gpt_rev_config = yaml.safe_load(file)
56
 
 
 
 
57
  ##################################################################################
58
  ## MLP
59
  ##################################################################################
@@ -329,6 +340,10 @@ gpt_rev = GPT(gpt_rev_config)
329
  gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
330
  gpt_rev.eval()
331
 
 
 
 
 
332
  ##################################################################################
333
  ## Gradio App
334
  ##################################################################################
@@ -351,9 +366,8 @@ def generate_names(name_start, name_end, number_of_names, model):
351
  config = gpt_rev_config
352
  sample_fcn = gpt_rev.sample_char
353
  elif model == "GPT First Rev":
354
- # TODO: Change model!
355
- config = gpt_rev_config
356
- sample_fcn = gpt_rev.sample_char
357
  else:
358
  return "Error: Model not selected"
359
 
 
42
  "jefsnacker/surname_generator",
43
  "rev_gpt_weights.pt")
44
 
45
+ gpt_first_rev_config_path = huggingface_hub.hf_hub_download(
46
+ "jefsnacker/surname_generator",
47
+ "first_name_gpt_config.yaml")
48
+
49
+ gpt_first_rev_weights_path = huggingface_hub.hf_hub_download(
50
+ "jefsnacker/surname_generator",
51
+ "first_name_gpt_weights.pt")
52
+
53
  with open(mlp_config_path, 'r') as file:
54
  mlp_config = yaml.safe_load(file)
55
 
 
62
  with open(gpt_rev_config_path, 'r') as file:
63
  gpt_rev_config = yaml.safe_load(file)
64
 
65
+ with open(gpt_first_rev_config_path, 'r') as file:
66
+ gpt_first_rev_config = yaml.safe_load(file)
67
+
68
  ##################################################################################
69
  ## MLP
70
  ##################################################################################
 
340
  gpt_rev.load_state_dict(torch.load(gpt_rev_weights_path))
341
  gpt_rev.eval()
342
 
343
+ gpt_first_rev = GPT(gpt_first_rev_config)
344
+ gpt_first_rev.load_state_dict(torch.load(gpt_first_rev_weights_path))
345
+ gpt_first_rev.eval()
346
+
347
  ##################################################################################
348
  ## Gradio App
349
  ##################################################################################
 
366
  config = gpt_rev_config
367
  sample_fcn = gpt_rev.sample_char
368
  elif model == "GPT First Rev":
369
+ config = gpt_first_rev_config
370
+ sample_fcn = gpt_first_rev.sample_char
 
371
  else:
372
  return "Error: Model not selected"
373