File size: 2,564 Bytes
c508d7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
diff --git a/tensorboardX/summary.py b/tensorboardX/summary.py
index 27d99ea..f5bf234 100644
--- a/tensorboardX/summary.py
+++ b/tensorboardX/summary.py
@@ -373,36 +373,24 @@ def make_video(tensor, fps):
 
 def audio(tag, tensor, sample_rate=44100):
     tensor = make_np(tensor)
-    tensor = tensor.squeeze()
     if abs(tensor).max() > 1:
         print('warning: audio amplitude out of range, auto clipped.')
         tensor = tensor.clip(-1, 1)
-    assert(tensor.ndim == 1), 'input tensor should be 1 dimensional.'
-
-    tensor_list = [int(32767.0 * x) for x in tensor]
+    assert(tensor.ndim == 2), 'input tensor should be 2 dimensional.'
+    length_frames, num_channels = tensor.shape
+    assert num_channels == 1 or num_channels == 2, f'Expected 1/2 channels, got {num_channels}'
+    import soundfile
     import io
-    import wave
-    import struct
-    fio = io.BytesIO()
-    Wave_write = wave.open(fio, 'wb')
-    Wave_write.setnchannels(1)
-    Wave_write.setsampwidth(2)
-    Wave_write.setframerate(sample_rate)
-    tensor_enc = b''
-    tensor_enc += struct.pack("<" + "h" * len(tensor_list), *tensor_list)
-
-    Wave_write.writeframes(tensor_enc)
-    Wave_write.close()
-    audio_string = fio.getvalue()
-    fio.close()
+    with io.BytesIO() as fio:
+        soundfile.write(fio, tensor, samplerate=sample_rate, format='wav')
+        audio_string = fio.getvalue()
     audio = Summary.Audio(sample_rate=sample_rate,
-                          num_channels=1,
-                          length_frames=len(tensor_list),
+                          num_channels=num_channels,
+                          length_frames=length_frames,
                           encoded_audio_string=audio_string,
                           content_type='audio/wav')
     return Summary(value=[Summary.Value(tag=tag, audio=audio)])
 
-
 def custom_scalars(layout):
     categoriesnames = layout.keys()
     categories = []
diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py
index 06337a7..58d57a1 100644
--- a/tensorboardX/writer.py
+++ b/tensorboardX/writer.py
@@ -716,7 +716,7 @@ class SummaryWriter(object):
             sample_rate (int): sample rate in Hz
             walltime (float): Optional override default walltime (time.time()) of event
         Shape:
-            snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
+            snd_tensor: :math:`(L, c)`. The values should lie between [-1, 1].
         """
         if self._check_caffe2_blob(snd_tensor):
             snd_tensor = workspace.FetchBlob(snd_tensor)