hexgrad commited on
Commit
c8ab947
·
verified ·
1 Parent(s): 1a61201

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -13,6 +13,30 @@ import spaces
13
  import torch
14
  import yaml
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  random_texts = {}
17
  for lang in ['en', 'ja']:
18
  with open(f'{lang}.txt', 'r') as r:
@@ -86,25 +110,6 @@ VOCAB = get_vocab()
86
  def tokenize(ps):
87
  return [i for i in map(VOCAB.get, ps) if i is not None]
88
 
89
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
90
-
91
- snapshot = snapshot_download(repo_id='hexgrad/kokoro', allow_patterns=['*.pt', '*.pth', '*.yml'], use_auth_token=os.environ['TOKEN'])
92
- config = yaml.safe_load(open(os.path.join(snapshot, 'config.yml')))
93
- model = build_model(config['model_params'])
94
- for key, value in model.items():
95
- for module in value.children():
96
- if isinstance(module, torch.nn.RNNBase):
97
- module.flatten_parameters()
98
- _ = [model[key].eval() for key in model]
99
- _ = [model[key].to(device) for key in model]
100
- for key, state_dict in torch.load(os.path.join(snapshot, 'net.pth'), map_location='cpu', weights_only=True)['net'].items():
101
- assert key in model, key
102
- try:
103
- model[key].load_state_dict(state_dict)
104
- except:
105
- state_dict = {k[7:]: v for k, v in state_dict.items()}
106
- model[key].load_state_dict(state_dict, strict=False)
107
-
108
  CHOICES = {
109
  '🇺🇸 🚺 American Female 0': 'af_0',
110
  '🇺🇸 🚺 Bella': 'af_bella',
 
13
  import torch
14
  import yaml
15
 
16
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+
18
+ snapshot = snapshot_download(repo_id='hexgrad/kokoro', allow_patterns=['*.pt', '*.pth', '*.yml'], use_auth_token=os.environ['TOKEN'])
19
+ config = yaml.safe_load(open(os.path.join(snapshot, 'config.yml')))
20
+ model = build_model(config['model_params'])
21
+ for key, value in model.items():
22
+ for module in value.children():
23
+ if isinstance(module, torch.nn.RNNBase):
24
+ module.flatten_parameters()
25
+
26
+ _ = [model[key].eval() for key in model]
27
+ _ = [model[key].to(device) for key in model]
28
+ for key, state_dict in torch.load(os.path.join(snapshot, 'net.pth'), map_location='cpu', weights_only=True)['net'].items():
29
+ assert key in model, key
30
+ try:
31
+ model[key].load_state_dict(state_dict)
32
+ except:
33
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
34
+ model[key].load_state_dict(state_dict, strict=False)
35
+
36
+ PARAM_COUNT = sum(p.numel() for value in model.values() for p in value.parameters())
37
+ print('PARAM_COUNT', PARAM_COUNT)
38
+ assert PARAM_COUNT < 82_000_000, PARAM_COUNT
39
+
40
  random_texts = {}
41
  for lang in ['en', 'ja']:
42
  with open(f'{lang}.txt', 'r') as r:
 
110
  def tokenize(ps):
111
  return [i for i in map(VOCAB.get, ps) if i is not None]
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  CHOICES = {
114
  '🇺🇸 🚺 American Female 0': 'af_0',
115
  '🇺🇸 🚺 Bella': 'af_bella',