aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/wav
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2017-03-27 16:33:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 19:38:26 -0700
commitce0a07f2479c864b0a6cc8f4a218d74e8b480746 (patch)
treebde31344445ab010e2c0e660d2a5b5c2f146fdfb /tensorflow/core/lib/wav
parent164b6a88f56a8e491b315f2747303c46f04b5c76 (diff)
Added WAV audio file reading and writing operations
Change: 151395519
Diffstat (limited to 'tensorflow/core/lib/wav')
-rw-r--r--tensorflow/core/lib/wav/wav_io.cc125
-rw-r--r--tensorflow/core/lib/wav/wav_io.h13
-rw-r--r--tensorflow/core/lib/wav/wav_io_test.cc19
3 files changed, 153 insertions, 4 deletions
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 31c81b7dde..97e218a793 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -65,21 +65,64 @@ static_assert(sizeof(WavHeader) ==
sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
"TF_PACKED does not work.");
+constexpr char kRiffChunkId[] = "RIFF";
+constexpr char kRiffType[] = "WAVE";
+constexpr char kFormatChunkId[] = "fmt ";
+constexpr char kDataChunkId[] = "data";
+
inline int16 FloatToInt16Sample(float data) {
constexpr float kMultiplier = 1.0f * (1 << 15);
return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
kint16max);
}
+inline float Int16SampleToFloat(int16 data) {
+ constexpr float kMultiplier = 1.0f / (1 << 15);
+ return data * kMultiplier;
+}
+
+Status ExpectText(const string& data, const string& expected_text,
+ int* offset) {
+ const int new_offset = *offset + expected_text.size();
+ if (new_offset > data.size()) {
+ return errors::InvalidArgument("Data too short when trying to read ",
+ expected_text);
+ }
+ const string found_text(data.begin() + *offset, data.begin() + new_offset);
+ if (found_text != expected_text) {
+ return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
+ " but found ", found_text);
+ }
+ *offset = new_offset;
+ return Status::OK();
+}
+
+template <class T>
+Status ReadValue(const string& data, T* value, int* offset) {
+ const int new_offset = *offset + sizeof(T);
+ if (new_offset > data.size()) {
+ return errors::InvalidArgument("Data too short when trying to read value");
+ }
+ if (port::kLittleEndian) {
+ memcpy(value, data.data() + *offset, sizeof(T));
+ } else {
+ *value = 0;
+ const uint8* data_buf =
+ reinterpret_cast<const uint8*>(data.data() + *offset);
+ int shift = 0;
+ for (int i = 0; i < sizeof(T); ++i, shift += 8) {
+ *value = *value | (data_buf[i] >> shift);
+ }
+ }
+ *offset = new_offset;
+ return Status::OK();
+}
+
} // namespace
Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
size_t num_channels, size_t num_frames,
string* wav_string) {
- constexpr char kRiffChunkId[] = "RIFF";
- constexpr char kRiffType[] = "WAVE";
- constexpr char kFormatChunkId[] = "fmt ";
- constexpr char kDataChunkId[] = "data";
constexpr size_t kFormatChunkSize = 16;
constexpr size_t kCompressionCodePcm = 1;
constexpr size_t kBitsPerSample = 16;
@@ -153,5 +196,79 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
return Status::OK();
}
+Status DecodeLin16WaveAsFloatVector(const string& wav_string,
+ std::vector<float>* float_values,
+ uint32* sample_count, uint16* channel_count,
+ uint32* sample_rate) {
+ int offset = 0;
+ TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
+ uint32 total_file_size;
+ TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
+ TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
+ TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
+ uint32 format_chunk_size;
+ TF_RETURN_IF_ERROR(
+ ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
+ if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
+ return errors::InvalidArgument(
+ "Bad file size for WAV: Expected 16 or 18, but got", format_chunk_size);
+ }
+ uint16 audio_format;
+ TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
+ if (audio_format != 1) {
+ return errors::InvalidArgument(
+ "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
+ }
+ TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
+ TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
+ uint32 bytes_per_second;
+ TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
+ uint16 bytes_per_sample;
+ TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
+ // Confusingly, bits per sample is defined as holding the number of bits for
+ // one channel, unlike the definition of sample used elsewhere in the WAV
+ // spec. For example, bytes per sample is the memory needed for all channels
+ // for one point in time.
+ uint16 bits_per_sample;
+ TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
+ if (bits_per_sample != 16) {
+ return errors::InvalidArgument(
+ "Can only read 16-bit WAV files, but received ", bits_per_sample);
+ }
+ const uint32 expected_bytes_per_sample =
+ ((bits_per_sample * *channel_count) + 7) / 8;
+ if (bytes_per_sample != expected_bytes_per_sample) {
+ return errors::InvalidArgument(
+ "Bad bytes per sample in WAV header: Expected ",
+ expected_bytes_per_sample, " but got ", bytes_per_sample);
+ }
+ const uint32 expected_bytes_per_second =
+ (bytes_per_sample * (*sample_rate)) / *channel_count;
+ if (bytes_per_second != expected_bytes_per_second) {
+ return errors::InvalidArgument(
+ "Bad bytes per second in WAV header: Expected ",
+ expected_bytes_per_second, " but got ", bytes_per_second,
+ " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
+ ")");
+ }
+ if (format_chunk_size == 18) {
+ // Skip over this unused section.
+ offset += 2;
+ }
+ TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset));
+ uint32 data_size;
+ TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &data_size, &offset));
+ *sample_count = data_size / bytes_per_sample;
+ const uint32 data_count = *sample_count * *channel_count;
+ float_values->resize(data_count);
+ for (int i = 0; i < data_count; ++i) {
+ int16 single_channel_value;
+ TF_RETURN_IF_ERROR(
+ ReadValue<int16>(wav_string, &single_channel_value, &offset));
+ (*float_values)[i] = Int16SampleToFloat(single_channel_value);
+ }
+ return Status::OK();
+}
+
} // namespace wav
} // namespace tensorflow
diff --git a/tensorflow/core/lib/wav/wav_io.h b/tensorflow/core/lib/wav/wav_io.h
index 68629996e1..adca0ee303 100644
--- a/tensorflow/core/lib/wav/wav_io.h
+++ b/tensorflow/core/lib/wav/wav_io.h
@@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_LIB_WAV_WAV_IO_H_
#include <string>
+#include <vector>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
@@ -42,6 +43,18 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
size_t num_channels, size_t num_frames,
string* wav_string);
+// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16
+// encoding) into a float Tensor. The channels are encoded as the lowest
+// dimension of the tensor, with the number of frames as the second. This means
+// that a four frame stereo signal will have the shape [4, 2]. The sample rate
+// is read from the file header, and an error is returned if the format is not
+// supported.
+// The results are output as floats within the range -1 to 1,
+Status DecodeLin16WaveAsFloatVector(const string& wav_string,
+ std::vector<float>* float_values,
+ uint32* sample_count, uint16* channel_count,
+ uint32* sample_rate);
+
} // namespace wav
} // namespace tensorflow
diff --git a/tensorflow/core/lib/wav/wav_io_test.cc b/tensorflow/core/lib/wav/wav_io_test.cc
index 11f1bfa527..e54b9445ab 100644
--- a/tensorflow/core/lib/wav/wav_io_test.cc
+++ b/tensorflow/core/lib/wav/wav_io_test.cc
@@ -78,5 +78,24 @@ TEST(WavIO, BasicOdd) {
EXPECT_EQ(54, result.size());
}
+TEST(WavIO, EncodeThenDecode) {
+ float audio[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f};
+ string wav_data;
+ TF_ASSERT_OK(EncodeAudioAsS16LEWav(audio, 44100, 2, 3, &wav_data));
+ std::vector<float> decoded_audio;
+ uint32 decoded_sample_count;
+ uint16 decoded_channel_count;
+ uint32 decoded_sample_rate;
+ TF_ASSERT_OK(DecodeLin16WaveAsFloatVector(
+ wav_data, &decoded_audio, &decoded_sample_count, &decoded_channel_count,
+ &decoded_sample_rate));
+ EXPECT_EQ(2, decoded_channel_count);
+ EXPECT_EQ(3, decoded_sample_count);
+ EXPECT_EQ(44100, decoded_sample_rate);
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_NEAR(audio[i], decoded_audio[i], 1e-4f) << "i=" << i;
+ }
+}
+
} // namespace wav
} // namespace tensorflow