yasserrmd commited on
Commit
e95ccf0
·
verified ·
1 Parent(s): 1868337

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -29,7 +29,7 @@ class VibeVoiceDemo:
29
  self.processor = None
30
  self.model = None
31
  self.available_voices = {}
32
- #self.load_model()
33
  self.setup_voice_presets()
34
  self.load_example_scripts()
35
 
@@ -38,11 +38,10 @@ class VibeVoiceDemo:
38
  self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
39
  self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
40
  self.model_path,
41
- torch_dtype=torch.bfloat16,
42
- device_map=self.device
43
  )
44
- self.model.eval()
45
- self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
46
 
47
  def setup_voice_presets(self):
48
  voices_dir = os.path.join(os.path.dirname(__file__), "voices")
@@ -82,7 +81,13 @@ class VibeVoiceDemo:
82
  This is a non-streaming function.
83
  """
84
  try:
85
- self.load_model()
 
 
 
 
 
 
86
  # 1. Set generating state and validate inputs
87
  self.is_generating = True
88
 
 
29
  self.processor = None
30
  self.model = None
31
  self.available_voices = {}
32
+ self.load_model()
33
  self.setup_voice_presets()
34
  self.load_example_scripts()
35
 
 
38
  self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
39
  self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
40
  self.model_path,
41
+ torch_dtype=torch.bfloat16
 
42
  )
43
+ # self.model.eval()
44
+ # self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
45
 
46
  def setup_voice_presets(self):
47
  voices_dir = os.path.join(os.path.dirname(__file__), "voices")
 
81
  This is a non-streaming function.
82
  """
83
  try:
84
+ self.model = self.model.to(self.device)
85
+
86
+ print(f"Model successfully moved to device: {self.device.upper()}")
87
+
88
+ # Step 3: Continue with the rest of your setup.
89
+ self.model.eval()
90
+ self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
91
  # 1. Set generating state and validate inputs
92
  self.is_generating = True
93