diff options
Diffstat (limited to 'tensorflow/core/lib/wav/wav_io.cc')
-rw-r--r-- | tensorflow/core/lib/wav/wav_io.cc | 125 |
1 files changed, 121 insertions, 4 deletions
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 31c81b7dde..97e218a793 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -65,21 +65,64 @@ static_assert(sizeof(WavHeader) == sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk), "TF_PACKED does not work."); +constexpr char kRiffChunkId[] = "RIFF"; +constexpr char kRiffType[] = "WAVE"; +constexpr char kFormatChunkId[] = "fmt "; +constexpr char kDataChunkId[] = "data"; + inline int16 FloatToInt16Sample(float data) { constexpr float kMultiplier = 1.0f * (1 << 15); return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min), kint16max); } +inline float Int16SampleToFloat(int16 data) { + constexpr float kMultiplier = 1.0f / (1 << 15); + return data * kMultiplier; +} + +Status ExpectText(const string& data, const string& expected_text, + int* offset) { + const int new_offset = *offset + expected_text.size(); + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read ", + expected_text); + } + const string found_text(data.begin() + *offset, data.begin() + new_offset); + if (found_text != expected_text) { + return errors::InvalidArgument("Header mismatch: Expected ", expected_text, + " but found ", found_text); + } + *offset = new_offset; + return Status::OK(); +} + +template <class T> +Status ReadValue(const string& data, T* value, int* offset) { + const int new_offset = *offset + sizeof(T); + if (new_offset > data.size()) { + return errors::InvalidArgument("Data too short when trying to read value"); + } + if (port::kLittleEndian) { + memcpy(value, data.data() + *offset, sizeof(T)); + } else { + *value = 0; + const uint8* data_buf = + reinterpret_cast<const uint8*>(data.data() + *offset); + int shift = 0; + for (int i = 0; i < sizeof(T); ++i, shift += 8) { + *value = *value | (data_buf[i] >> shift); + } + } + *offset = new_offset; + return Status::OK(); +} + } // namespace Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, size_t num_channels, size_t num_frames, string* wav_string) { - constexpr char kRiffChunkId[] = "RIFF"; - constexpr char kRiffType[] = "WAVE"; - constexpr char kFormatChunkId[] = "fmt "; - constexpr char kDataChunkId[] = "data"; constexpr size_t kFormatChunkSize = 16; constexpr size_t kCompressionCodePcm = 1; constexpr size_t kBitsPerSample = 16; @@ -153,5 +196,79 @@ Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, return Status::OK(); } +Status DecodeLin16WaveAsFloatVector(const string& wav_string, + std::vector<float>* float_values, + uint32* sample_count, uint16* channel_count, + uint32* sample_rate) { + int offset = 0; + TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset)); + uint32 total_file_size; + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset)); + TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset)); + TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset)); + uint32 format_chunk_size; + TF_RETURN_IF_ERROR( + ReadValue<uint32>(wav_string, &format_chunk_size, &offset)); + if ((format_chunk_size != 16) && (format_chunk_size != 18)) { + return errors::InvalidArgument( + "Bad file size for WAV: Expected 16 or 18, but got", format_chunk_size); + } + uint16 audio_format; + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset)); + if (audio_format != 1) { + return errors::InvalidArgument( + "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format); + } + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset)); + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset)); + uint32 bytes_per_second; + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset)); + uint16 bytes_per_sample; + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset)); + // Confusingly, bits per sample is defined as holding the number of bits for + // one channel, unlike the definition of sample used elsewhere in the WAV + // spec. For example, bytes per sample is the memory needed for all channels + // for one point in time. + uint16 bits_per_sample; + TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset)); + if (bits_per_sample != 16) { + return errors::InvalidArgument( + "Can only read 16-bit WAV files, but received ", bits_per_sample); + } + const uint32 expected_bytes_per_sample = + ((bits_per_sample * *channel_count) + 7) / 8; + if (bytes_per_sample != expected_bytes_per_sample) { + return errors::InvalidArgument( + "Bad bytes per sample in WAV header: Expected ", + expected_bytes_per_sample, " but got ", bytes_per_sample); + } + const uint32 expected_bytes_per_second = + (bytes_per_sample * (*sample_rate)) / *channel_count; + if (bytes_per_second != expected_bytes_per_second) { + return errors::InvalidArgument( + "Bad bytes per second in WAV header: Expected ", + expected_bytes_per_second, " but got ", bytes_per_second, + " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample, + ")"); + } + if (format_chunk_size == 18) { + // Skip over this unused section. + offset += 2; + } + TF_RETURN_IF_ERROR(ExpectText(wav_string, kDataChunkId, &offset)); + uint32 data_size; + TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &data_size, &offset)); + *sample_count = data_size / bytes_per_sample; + const uint32 data_count = *sample_count * *channel_count; + float_values->resize(data_count); + for (int i = 0; i < data_count; ++i) { + int16 single_channel_value; + TF_RETURN_IF_ERROR( + ReadValue<int16>(wav_string, &single_channel_value, &offset)); + (*float_values)[i] = Int16SampleToFloat(single_channel_value); + } + return Status::OK(); +} + } // namespace wav } // namespace tensorflow |