MasonCriner / tensorboardX /tensorboardX.patch
MasonCrinr's picture
Upload 331 files
8026e91
diff --git a/tensorboardX/summary.py b/tensorboardX/summary.py
index 27d99ea..f5bf234 100644
--- a/tensorboardX/summary.py
+++ b/tensorboardX/summary.py
@@ -373,36 +373,24 @@ def make_video(tensor, fps):
def audio(tag, tensor, sample_rate=44100):
tensor = make_np(tensor)
- tensor = tensor.squeeze()
if abs(tensor).max() > 1:
print('warning: audio amplitude out of range, auto clipped.')
tensor = tensor.clip(-1, 1)
- assert(tensor.ndim == 1), 'input tensor should be 1 dimensional.'
-
- tensor_list = [int(32767.0 * x) for x in tensor]
+ assert(tensor.ndim == 2), 'input tensor should be 2 dimensional.'
+ length_frames, num_channels = tensor.shape
+ assert num_channels == 1 or num_channels == 2, f'Expected 1/2 channels, got {num_channels}'
+ import soundfile
import io
- import wave
- import struct
- fio = io.BytesIO()
- Wave_write = wave.open(fio, 'wb')
- Wave_write.setnchannels(1)
- Wave_write.setsampwidth(2)
- Wave_write.setframerate(sample_rate)
- tensor_enc = b''
- tensor_enc += struct.pack("<" + "h" * len(tensor_list), *tensor_list)
-
- Wave_write.writeframes(tensor_enc)
- Wave_write.close()
- audio_string = fio.getvalue()
- fio.close()
+ with io.BytesIO() as fio:
+ soundfile.write(fio, tensor, samplerate=sample_rate, format='wav')
+ audio_string = fio.getvalue()
audio = Summary.Audio(sample_rate=sample_rate,
- num_channels=1,
- length_frames=len(tensor_list),
+ num_channels=num_channels,
+ length_frames=length_frames,
encoded_audio_string=audio_string,
content_type='audio/wav')
return Summary(value=[Summary.Value(tag=tag, audio=audio)])
-
def custom_scalars(layout):
categoriesnames = layout.keys()
categories = []
diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py
index 06337a7..58d57a1 100644
--- a/tensorboardX/writer.py
+++ b/tensorboardX/writer.py
@@ -716,7 +716,7 @@ class SummaryWriter(object):
sample_rate (int): sample rate in Hz
walltime (float): Optional override default walltime (time.time()) of event
Shape:
- snd_tensor: :math:`(1, L)`. The values should lie between [-1, 1].
+ snd_tensor: :math:`(L, c)`. The values should lie between [-1, 1].
"""
if self._check_caffe2_blob(snd_tensor):
snd_tensor = workspace.FetchBlob(snd_tensor)