diff options
author | Pete Warden <petewarden@google.com> | 2017-03-27 16:33:37 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-27 19:38:26 -0700 |
commit | ce0a07f2479c864b0a6cc8f4a218d74e8b480746 (patch) | |
tree | bde31344445ab010e2c0e660d2a5b5c2f146fdfb /tensorflow/core/lib/wav | |
parent | 164b6a88f56a8e491b315f2747303c46f04b5c76 (diff) |
Added WAV audio file reading and writing operations
Change: 151395519
Diffstat (limited to 'tensorflow/core/lib/wav')
-rw-r--r-- | tensorflow/core/lib/wav/wav_io.cc | 125 | ||||
-rw-r--r-- | tensorflow/core/lib/wav/wav_io.h | 13 | ||||
-rw-r--r-- | tensorflow/core/lib/wav/wav_io_test.cc | 19 |
3 files changed, 153 insertions, 4 deletions
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 31c81b7dde..97e218a793 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -65,21 +65,64 @@ static_assert(sizeof(WavHeader) == sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk), "TF_PACKED does not work."); +constexpr char kRiffChunkId[] = "RIFF"; +constexpr char kRiffType[] = "WAVE"; +constexpr char kFormatChunkId[] = "fmt "; +constexpr char kDataChunkId[] = "data"; + inline int16 FloatToInt16Sample(float data) { constexpr float kMultiplier = 1.0f * (1 << 15); return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min), kint16max); } +inline float Int16SampleToFloat(int16 data) { + constexpr float kMultiplier = 1.0f / (1 << 15); + return data * kMultiplier; +} + +Status ExpectText(const string& data, const string& expected_text, + int* offset) { + const int new_offset = *offset + expected_text.size(); + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read ", + expected_text); + } + const string found_text(data.begin() + *offset, data.begin() + new_offset); + if (found_text != expected_text) { + return errors::InvalidArgument("Header mismatch: Expected ", expected_text, + " but found ", found_text); + } + *offset = new_offset; + return Status::OK(); +} + +template <class T> +Status ReadValue(const string& data, T* value, int* offset) { + const int new_offset = *offset + sizeof(T); + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read value"); + } + if (port::kLittleEndian) { + memcpy(value, data.data() + *offset, sizeof(T)); + } else { + *value = 0; + const uint8* data_buf = + reinterpret_cast<const uint8*>(data.data() + *offset); + int shift = 0; + for (int i = 0; i < sizeof(T); ++i, shift += 8) { + *value = *value | (data_buf[i] >> shift); + } + } + *offset = new_offset; + return Status::OK(); +} + } // namespace Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, size_t num_channels, size_t num_frames, string* wav_string) { - constexpr char kRiffChunkId[] = "RIFF"; - constexpr char kRiffType[] = "WAVE"; - constexpr char kFormatChunkId[] = "fmt "; - constexpr char kDataChunkId[] = "data"; constexpr size_t kFormatChunkSize = 16; constexpr size_t kCompressionCodePcm = 1; constexpr size_t kBitsPerSample = 16; @@ -153,5 +196,79 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, return Status::OK(); } +Status DecodeLin16WaveAsFloatVector(const string& wav_string, + std::vector<float>* float_values, + uint32* sample_count, uint16* channel_count, + uint32* sample_rate) { + int offset = 0; + TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset)); + uint32 total_file_size; + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset)); + TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset)); + TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset)); + uint32 format_chunk_size; + TF_RETURN_IF_ERROR( + ReadValue<uint32>(wav_string, &format_chunk_size, &offset)); + if ((format_chunk_size != 16) && (format_chunk_size != 18)) { + return errors::InvalidArgument( + "Bad file size for WAV: Expected 16 or 18, but got", format_chunk_size); + } + uint16 audio_format; + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset)); + if (audio_format != 1) { + return errors::InvalidArgument( + "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format); + } + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset)); + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset)); + uint32 bytes_per_second; + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset)); + uint16 bytes_per_sample; + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset)); + // Confusingly, bits per sample is defined as holding the number of bits for + // one channel, unlike the definition of sample used elsewhere in the WAV + // spec. For example, bytes per sample is the memory needed for all channels + // for one point in time. + uint16 bits_per_sample; + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset)); + if (bits_per_sample != 16) { + return errors::InvalidArgument( + "Can only read 16-bit WAV files, but received ", bits_per_sample); + } + const uint32 expected_bytes_per_sample = + ((bits_per_sample * *channel_count) + 7) / 8; + if (bytes_per_sample != expected_bytes_per_sample) { + return errors::InvalidArgument( + "Bad bytes per sample in WAV header: Expected ", + expected_bytes_per_sample, " but got ", bytes_per_sample); + } + const uint32 expected_bytes_per_second = + (bytes_per_sample * (*sample_rate)) / *channel_count; + if (bytes_per_second != expected_bytes_per_second) { + return errors::InvalidArgument( + "Bad bytes per second in WAV header: Expected ", + expected_bytes_per_second, " but got ", bytes_per_second, + " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample, + ")"); + } + if (format_chunk_size == 18) { + // Skip over this unused section. + offset += 2; + } + TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset)); + uint32 data_size; + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &data_size, &offset)); + *sample_count = data_size / bytes_per_sample; + const uint32 data_count = *sample_count * *channel_count; + float_values->resize(data_count); + for (int i = 0; i < data_count; ++i) { + int16 single_channel_value; + TF_RETURN_IF_ERROR( + ReadValue<int16>(wav_string, &single_channel_value, &offset)); + (*float_values)[i] = Int16SampleToFloat(single_channel_value); + } + return Status::OK(); +} + } // namespace wav } // namespace tensorflow diff --git a/tensorflow/core/lib/wav/wav_io.h b/tensorflow/core/lib/wav/wav_io.h index 68629996e1..adca0ee303 100644 --- a/tensorflow/core/lib/wav/wav_io.h +++ b/tensorflow/core/lib/wav/wav_io.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_LIB_WAV_WAV_IO_H_ #include <string> +#include <vector> #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" @@ -42,6 +43,18 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, size_t num_channels, size_t num_frames, string* wav_string); +// Decodes the little-endian signed 16-bit PCM WAV file data (aka LIN16 +// encoding) into a float Tensor. The channels are encoded as the lowest +// dimension of the tensor, with the number of frames as the second. This means +// that a four frame stereo signal will have the shape [4, 2]. The sample rate +// is read from the file header, and an error is returned if the format is not +// supported. +// The results are output as floats within the range -1 to 1, +Status DecodeLin16WaveAsFloatVector(const string& wav_string, + std::vector<float>* float_values, + uint32* sample_count, uint16* channel_count, + uint32* sample_rate); + } // namespace wav } // namespace tensorflow diff --git a/tensorflow/core/lib/wav/wav_io_test.cc b/tensorflow/core/lib/wav/wav_io_test.cc index 11f1bfa527..e54b9445ab 100644 --- a/tensorflow/core/lib/wav/wav_io_test.cc +++ b/tensorflow/core/lib/wav/wav_io_test.cc @@ -78,5 +78,24 @@ TEST(WavIO, BasicOdd) { EXPECT_EQ(54, result.size()); } +TEST(WavIO, EncodeThenDecode) { + float audio[] = {0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}; + string wav_data; + TF_ASSERT_OK(EncodeAudioAsS16LEWav(audio, 44100, 2, 3, &wav_data)); + std::vector<float> decoded_audio; + uint32 decoded_sample_count; + uint16 decoded_channel_count; + uint32 decoded_sample_rate; + TF_ASSERT_OK(DecodeLin16WaveAsFloatVector( + wav_data, &decoded_audio, &decoded_sample_count, &decoded_channel_count, + &decoded_sample_rate)); + EXPECT_EQ(2, decoded_channel_count); + EXPECT_EQ(3, decoded_sample_count); + EXPECT_EQ(44100, decoded_sample_rate); + for (int i = 0; i < 6; ++i) { + EXPECT_NEAR(audio[i], decoded_audio[i], 1e-4f) << "i=" << i; + } +} + } // namespace wav } // namespace tensorflow |