From a24c6b842d982de8a38ae5058ace91cb47ee3cef Mon Sep 17 00:00:00 2001 From: Pete Warden Date: Wed, 5 Apr 2017 17:10:48 -0800 Subject: Add AudioSpectrogram op to TensorFlow for audio feature generation Change: 152332221 --- tensorflow/BUILD | 1 + tensorflow/contrib/cmake/CMakeLists.txt | 3 + tensorflow/contrib/cmake/external/fft2d.cmake | 37 +++ .../contrib/cmake/patches/fft2d/CMakeLists.txt | 17 ++ tensorflow/core/BUILD | 9 +- tensorflow/core/kernels/BUILD | 111 +++++++ tensorflow/core/kernels/spectrogram.cc | 212 +++++++++++++ tensorflow/core/kernels/spectrogram.h | 112 +++++++ .../core/kernels/spectrogram_convert_test_data.cc | 56 ++++ tensorflow/core/kernels/spectrogram_op.cc | 120 ++++++++ tensorflow/core/kernels/spectrogram_op_test.cc | 104 +++++++ tensorflow/core/kernels/spectrogram_test.cc | 340 +++++++++++++++++++++ .../core/kernels/spectrogram_test_data/README | 8 + .../spectrogram_test_data/short_test_segment.wav | Bin 0 -> 91784 bytes .../short_test_segment_spectrogram.csv.bin | Bin 0 -> 365968 bytes .../short_test_segment_spectrogram_400_200.csv.bin | Bin 0 -> 468768 bytes tensorflow/core/kernels/spectrogram_test_utils.cc | 288 +++++++++++++++++ tensorflow/core/kernels/spectrogram_test_utils.h | 81 +++++ tensorflow/core/lib/core/bits.h | 13 + tensorflow/core/ops/audio_ops.cc | 79 +++++ .../core/platform/default/build_config/BUILD | 1 + tensorflow/core/util/command_line_flags.cc | 28 ++ tensorflow/core/util/command_line_flags.h | 4 +- tensorflow/core/util/command_line_flags_test.cc | 35 ++- tensorflow/examples/wav_to_spectrogram/BUILD | 68 +++++ tensorflow/examples/wav_to_spectrogram/README.md | 49 +++ tensorflow/examples/wav_to_spectrogram/main.cc | 66 ++++ .../wav_to_spectrogram/wav_to_spectrogram.cc | 97 ++++++ .../wav_to_spectrogram/wav_to_spectrogram.h | 31 ++ .../wav_to_spectrogram/wav_to_spectrogram_test.cc | 37 +++ tensorflow/tools/lib_package/BUILD | 4 + tensorflow/tools/pip_package/BUILD | 2 + tensorflow/workspace.bzl | 10 + third_party/fft2d/BUILD | 30 ++ third_party/fft2d/LICENSE | 3 + third_party/fft2d/fft.h | 36 +++ third_party/fft2d/fft2d.BUILD | 36 +++ 37 files changed, 2120 insertions(+), 8 deletions(-) create mode 100644 tensorflow/contrib/cmake/external/fft2d.cmake create mode 100644 tensorflow/contrib/cmake/patches/fft2d/CMakeLists.txt create mode 100644 tensorflow/core/kernels/spectrogram.cc create mode 100644 tensorflow/core/kernels/spectrogram.h create mode 100644 tensorflow/core/kernels/spectrogram_convert_test_data.cc create mode 100644 tensorflow/core/kernels/spectrogram_op.cc create mode 100644 tensorflow/core/kernels/spectrogram_op_test.cc create mode 100644 tensorflow/core/kernels/spectrogram_test.cc create mode 100644 tensorflow/core/kernels/spectrogram_test_data/README create mode 100644 tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav create mode 100644 tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin create mode 100644 tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin create mode 100644 tensorflow/core/kernels/spectrogram_test_utils.cc create mode 100644 tensorflow/core/kernels/spectrogram_test_utils.h create mode 100644 tensorflow/examples/wav_to_spectrogram/BUILD create mode 100644 tensorflow/examples/wav_to_spectrogram/README.md create mode 100644 tensorflow/examples/wav_to_spectrogram/main.cc create mode 100644 tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.cc create mode 100644 tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h create mode 100644 tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram_test.cc create mode 100644 third_party/fft2d/BUILD create mode 100644 third_party/fft2d/LICENSE create mode 100644 third_party/fft2d/fft.h create mode 100644 third_party/fft2d/fft2d.BUILD diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 37dc8e265f..5d2b1e74df 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -276,6 +276,7 @@ filegroup( "//tensorflow/examples/tutorials/estimators:all_files", "//tensorflow/examples/tutorials/mnist:all_files", "//tensorflow/examples/tutorials/word2vec:all_files", + "//tensorflow/examples/wav_to_spectrogram:all_files", "//tensorflow/go:all_files", "//tensorflow/java:all_files", "//tensorflow/java/src/main/java/org/tensorflow/examples:all_files", diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index e27df6898e..31a3d45a98 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -108,6 +108,7 @@ include(eigen) include(gemmlowp) include(jsoncpp) include(farmhash) +include(fft2d) include(highwayhash) include(protobuf) if (tensorflow_BUILD_CC_TESTS) @@ -121,6 +122,7 @@ set(tensorflow_EXTERNAL_LIBRARIES ${jpeg_STATIC_LIBRARIES} ${jsoncpp_STATIC_LIBRARIES} ${farmhash_STATIC_LIBRARIES} + ${fft2d_STATIC_LIBRARIES} ${highwayhash_STATIC_LIBRARIES} ${protobuf_STATIC_LIBRARIES} ) @@ -135,6 +137,7 @@ set(tensorflow_EXTERNAL_DEPENDENCIES protobuf eigen gemmlowp + fft2d ) include_directories( diff --git a/tensorflow/contrib/cmake/external/fft2d.cmake b/tensorflow/contrib/cmake/external/fft2d.cmake new file mode 100644 index 0000000000..50c6b91684 --- /dev/null +++ b/tensorflow/contrib/cmake/external/fft2d.cmake @@ -0,0 +1,37 @@ +include (ExternalProject) + +set(fft2d_URL http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz) +set(fft2d_HASH SHA256=52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296) +set(fft2d_BUILD ${CMAKE_CURRENT_BINARY_DIR}/fft2d/) +set(fft2d_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/fft2d/src) + +if(WIN32) + set(fft2d_STATIC_LIBRARIES ${fft2d_BUILD}/src/lib/fft2d.lib) + + ExternalProject_Add(fft2d + PREFIX fft2d + URL ${fft2d_URL} + URL_HASH ${fft2d_HASH} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt + INSTALL_DIR ${fft2d_INSTALL} + CMAKE_CACHE_ARGS + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -DCMAKE_INSTALL_PREFIX:STRING=${fft2d_INSTALL}) +else() + set(fft2d_STATIC_LIBRARIES ${fft2d_BUILD}/src/fft2d/libfft2d.a) + + ExternalProject_Add(fft2d + PREFIX fft2d + URL ${fft2d_URL} + URL_HASH ${fft2d_HASH} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + BUILD_IN_SOURCE 1 + PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt + INSTALL_DIR $(fft2d_INSTALL) + INSTALL_COMMAND echo + BUILD_COMMAND $(MAKE)) + +endif() diff --git a/tensorflow/contrib/cmake/patches/fft2d/CMakeLists.txt b/tensorflow/contrib/cmake/patches/fft2d/CMakeLists.txt new file mode 100644 index 0000000000..b31ea3ed98 --- /dev/null +++ b/tensorflow/contrib/cmake/patches/fft2d/CMakeLists.txt @@ -0,0 +1,17 @@ +cmake_minimum_required(VERSION 2.8.3) + +project(fft2d) + +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +set(FFT2D_SRCS + "fftsg.c" +) + +include_directories("${CMAKE_CURRENT_SOURCE_DIR}") + +add_library(fft2d ${FFT2D_SRCS}) + +install(TARGETS fft2d + LIBRARY DESTINATION lib COMPONENT RuntimeLibraries + ARCHIVE DESTINATION lib COMPONENT Development) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index ba761cd7c6..6b7e297c4f 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -494,7 +494,6 @@ cc_library( tf_gen_op_libs( op_lib_names = [ "array_ops", - "audio_ops", "candidate_sampling_ops", "control_flow_ops", "ctc_ops", @@ -524,6 +523,13 @@ tf_gen_op_libs( ], ) +tf_gen_op_libs( + op_lib_names = [ + "audio_ops", + ], + deps = [":lib"], +) + cc_library( name = "debug_ops_op_lib", srcs = ["ops/debug_ops.cc"], @@ -686,6 +692,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:audio", "//tensorflow/core/kernels:bincount_op", "//tensorflow/core/kernels:candidate_sampler_ops", "//tensorflow/core/kernels:control_flow_ops", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9c47d520d9..0a4fd0f256 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3537,6 +3537,117 @@ tf_kernel_library( ], ) +filegroup( + name = "spectrogram_test_data", + srcs = [ + "spectrogram_test_data/short_test_segment.wav", + "spectrogram_test_data/short_test_segment_spectrogram.csv.bin", + "spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "spectrogram", + srcs = ["spectrogram.cc"], + hdrs = ["spectrogram.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/fft2d:fft2d_headers", + "@fft2d//:fft2d", + ], +) + +cc_library( + name = "spectrogram_test_utils", + testonly = 1, + srcs = ["spectrogram_test_utils.cc"], + hdrs = ["spectrogram_test_utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + +cc_binary( + name = "spectrogram_convert_test_data", + testonly = 1, + srcs = ["spectrogram_convert_test_data.cc"], + deps = [ + ":spectrogram_test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "spectrogram_test", + size = "medium", + srcs = ["spectrogram_test.cc"], + data = [":spectrogram_test_data"], + deps = [ + ":spectrogram", + ":spectrogram_test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:lib_test_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "spectrogram_op", + prefix = "spectrogram_op", + deps = [ + ":spectrogram", + "//tensorflow/core:audio_ops_op_lib", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], + alwayslink = 1, +) + +tf_cuda_cc_test( + name = "spectrogram_op_test", + size = "small", + srcs = ["spectrogram_op_test.cc"], + deps = [ + ":ops_util", + ":spectrogram_op", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "audio", + deps = [ + ":decode_wav_op", + ":encode_wav_op", + ":spectrogram_op", + ], +) + # Android libraries ----------------------------------------------------------- # Changes to the Android srcs here should be replicated in diff --git a/tensorflow/core/kernels/spectrogram.cc b/tensorflow/core/kernels/spectrogram.cc new file mode 100644 index 0000000000..7531d5d64a --- /dev/null +++ b/tensorflow/core/kernels/spectrogram.cc @@ -0,0 +1,212 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/spectrogram.h" + +#include + +#include "third_party/fft2d/fft.h" +#include "tensorflow/core/lib/core/bits.h" + +namespace tensorflow { + +using std::complex; + +namespace { +// Returns the default Hann window function for the spectrogram. +void GetPeriodicHann(int window_length, std::vector* window) { + // Some platforms don't have M_PI, so define a local constant here. + const double pi = std::atan(1) * 4; + window->resize(window_length); + for (int i = 0; i < window_length; ++i) { + (*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length); + } +} +} // namespace + +bool Spectrogram::Initialize(int window_length, int step_length) { + std::vector window; + GetPeriodicHann(window_length, &window); + return Initialize(window, step_length); +} + +bool Spectrogram::Initialize(const std::vector& window, + int step_length) { + window_length_ = window.size(); + window_ = window; // Copy window. + if (window_length_ < 2) { + LOG(ERROR) << "Window length too short."; + initialized_ = false; + return false; + } + + step_length_ = step_length; + if (step_length_ < 1) { + LOG(ERROR) << "Step length must be positive."; + initialized_ = false; + return false; + } + + fft_length_ = NextPowerOfTwo(window_length_); + CHECK(fft_length_ >= window_length_); + output_frequency_channels_ = 1 + fft_length_ / 2; + + // Allocate 2 more than what rdft needs, so we can rationalize the layout. + fft_input_output_.assign(fft_length_ + 2, 0.0); + + int half_fft_length = fft_length_ / 2; + fft_double_working_area_.assign(half_fft_length, 0.0); + fft_integer_working_area_.assign(2 + static_cast(sqrt(half_fft_length)), + 0); + // Set flag element to ensure that the working areas are initialized + // on the first call to cdft. It's redundant given the assign above, + // but keep it as a reminder. + fft_integer_working_area_[0] = 0; + input_queue_.clear(); + samples_to_next_step_ = window_length_; + initialized_ = true; + return true; +} + +template +bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>* output) { + if (!initialized_) { + LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call " + << "to Initialize()."; + return false; + } + CHECK(output); + output->clear(); + int input_start = 0; + while (GetNextWindowOfSamples(input, &input_start)) { + DCHECK_EQ(input_queue_.size(), window_length_); + ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_. + // Add a new slice vector onto the output, to save new result to. + output->resize(output->size() + 1); + // Get a reference to the newly added slice to fill in. + auto& spectrogram_slice = output->back(); + spectrogram_slice.resize(output_frequency_channels_); + for (int i = 0; i < output_frequency_channels_; ++i) { + // This will convert double to float if it needs to. + spectrogram_slice[i] = complex( + fft_input_output_[2 * i], fft_input_output_[2 * i + 1]); + } + } + return true; +} +// Instantiate it four ways: +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, std::vector>>*); +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>*); +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>*); +template bool Spectrogram::ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>*); + +template +bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, + std::vector>* output) { + if (!initialized_) { + LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before " + << "successful call to Initialize()."; + return false; + } + CHECK(output); + output->clear(); + int input_start = 0; + while (GetNextWindowOfSamples(input, &input_start)) { + DCHECK_EQ(input_queue_.size(), window_length_); + ProcessCoreFFT(); // Processes input_queue_ to fft_input_output_. + // Add a new slice vector onto the output, to save new result to. + output->resize(output->size() + 1); + // Get a reference to the newly added slice to fill in. + auto& spectrogram_slice = output->back(); + spectrogram_slice.resize(output_frequency_channels_); + for (int i = 0; i < output_frequency_channels_; ++i) { + // Similar to the Complex case, except storing the norm. + // But the norm function is known to be a performance killer, + // so do it this way with explicit real and imagninary temps. + const double re = fft_input_output_[2 * i]; + const double im = fft_input_output_[2 * i + 1]; + // Which finally converts double to float if it needs to. + spectrogram_slice[i] = re * re + im * im; + } + } + return true; +} +// Instantiate it four ways: +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); +template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, std::vector>*); + +// Return true if a full window of samples is prepared; manage the queue. +template +bool Spectrogram::GetNextWindowOfSamples(const std::vector& input, + int* input_start) { + auto input_it = input.begin() + *input_start; + int input_remaining = input.end() - input_it; + if (samples_to_next_step_ > input_remaining) { + // Copy in as many samples are left and return false, no full window. + input_queue_.insert(input_queue_.end(), input_it, input.end()); + *input_start += input_remaining; // Increases it to input.size(). + samples_to_next_step_ -= input_remaining; + return false; // Not enough for a full window. + } else { + // Copy just enough into queue to make a new window, then trim the + // front off the queue to make it window-sized. + input_queue_.insert(input_queue_.end(), input_it, + input_it + samples_to_next_step_); + *input_start += samples_to_next_step_; + input_queue_.erase( + input_queue_.begin(), + input_queue_.begin() + input_queue_.size() - window_length_); + DCHECK_EQ(window_length_, input_queue_.size()); + samples_to_next_step_ = step_length_; // Be ready for next time. + return true; // Yes, input_queue_ now contains exactly a window-full. + } +} + +void Spectrogram::ProcessCoreFFT() { + for (int j = 0; j < window_length_; ++j) { + fft_input_output_[j] = input_queue_[j] * window_[j]; + } + // Zero-pad the rest of the input buffer. + for (int j = window_length_; j < fft_length_; ++j) { + fft_input_output_[j] = 0.0; + } + const int kForwardFFT = 1; // 1 means forward; -1 reverse. + // This real FFT is a fair amount faster than using cdft here. + rdft(fft_length_, kForwardFFT, &fft_input_output_[0], + &fft_integer_working_area_[0], &fft_double_working_area_[0]); + // Make rdft result look like cdft result; + // unpack the last real value from the first position's imag slot. + fft_input_output_[fft_length_] = fft_input_output_[1]; + fft_input_output_[fft_length_ + 1] = 0; + fft_input_output_[1] = 0; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/spectrogram.h b/tensorflow/core/kernels/spectrogram.h new file mode 100644 index 0000000000..5476a0a961 --- /dev/null +++ b/tensorflow/core/kernels/spectrogram.h @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Class for generating spectrogram slices from a waveform. +// Initialize() should be called before calls to other functions. Once +// Initialize() has been called and returned true, The Compute*() functions can +// be called repeatedly with sequential input data (ie. the first element of the +// next input vector directly follows the last element of the previous input +// vector). Whenever enough audio samples are buffered to produce a +// new frame, it will be placed in output. Output is cleared on each +// call to Compute*(). This class is thread-unsafe, and should only be +// called from one thread at a time. +// With the default parameters, the output of this class should be very +// close to the results of the following MATLAB code: +// overlap_samples = window_length_samples - step_samples; +// window = hann(window_length_samples, 'periodic'); +// S = abs(spectrogram(audio, window, overlap_samples)).^2; + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ + +#include +#include +#include + +#include "third_party/fft2d/fft.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +class Spectrogram { + public: + Spectrogram() : initialized_(false) {} + ~Spectrogram() {} + + // Initializes the class with a given window length and step length + // (both in samples). Internally a Hann window is used as the window + // function. Returns true on success, after which calls to Process() + // are possible. window_length must be greater than 1 and step + // length must be greater than 0. + bool Initialize(int window_length, int step_length); + + // Initialize with an explicit window instead of a length. + bool Initialize(const std::vector& window, int step_length); + + // Processes an arbitrary amount of audio data (contained in input) + // to yield complex spectrogram frames. After a successful call to + // Initialize(), Process() may be called repeatedly with new input data + // each time. The audio input is buffered internally, and the output + // vector is populated with as many temporally-ordered spectral slices + // as it is possible to generate from the input. The output is cleared + // on each call before the new frames (if any) are added. + // + // The template parameters can be float or double. + template + bool ComputeComplexSpectrogram( + const std::vector& input, + std::vector>>* output); + + // This function works as the one above, but returns the power + // (the L2 norm, or the squared magnitude) of each complex value. + template + bool ComputeSquaredMagnitudeSpectrogram( + const std::vector& input, + std::vector>* output); + + // Return reference to the window function used internally. + const std::vector& GetWindow() const { return window_; } + + // Return the number of frequency channels in the spectrogram. + int output_frequency_channels() const { return output_frequency_channels_; } + + private: + template + bool GetNextWindowOfSamples(const std::vector& input, + int* input_start); + void ProcessCoreFFT(); + + int fft_length_; + int output_frequency_channels_; + int window_length_; + int step_length_; + bool initialized_; + int samples_to_next_step_; + + std::vector window_; + std::vector fft_input_output_; + std::deque input_queue_; + + // Working data areas for the FFT routines. + std::vector fft_integer_working_area_; + std::vector fft_double_working_area_; + + TF_DISALLOW_COPY_AND_ASSIGN(Spectrogram); +}; + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_H_ diff --git a/tensorflow/core/kernels/spectrogram_convert_test_data.cc b/tensorflow/core/kernels/spectrogram_convert_test_data.cc new file mode 100644 index 0000000000..bae13c0213 --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_convert_test_data.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/spectrogram_test_utils.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace wav { + +// This takes a CSV file representing an array of complex numbers, and saves out +// a version using a binary format to save space in the repository. +Status ConvertCsvToRaw(const string& input_filename) { + std::vector>> input_data; + ReadCSVFileToComplexVectorOrDie(input_filename, &input_data); + const string output_filename = input_filename + ".bin"; + if (!WriteComplexVectorToRawFloatFile(output_filename, input_data)) { + return errors::InvalidArgument("Failed to write raw float file ", + input_filename); + } + LOG(INFO) << "Wrote raw file to " << output_filename; + return Status::OK(); +} + +} // namespace wav +} // namespace tensorflow + +int main(int argc, char* argv[]) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc < 2) { + LOG(ERROR) << "You must supply a CSV file as the first argument"; + return 1; + } + tensorflow::string filename(argv[1]); + tensorflow::Status status = tensorflow::wav::ConvertCsvToRaw(filename); + if (!status.ok()) { + LOG(ERROR) << "Error processing '" << filename << "':" << status; + return 1; + } + return 0; +} diff --git a/tensorflow/core/kernels/spectrogram_op.cc b/tensorflow/core/kernels/spectrogram_op.cc new file mode 100644 index 0000000000..98d9bb1ad1 --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_op.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/audio_ops.cc + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/spectrogram.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Create a spectrogram frequency visualization from audio data. +class SpectrogramOp : public OpKernel { + public: + explicit SpectrogramOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("window_size", &window_size_)); + OP_REQUIRES_OK(context, context->GetAttr("stride", &stride_)); + OP_REQUIRES_OK(context, + context->GetAttr("magnitude_squared", &magnitude_squared_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 2, + errors::InvalidArgument("input must be 2-dimensional", + input.shape().DebugString())); + Spectrogram spectrogram; + OP_REQUIRES(context, spectrogram.Initialize(window_size_, stride_), + errors::InvalidArgument( + "Spectrogram initialization failed for window size ", + window_size_, " and stride ", stride_)); + + const auto input_as_matrix = input.matrix(); + + const int64 sample_count = input.dim_size(0); + const int64 channel_count = input.dim_size(1); + + const int64 output_width = spectrogram.output_frequency_channels(); + const int64 length_minus_window = (sample_count - window_size_); + int64 output_height; + if (length_minus_window < 0) { + output_height = 0; + } else { + output_height = 1 + (length_minus_window / stride_); + } + const int64 output_slices = channel_count; + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK( + context, + context->allocate_output( + 0, TensorShape({output_slices, output_height, output_width}), + &output_tensor)); + auto output_flat = output_tensor->flat().data(); + + std::vector input_for_channel(sample_count); + for (int64 channel = 0; channel < channel_count; ++channel) { + float* output_slice = + output_flat + (channel * output_height * output_width); + for (int i = 0; i < sample_count; ++i) { + input_for_channel[i] = input_as_matrix(i, channel); + } + std::vector> spectrogram_output; + OP_REQUIRES(context, + spectrogram.ComputeSquaredMagnitudeSpectrogram( + input_for_channel, &spectrogram_output), + errors::InvalidArgument("Spectrogram compute failed")); + OP_REQUIRES(context, (spectrogram_output.size() == output_height), + errors::InvalidArgument( + "Spectrogram size calculation failed: Expected height ", + output_height, " but got ", spectrogram_output.size())); + OP_REQUIRES(context, + spectrogram_output.empty() || + (spectrogram_output[0].size() == output_width), + errors::InvalidArgument( + "Spectrogram size calculation failed: Expected width ", + output_width, " but got ", spectrogram_output[0].size())); + for (int row_index = 0; row_index < output_height; ++row_index) { + const std::vector& spectrogram_row = + spectrogram_output[row_index]; + DCHECK_EQ(spectrogram_row.size(), output_width); + float* output_row = output_slice + (row_index * output_width); + if (magnitude_squared_) { + for (int i = 0; i < output_width; ++i) { + output_row[i] = spectrogram_row[i]; + } + } else { + for (int i = 0; i < output_width; ++i) { + output_row[i] = sqrtf(spectrogram_row[i]); + } + } + } + } + } + + private: + int32 window_size_; + int32 stride_; + bool magnitude_squared_; +}; +REGISTER_KERNEL_BUILDER(Name("AudioSpectrogram").Device(DEVICE_CPU), + SpectrogramOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/spectrogram_op_test.cc b/tensorflow/core/kernels/spectrogram_op_test.cc new file mode 100644 index 0000000000..5c3cbeeeb9 --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_op_test.cc @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/ops/audio_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +using namespace ops; // NOLINT(build/namespaces) + +TEST(SpectrogramOpTest, SimpleTest) { + Scope root = Scope::NewRootScope(); + + Tensor audio_tensor(DT_FLOAT, TensorShape({8, 1})); + test::FillValues(&audio_tensor, + {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f}); + + Output audio_const_op = Const(root.WithOpName("audio_const_op"), + Input::Initializer(audio_tensor)); + + AudioSpectrogram spectrogram_op = + AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op, 8, 1); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), + {spectrogram_op.spectrogram}, &outputs)); + + const Tensor& spectrogram_tensor = outputs[0]; + + EXPECT_EQ(3, spectrogram_tensor.dims()); + EXPECT_EQ(5, spectrogram_tensor.dim_size(2)); + EXPECT_EQ(1, spectrogram_tensor.dim_size(1)); + EXPECT_EQ(1, spectrogram_tensor.dim_size(0)); + + test::ExpectTensorNear( + spectrogram_tensor, + test::AsTensor({0, 1, 2, 1, 0}, TensorShape({1, 1, 5})), 1e-3); +} + +TEST(SpectrogramOpTest, SquaredTest) { + Scope root = Scope::NewRootScope(); + + Tensor audio_tensor(DT_FLOAT, TensorShape({8, 1})); + test::FillValues(&audio_tensor, + {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f}); + + Output audio_const_op = Const(root.WithOpName("audio_const_op"), + Input::Initializer(audio_tensor)); + + AudioSpectrogram spectrogram_op = + AudioSpectrogram(root.WithOpName("spectrogram_op"), audio_const_op, 8, 1, + AudioSpectrogram::Attrs().MagnitudeSquared(true)); + + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + + TF_EXPECT_OK(session.Run(ClientSession::FeedType(), + {spectrogram_op.spectrogram}, &outputs)); + + const Tensor& spectrogram_tensor = outputs[0]; + + EXPECT_EQ(3, spectrogram_tensor.dims()); + EXPECT_EQ(5, spectrogram_tensor.dim_size(2)); + EXPECT_EQ(1, spectrogram_tensor.dim_size(1)); + EXPECT_EQ(1, spectrogram_tensor.dim_size(0)); + + test::ExpectTensorNear( + spectrogram_tensor, + test::AsTensor({0, 1, 4, 1, 0}, TensorShape({1, 1, 5})), 1e-3); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/spectrogram_test.cc b/tensorflow/core/kernels/spectrogram_test.cc new file mode 100644 index 0000000000..73175a91a0 --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_test.cc @@ -0,0 +1,340 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The MATLAB test data were generated using GenerateTestData.m. + +#include "tensorflow/core/kernels/spectrogram.h" + +#include +#include + +#include "tensorflow/core/kernels/spectrogram_test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using ::std::complex; + +const char kInputFilename[] = + "core/kernels/spectrogram_test_data/short_test_segment.wav"; + +const char kExpectedFilename[] = + "core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin"; +const int kDataVectorLength = 257; +const int kNumberOfFramesInTestData = 178; + +const char kExpectedNonPowerOfTwoFilename[] = + "core/kernels/spectrogram_test_data/" + "short_test_segment_spectrogram_400_200.csv.bin"; +const int kNonPowerOfTwoDataVectorLength = 257; +const int kNumberOfFramesInNonPowerOfTwoTestData = 228; + +TEST(SpectrogramTest, TooLittleDataYieldsNoFrames) { + Spectrogram sgram; + sgram.Initialize(400, 200); + std::vector input; + // Generate 44 samples of audio. + SineWave(44100, 1000.0, 0.001, &input); + EXPECT_EQ(44, input.size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(0, output.size()); +} + +TEST(SpectrogramTest, StepSizeSmallerThanWindow) { + Spectrogram sgram; + EXPECT_TRUE(sgram.Initialize(400, 200)); + std::vector input; + // Generate 661 samples of audio. + SineWave(44100, 1000.0, 0.015, &input); + EXPECT_EQ(661, input.size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(2, output.size()); +} + +TEST(SpectrogramTest, StepSizeBiggerThanWindow) { + Spectrogram sgram; + EXPECT_TRUE(sgram.Initialize(200, 400)); + std::vector input; + // Generate 882 samples of audio. + SineWave(44100, 1000.0, 0.02, &input); + EXPECT_EQ(882, input.size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(2, output.size()); +} + +TEST(SpectrogramTest, StepSizeBiggerThanWindow2) { + Spectrogram sgram; + EXPECT_TRUE(sgram.Initialize(200, 400)); + std::vector input; + // Generate more than 600 but fewer than 800 samples of audio. + SineWave(44100, 1000.0, 0.016, &input); + EXPECT_GT(input.size(), 600); + EXPECT_LT(input.size(), 800); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(2, output.size()); +} + +TEST(SpectrogramTest, + MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames) { + // Repeatedly pass inputs with "extra" samples beyond complete windows + // and check that the excess points cumulate to eventually cause an + // extra output frame. + Spectrogram sgram; + sgram.Initialize(200, 400); + std::vector input; + // Generate 882 samples of audio. + SineWave(44100, 1000.0, 0.02, &input); + EXPECT_EQ(882, input.size()); + std::vector>> output; + const std::vector expected_output_sizes = { + 2, // One pass of input leaves 82 samples buffered after two steps of + // 400. + 2, // Passing in 882 samples again will now leave 164 samples buffered. + 3, // Third time gives 246 extra samples, triggering an extra output + // frame. + }; + for (int expected_output_size : expected_output_sizes) { + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(expected_output_size, output.size()); + } +} + +TEST(SpectrogramTest, CumulatingExcessInputsForOverlappingFrames) { + // Input frames that don't fit into whole windows are cumulated even when + // the windows have overlap (similar to + // MultipleCallsToComputeComplexSpectrogramMayYieldDifferentNumbersOfFrames + // but with window size/hop size swapped). + Spectrogram sgram; + sgram.Initialize(400, 200); + std::vector input; + // Generate 882 samples of audio. + SineWave(44100, 1000.0, 0.02, &input); + EXPECT_EQ(882, input.size()); + std::vector>> output; + const std::vector expected_output_sizes = { + 3, // Windows 0..400, 200..600, 400..800 with 82 samples buffered. + 4, // 1764 frames input; outputs from 600, 800, 1000, 1200..1600. + 5, // 2646 frames in; outputs from 1400, 1600, 1800, 2000, 2200..2600. + }; + for (int expected_output_size : expected_output_sizes) { + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(expected_output_size, output.size()); + } +} + +TEST(SpectrogramTest, StepSizeEqualToWindowWorks) { + Spectrogram sgram; + sgram.Initialize(200, 200); + std::vector input; + // Generate 2205 samples of audio. + SineWave(44100, 1000.0, 0.05, &input); + EXPECT_EQ(2205, input.size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + EXPECT_EQ(11, output.size()); +} + +template +void CompareComplexData( + const std::vector>>& expected, + const std::vector>>& actual, + double tolerance) { + ASSERT_EQ(actual.size(), expected.size()); + for (int i = 0; i < expected.size(); ++i) { + ASSERT_EQ(expected[i].size(), actual[i].size()); + for (int j = 0; j < expected[i].size(); ++j) { + ASSERT_NEAR(real(expected[i][j]), real(actual[i][j]), tolerance) + << ": where i=" << i << " and j=" << j << "."; + ASSERT_NEAR(imag(expected[i][j]), imag(actual[i][j]), tolerance) + << ": where i=" << i << " and j=" << j << "."; + } + } +} + +template +double GetMaximumAbsolute(const std::vector>& spectrogram) { + double max_absolute = 0.0; + for (int i = 0; i < spectrogram.size(); ++i) { + for (int j = 0; j < spectrogram[i].size(); ++j) { + double absolute_value = std::abs(spectrogram[i][j]); + if (absolute_value > max_absolute) { + max_absolute = absolute_value; + } + } + } + return max_absolute; +} + +template +void CompareMagnitudeData( + const std::vector>>& + expected_complex_output, + const std::vector>& actual_squared_magnitude, + double tolerance) { + ASSERT_EQ(actual_squared_magnitude.size(), expected_complex_output.size()); + for (int i = 0; i < expected_complex_output.size(); ++i) { + ASSERT_EQ(expected_complex_output[i].size(), + actual_squared_magnitude[i].size()); + for (int j = 0; j < expected_complex_output[i].size(); ++j) { + ASSERT_NEAR(norm(expected_complex_output[i][j]), + actual_squared_magnitude[i][j], tolerance) + << ": where i=" << i << " and j=" << j << "."; + } + } +} + +TEST(SpectrogramTest, ReInitializationWorks) { + Spectrogram sgram; + sgram.Initialize(512, 256); + std::vector input; + CHECK(ReadWaveFileToVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename), + &input)); + std::vector>> first_output; + std::vector>> second_output; + sgram.Initialize(512, 256); + sgram.ComputeComplexSpectrogram(input, &first_output); + // Re-Initialize it. + sgram.Initialize(512, 256); + sgram.ComputeComplexSpectrogram(input, &second_output); + // Verify identical outputs. + ASSERT_EQ(first_output.size(), second_output.size()); + int slice_size = first_output[0].size(); + for (int i = 0; i < first_output.size(); ++i) { + ASSERT_EQ(slice_size, first_output[i].size()); + ASSERT_EQ(slice_size, second_output[i].size()); + for (int j = 0; j < slice_size; ++j) { + ASSERT_EQ(first_output[i][j], second_output[i][j]); + } + } +} + +TEST(SpectrogramTest, ComputedComplexDataAgreeWithMatlab) { + const int kInputDataLength = 45870; + Spectrogram sgram; + sgram.Initialize(512, 256); + std::vector input; + CHECK(ReadWaveFileToVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename), + &input)); + EXPECT_EQ(kInputDataLength, input.size()); + std::vector>> expected_output; + ASSERT_TRUE(ReadRawFloatFileToComplexVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename), + kDataVectorLength, &expected_output)); + EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size()); + EXPECT_EQ(kDataVectorLength, expected_output[0].size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + CompareComplexData(expected_output, output, 1e-5); +} + +TEST(SpectrogramTest, ComputedFloatComplexDataAgreeWithMatlab) { + const int kInputDataLength = 45870; + Spectrogram sgram; + sgram.Initialize(512, 256); + std::vector double_input; + CHECK(ReadWaveFileToVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename), + &double_input)); + std::vector input; + input.assign(double_input.begin(), double_input.end()); + EXPECT_EQ(kInputDataLength, input.size()); + std::vector>> expected_output; + ASSERT_TRUE(ReadRawFloatFileToComplexVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename), + kDataVectorLength, &expected_output)); + EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size()); + EXPECT_EQ(kDataVectorLength, expected_output[0].size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + CompareComplexData(expected_output, output, 1e-4); +} + +TEST(SpectrogramTest, ComputedSquaredMagnitudeDataAgreeWithMatlab) { + const int kInputDataLength = 45870; + Spectrogram sgram; + sgram.Initialize(512, 256); + std::vector input; + CHECK(ReadWaveFileToVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename), + &input)); + EXPECT_EQ(kInputDataLength, input.size()); + std::vector>> expected_output; + ASSERT_TRUE(ReadRawFloatFileToComplexVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename), + kDataVectorLength, &expected_output)); + EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size()); + EXPECT_EQ(kDataVectorLength, expected_output[0].size()); + std::vector> output; + sgram.ComputeSquaredMagnitudeSpectrogram(input, &output); + CompareMagnitudeData(expected_output, output, 1e-3); +} + +TEST(SpectrogramTest, ComputedFloatSquaredMagnitudeDataAgreeWithMatlab) { + const int kInputDataLength = 45870; + Spectrogram sgram; + sgram.Initialize(512, 256); + std::vector double_input; + CHECK(ReadWaveFileToVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename), + &double_input)); + EXPECT_EQ(kInputDataLength, double_input.size()); + std::vector input; + input.assign(double_input.begin(), double_input.end()); + std::vector>> expected_output; + ASSERT_TRUE(ReadRawFloatFileToComplexVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kExpectedFilename), + kDataVectorLength, &expected_output)); + EXPECT_EQ(kNumberOfFramesInTestData, expected_output.size()); + EXPECT_EQ(kDataVectorLength, expected_output[0].size()); + std::vector> output; + sgram.ComputeSquaredMagnitudeSpectrogram(input, &output); + double max_absolute = GetMaximumAbsolute(output); + EXPECT_GT(max_absolute, 2300.0); // Verify that we have some big numbers. + // Squaring increases dynamic range; max square is about 2300, + // so 2e-4 is about 7 decimal digits; not bad for a float. + CompareMagnitudeData(expected_output, output, 2e-4); +} + +TEST(SpectrogramTest, ComputedNonPowerOfTwoComplexDataAgreeWithMatlab) { + const int kInputDataLength = 45870; + Spectrogram sgram; + sgram.Initialize(400, 200); + std::vector input; + CHECK(ReadWaveFileToVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), kInputFilename), + &input)); + EXPECT_EQ(kInputDataLength, input.size()); + std::vector>> expected_output; + ASSERT_TRUE(ReadRawFloatFileToComplexVector( + tensorflow::io::JoinPath(testing::TensorFlowSrcRoot(), + kExpectedNonPowerOfTwoFilename), + kNonPowerOfTwoDataVectorLength, &expected_output)); + EXPECT_EQ(kNumberOfFramesInNonPowerOfTwoTestData, expected_output.size()); + EXPECT_EQ(kNonPowerOfTwoDataVectorLength, expected_output[0].size()); + std::vector>> output; + sgram.ComputeComplexSpectrogram(input, &output); + CompareComplexData(expected_output, output, 1e-5); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/spectrogram_test_data/README b/tensorflow/core/kernels/spectrogram_test_data/README new file mode 100644 index 0000000000..271238e0c9 --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_test_data/README @@ -0,0 +1,8 @@ +The CSV spectrogram files in this directory are generated from the +matlab code in ./matlab/GenerateTestData.m +To save space in the repo, you'll then need to convert them into a binary packed +format using the convert_test_data.cc command line tool. + + +short_test_segment.wav is approximately 1s of music audio. + diff --git a/tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav b/tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav new file mode 100644 index 0000000000..7339dfd08c Binary files /dev/null and b/tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav differ diff --git a/tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin b/tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin new file mode 100644 index 0000000000..67b9e2487c Binary files /dev/null and b/tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram.csv.bin differ diff --git a/tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin b/tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin new file mode 100644 index 0000000000..d5e4cc5dd6 Binary files /dev/null and b/tensorflow/core/kernels/spectrogram_test_data/short_test_segment_spectrogram_400_200.csv.bin differ diff --git a/tensorflow/core/kernels/spectrogram_test_utils.cc b/tensorflow/core/kernels/spectrogram_test_utils.cc new file mode 100644 index 0000000000..a2141c649f --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_test_utils.cc @@ -0,0 +1,288 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/spectrogram_test_utils.h" + +#include +#include + +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/wav/wav_io.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +bool ReadWaveFileToVector(const string& file_name, std::vector* data) { + string wav_data; + if (!ReadFileToString(Env::Default(), file_name, &wav_data).ok()) { + LOG(ERROR) << "Wave file read failed for " << file_name; + return false; + } + std::vector decoded_data; + uint32 decoded_sample_count; + uint16 decoded_channel_count; + uint32 decoded_sample_rate; + if (!wav::DecodeLin16WaveAsFloatVector( + wav_data, &decoded_data, &decoded_sample_count, + &decoded_channel_count, &decoded_sample_rate) + .ok()) { + return false; + } + // Convert from float to double for the output value. + data->resize(decoded_data.size()); + for (int i = 0; i < decoded_data.size(); ++i) { + (*data)[i] = decoded_data[i]; + } + return true; +} + +bool ReadRawFloatFileToComplexVector( + const string& file_name, int row_length, + std::vector > >* data) { + data->clear(); + string data_string; + if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) { + LOG(ERROR) << "Failed to open file " << file_name; + return false; + } + float real_out; + float imag_out; + const int kBytesPerValue = 4; + CHECK_EQ(sizeof(real_out), kBytesPerValue); + std::vector > data_row; + int row_counter = 0; + int offset = 0; + const int end = data_string.size(); + while (offset < end) { + memcpy(&real_out, data_string.data() + offset, kBytesPerValue); + offset += kBytesPerValue; + memcpy(&imag_out, data_string.data() + offset, kBytesPerValue); + offset += kBytesPerValue; + if (row_counter >= row_length) { + data->push_back(data_row); + data_row.clear(); + row_counter = 0; + } + data_row.push_back(std::complex(real_out, imag_out)); + ++row_counter; + } + if (row_counter >= row_length) { + data->push_back(data_row); + } + return true; +} + +void ReadCSVFileToComplexVectorOrDie( + const string& file_name, + std::vector > >* data) { + data->clear(); + string data_string; + if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) { + LOG(FATAL) << "Failed to open file " << file_name; + return; + } + std::vector lines = str_util::Split(data_string, '\n'); + for (const string& line : lines) { + if (line == "") { + continue; + } + std::vector > data_line; + std::vector values = str_util::Split(line, ','); + for (std::vector::const_iterator i = values.begin(); + i != values.end(); ++i) { + // each element of values may be in the form: + // 0.001+0.002i, 0.001, 0.001i, -1.2i, -1.2-3.2i, 1.5, 1.5e-03+21.0i + std::vector parts; + // Find the first instance of + or - after the second character + // in the string, that does not immediately follow an 'e'. + size_t operator_index = i->find_first_of("+-", 2); + if (operator_index < i->size() && + i->substr(operator_index - 1, 1) == "e") { + operator_index = i->find_first_of("+-", operator_index + 1); + } + parts.push_back(i->substr(0, operator_index)); + if (operator_index < i->size()) { + parts.push_back(i->substr(operator_index, string::npos)); + } + + double real_part = 0.0; + double imaginary_part = 0.0; + for (std::vector::const_iterator j = parts.begin(); + j != parts.end(); ++j) { + if (j->find_first_of("ij") != string::npos) { + strings::safe_strtod((*j).c_str(), &imaginary_part); + } else { + strings::safe_strtod((*j).c_str(), &real_part); + } + } + data_line.push_back(std::complex(real_part, imaginary_part)); + } + data->push_back(data_line); + } +} + +void ReadCSVFileToArrayOrDie(const string& filename, + std::vector >* array) { + string contents; + TF_CHECK_OK(ReadFileToString(Env::Default(), filename, &contents)); + std::vector lines = str_util::Split(contents, '\n'); + contents.clear(); + + array->clear(); + std::vector values; + for (int l = 0; l < lines.size(); ++l) { + values.clear(); + CHECK(str_util::SplitAndParseAsFloats(lines[l], ',', &values)); + array->push_back(values); + } +} + +bool WriteDoubleVectorToFile(const string& file_name, + const std::vector& data) { + std::unique_ptr file; + if (!Env::Default()->NewWritableFile(file_name, &file).ok()) { + LOG(ERROR) << "Failed to open file " << file_name; + return false; + } + for (int i = 0; i < data.size(); ++i) { + if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), + sizeof(data[i]))) + .ok()) { + LOG(ERROR) << "Failed to append to file " << file_name; + return false; + } + } + if (!file->Close().ok()) { + LOG(ERROR) << "Failed to close file " << file_name; + return false; + } + return true; +} + +bool WriteFloatVectorToFile(const string& file_name, + const std::vector& data) { + std::unique_ptr file; + if (!Env::Default()->NewWritableFile(file_name, &file).ok()) { + LOG(ERROR) << "Failed to open file " << file_name; + return false; + } + for (int i = 0; i < data.size(); ++i) { + if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), + sizeof(data[i]))) + .ok()) { + LOG(ERROR) << "Failed to append to file " << file_name; + return false; + } + } + if (!file->Close().ok()) { + LOG(ERROR) << "Failed to close file " << file_name; + return false; + } + return true; +} + +bool WriteDoubleArrayToFile(const string& file_name, int size, + const double* data) { + std::unique_ptr file; + if (!Env::Default()->NewWritableFile(file_name, &file).ok()) { + LOG(ERROR) << "Failed to open file " << file_name; + return false; + } + for (int i = 0; i < size; ++i) { + if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), + sizeof(data[i]))) + .ok()) { + LOG(ERROR) << "Failed to append to file " << file_name; + return false; + } + } + if (!file->Close().ok()) { + LOG(ERROR) << "Failed to close file " << file_name; + return false; + } + return true; +} + +bool WriteFloatArrayToFile(const string& file_name, int size, + const float* data) { + std::unique_ptr file; + if (!Env::Default()->NewWritableFile(file_name, &file).ok()) { + LOG(ERROR) << "Failed to open file " << file_name; + return false; + } + for (int i = 0; i < size; ++i) { + if (!file->Append(StringPiece(reinterpret_cast(&(data[i])), + sizeof(data[i]))) + .ok()) { + LOG(ERROR) << "Failed to append to file " << file_name; + return false; + } + } + if (!file->Close().ok()) { + LOG(ERROR) << "Failed to close file " << file_name; + return false; + } + return true; +} + +bool WriteComplexVectorToRawFloatFile( + const string& file_name, + const std::vector > >& data) { + std::unique_ptr file; + if (!Env::Default()->NewWritableFile(file_name, &file).ok()) { + LOG(ERROR) << "Failed to open file " << file_name; + return false; + } + for (int i = 0; i < data.size(); ++i) { + for (int j = 0; j < data[i].size(); ++j) { + const float real_part(real(data[i][j])); + if (!file->Append(StringPiece(reinterpret_cast(&real_part), + sizeof(real_part))) + .ok()) { + LOG(ERROR) << "Failed to append to file " << file_name; + return false; + } + + const float imag_part(imag(data[i][j])); + if (!file->Append(StringPiece(reinterpret_cast(&imag_part), + sizeof(imag_part))) + .ok()) { + LOG(ERROR) << "Failed to append to file " << file_name; + return false; + } + } + } + if (!file->Close().ok()) { + LOG(ERROR) << "Failed to close file " << file_name; + return false; + } + return true; +} + +void SineWave(int sample_rate, float frequency, float duration_seconds, + std::vector* data) { + data->clear(); + for (int i = 0; i < static_cast(sample_rate * duration_seconds); ++i) { + data->push_back( + sin(2.0 * M_PI * i * frequency / static_cast(sample_rate))); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/spectrogram_test_utils.h b/tensorflow/core/kernels/spectrogram_test_utils.h new file mode 100644 index 0000000000..59a903549e --- /dev/null +++ b/tensorflow/core/kernels/spectrogram_test_utils.h @@ -0,0 +1,81 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +// Reads a wav format file into a vector of floating-point values with range +// -1.0 to 1.0. +bool ReadWaveFileToVector(const string& file_name, std::vector* data); + +// Reads a binary file containing 32-bit floating point values in the +// form [real_1, imag_1, real_2, imag_2, ...] into a rectangular array +// of complex values where row_length is the length of each inner vector. +bool ReadRawFloatFileToComplexVector( + const string& file_name, int row_length, + std::vector > >* data); + +// Reads a CSV file of numbers in the format 1.1+2.2i,1.1,2.2i,3.3j into data. +void ReadCSVFileToComplexVectorOrDie( + const string& file_name, + std::vector > >* data); + +// Reads a 2D array of floats from an ASCII text file, where each line is a row +// of the array, and elements are separated by commas. +void ReadCSVFileToArrayOrDie(const string& filename, + std::vector >* array); + +// Write a binary file containing 64-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteDoubleVectorToFile(const string& file_name, + const std::vector& data); + +// Write a binary file containing 32-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteFloatVectorToFile(const string& file_name, + const std::vector& data); + +// Write a binary file containing 64-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteDoubleArrayToFile(const string& file_name, int size, + const double* data); + +// Write a binary file containing 32-bit floating-point values for +// reading by, for example, MATLAB. +bool WriteFloatArrayToFile(const string& file_name, int size, + const float* data); + +// Write a binary file in the format read by +// ReadRawDoubleFileToComplexVector above. +bool WriteComplexVectorToRawFloatFile( + const string& file_name, + const std::vector > >& data); + +// Generate a sine wave with the provided parameters, and populate +// data with the samples. +void SineWave(int sample_rate, float frequency, float duration_seconds, + std::vector* data); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SPECTROGRAM_TEST_UTILS_H_ diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h index 30ad0c2bea..1110ef5c2a 100644 --- a/tensorflow/core/lib/core/bits.h +++ b/tensorflow/core/lib/core/bits.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_LIB_CORE_BITS_H_ #define TENSORFLOW_LIB_CORE_BITS_H_ +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -91,6 +92,18 @@ inline int Log2Ceiling64(uint64 n) { return floor + 1; } +inline uint32 NextPowerOfTwo(uint32 value) { + int exponent = Log2Ceiling(value); + DCHECK_LT(exponent, std::numeric_limits::digits); + return 1 << exponent; +} + +inline uint64 NextPowerOfTwo64(uint64 value) { + int exponent = Log2Ceiling(value); + DCHECK_LT(exponent, std::numeric_limits::digits); + return 1LL << exponent; +} + } // namespace tensorflow #endif // TENSORFLOW_LIB_CORE_BITS_H_ diff --git a/tensorflow/core/ops/audio_ops.cc b/tensorflow/core/ops/audio_ops.cc index d6dedc3820..2f55e45e37 100644 --- a/tensorflow/core/ops/audio_ops.cc +++ b/tensorflow/core/ops/audio_ops.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/bits.h" namespace tensorflow { @@ -66,6 +67,39 @@ Status EncodeWavShapeFn(InferenceContext* c) { return Status::OK(); } +Status SpectrogramShapeFn(InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); + int32 window_size; + TF_RETURN_IF_ERROR(c->GetAttr("window_size", &window_size)); + int32 stride; + TF_RETURN_IF_ERROR(c->GetAttr("stride", &stride)); + + DimensionHandle input_channels = c->Dim(input, 0); + DimensionHandle input_length = c->Dim(input, 1); + + DimensionHandle output_length; + if (!c->ValueKnown(input_length)) { + output_length = c->UnknownDim(); + } else { + const int64 input_length_value = c->Value(input_length); + const int64 length_minus_window = (input_length_value - window_size); + int64 output_length_value; + if (length_minus_window < 0) { + output_length_value = 0; + } else { + output_length_value = 1 + (length_minus_window / stride); + } + output_length = c->MakeDim(output_length_value); + } + + DimensionHandle output_channels = + c->MakeDim(1 + NextPowerOfTwo(window_size) / 2); + c->set_output(0, + c->MakeShape({input_channels, output_length, output_channels})); + return Status::OK(); +} + } // namespace REGISTER_OP("DecodeWav") @@ -121,4 +155,49 @@ sample_rate: Scalar containing the sample frequency. contents: 0-D. WAV-encoded file contents. )doc"); +REGISTER_OP("AudioSpectrogram") + .Input("input: float") + .Attr("window_size: int") + .Attr("stride: int") + .Attr("magnitude_squared: bool = false") + .Output("spectrogram: float") + .SetShapeFn(SpectrogramShapeFn) + .Doc(R"doc( +Produces a visualization of audio data over time. + +Spectrograms are a standard way of representing audio information as a series of +slices of frequency information, one slice for each window of time. By joining +these together into a sequence, they form a distinctive fingerprint of the sound +over time. + +This op expects to receive audio data as an input, stored as floats in the range +-1 to 1, together with a window width in samples, and a stride specifying how +far to move the window between slices. From this it generates a three +dimensional output. The lowest dimension has an amplitude value for each +frequency during that time slice. The next dimension is time, with successive +frequency slices. The final dimension is for the channels in the input, so a +stereo audio input would have two here for example. + +This means the layout when converted and saved as an image is rotated 90 degrees +clockwise from a typical spectrogram. Time is descending down the Y axis, and +the frequency decreases from left to right. + +Each value in the result represents the square root of the sum of the real and +imaginary parts of an FFT on the current window of samples. In this way, the +lowest dimension represents the power of each frequency in the current window, +and adjacent windows are concatenated in the next dimension. + +To get a more intuitive and visual look at what this operation does, you can run +tensorflow/examples/wav_to_spectrogram to read in an audio file and save out the +resulting spectrogram as a PNG image. + +input: Float representation of audio data. +window_size: How wide the input window is in samples. For the highest efficiency + this should be a power of two, but other values are accepted. +stride: How widely apart the center of adjacent sample windows should be. +magnitude_squared: Whether to return the squared magnitude or just the + magnitude. Using squared magnitude can avoid extra calculations. +spectrogram: 3D representation of the audio frequencies as an image. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD index 0857010f7c..62af852173 100644 --- a/tensorflow/core/platform/default/build_config/BUILD +++ b/tensorflow/core/platform/default/build_config/BUILD @@ -92,6 +92,7 @@ cc_library( "//tensorflow/core:protos_cc", "@com_googlesource_code_re2//:re2", "@farmhash_archive//:farmhash", + "@fft2d//:fft2d", "@highwayhash//:sip_hash", "@png_archive//:png", ], diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc index 03eb076f30..8373eb1f9e 100644 --- a/tensorflow/core/util/command_line_flags.cc +++ b/tensorflow/core/util/command_line_flags.cc @@ -93,6 +93,22 @@ bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, return false; } +bool ParseFloatFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag, + float* dst, bool* value_parsing_ok) { + *value_parsing_ok = true; + if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) { + char extra; + if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + } // namespace Flag::Flag(const char* name, tensorflow::int32* dst, const string& usage_text) @@ -116,6 +132,12 @@ Flag::Flag(const char* name, string* dst, const string& usage_text) string_value_(dst), usage_text_(usage_text) {} +Flag::Flag(const char* name, float* dst, const string& usage_text) + : name_(name), + type_(TYPE_FLOAT), + float_value_(dst), + usage_text_(usage_text) {} + bool Flag::Parse(string arg, bool* value_parsing_ok) const { bool result = false; if (type_ == TYPE_INT) { @@ -126,6 +148,8 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); } else if (type_ == TYPE_STRING) { result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + } else if (type_ == TYPE_FLOAT) { + result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok); } return result; } @@ -195,6 +219,10 @@ bool Flag::Parse(string arg, bool* value_parsing_ok) const { type_name = "string"; flag_string = strings::Printf("--%s=\"%s\"", flag.name_.c_str(), flag.string_value_->c_str()); + } else if (flag.type_ == Flag::TYPE_FLOAT) { + type_name = "float"; + flag_string = + strings::Printf("--%s=%f", flag.name_.c_str(), *flag.float_value_); } strings::Appendf(&usage_text, "\t%-33s\t%s\t%s\n", flag_string.c_str(), type_name, flag.usage_text_.c_str()); diff --git a/tensorflow/core/util/command_line_flags.h b/tensorflow/core/util/command_line_flags.h index 2c77d7874f..f349df16fd 100644 --- a/tensorflow/core/util/command_line_flags.h +++ b/tensorflow/core/util/command_line_flags.h @@ -65,6 +65,7 @@ class Flag { Flag(const char* name, int64* dst1, const string& usage_text); Flag(const char* name, bool* dst, const string& usage_text); Flag(const char* name, string* dst, const string& usage_text); + Flag(const char* name, float* dst, const string& usage_text); private: friend class Flags; @@ -72,11 +73,12 @@ class Flag { bool Parse(string arg, bool* value_parsing_ok) const; string name_; - enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING } type_; + enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_; int* int_value_; int64* int64_value_; bool* bool_value_; string* string_value_; + float* float_value_; string usage_text_; }; diff --git a/tensorflow/core/util/command_line_flags_test.cc b/tensorflow/core/util/command_line_flags_test.cc index b002e35899..62025463af 100644 --- a/tensorflow/core/util/command_line_flags_test.cc +++ b/tensorflow/core/util/command_line_flags_test.cc @@ -32,29 +32,35 @@ std::vector CharPointerVectorFromStrings( } return result; } -} +} // namespace TEST(CommandLineFlagsTest, BasicUsage) { int some_int = 10; int64 some_int64 = 21474836470; // max int32 is 2147483647 bool some_switch = false; string some_name = "something"; - int argc = 5; - std::vector argv_strings = { - "program_name", "--some_int=20", "--some_int64=214748364700", - "--some_switch", "--some_name=somethingelse"}; + float some_float = -23.23f; + int argc = 6; + std::vector argv_strings = {"program_name", + "--some_int=20", + "--some_int64=214748364700", + "--some_switch", + "--some_name=somethingelse", + "--some_float=42.0"}; std::vector argv_array = CharPointerVectorFromStrings(argv_strings); bool parsed_ok = Flags::Parse(&argc, argv_array.data(), {Flag("some_int", &some_int, "some int"), Flag("some_int64", &some_int64, "some int64"), Flag("some_switch", &some_switch, "some switch"), - Flag("some_name", &some_name, "some name")}); + Flag("some_name", &some_name, "some name"), + Flag("some_float", &some_float, "some float")}); EXPECT_EQ(true, parsed_ok); EXPECT_EQ(20, some_int); EXPECT_EQ(214748364700, some_int64); EXPECT_EQ(true, some_switch); EXPECT_EQ("somethingelse", some_name); + EXPECT_NEAR(42.0f, some_float, 1e-5f); EXPECT_EQ(argc, 1); } @@ -85,6 +91,21 @@ TEST(CommandLineFlagsTest, BadBoolValue) { EXPECT_EQ(argc, 1); } +TEST(CommandLineFlagsTest, BadFloatValue) { + float some_float = -23.23f; + int argc = 2; + std::vector argv_strings = {"program_name", + "--some_float=notanumber"}; + std::vector argv_array = CharPointerVectorFromStrings(argv_strings); + bool parsed_ok = + Flags::Parse(&argc, argv_array.data(), + {Flag("some_float", &some_float, "some float")}); + + EXPECT_EQ(false, parsed_ok); + EXPECT_NEAR(-23.23f, some_float, 1e-5f); + EXPECT_EQ(argc, 1); +} + // Return whether str==pat, but allowing any whitespace in pat // to match zero or more whitespace characters in str. static bool MatchWithAnyWhitespace(const string &str, const string &pat) { @@ -111,6 +132,8 @@ TEST(CommandLineFlagsTest, UsageString) { int64 some_int64 = 21474836470; // max int32 is 2147483647 bool some_switch = false; string some_name = "something"; + // Don't test float in this case, because precision is hard to predict and + // match against, and we don't want a flakey test. const string tool_name = "some_tool_name"; string usage = Flags::Usage(tool_name + "", {Flag("some_int", &some_int, "some int"), diff --git a/tensorflow/examples/wav_to_spectrogram/BUILD b/tensorflow/examples/wav_to_spectrogram/BUILD new file mode 100644 index 0000000000..1e72324fb0 --- /dev/null +++ b/tensorflow/examples/wav_to_spectrogram/BUILD @@ -0,0 +1,68 @@ +# Description: +# TensorFlow C++ inference example for labeling images. + +package( + default_visibility = ["//tensorflow:internal"], + features = [ + "-layering_check", + "-parse_headers", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "wav_to_spectrogram_lib", + srcs = [ + "wav_to_spectrogram.cc", + ], + hdrs = [ + "wav_to_spectrogram.h", + ], + deps = [ + "//tensorflow/cc:cc_ops", + "//tensorflow/core:framework_internal", + "//tensorflow/core:tensorflow", + ], +) + +cc_binary( + name = "wav_to_spectrogram", + srcs = [ + "main.cc", + ], + deps = [ + ":wav_to_spectrogram_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "wav_to_spectrogram_test", + size = "medium", + srcs = ["wav_to_spectrogram_test.cc"], + deps = [ + ":wav_to_spectrogram_lib", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "bin/**", + "gen/**", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/examples/wav_to_spectrogram/README.md b/tensorflow/examples/wav_to_spectrogram/README.md new file mode 100644 index 0000000000..7f7eb43700 --- /dev/null +++ b/tensorflow/examples/wav_to_spectrogram/README.md @@ -0,0 +1,49 @@ +# TensorFlow Spectrogram Example + +This example shows how you can load audio from a .wav file, convert it to a +spectrogram, and then save it out as a PNG image. A spectrogram is a +visualization of the frequencies in sound over time, and can be useful as a +feature for neural network recognition on noise or speech. + +## Building + +To build it, run this command: + +```bash +bazel build tensorflow/examples/wav_to_spectrogram/... +``` + +That should build a binary executable that you can then run like this: + +```bash +bazel-bin/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram +``` + +This uses a default test audio file that's part of the TensorFlow source code, +and writes out the image to the current directory as spectrogram.png. + +## Options + +To load your own audio, you need to supply a .wav file in LIN16 format, and use +the `--input_audio` flag to pass in the path. + +To control how the spectrogram is created, you can specify the `--window_size` +and `--stride` arguments, which control how wide the window used to estimate +frequencies is, and how widely adjacent windows are spaced. + +The `--output_image` flag sets the path to save the image file to. This is +always written out in PNG format, even if you specify a different file +extension. + +If your result seems too dark, try using the `--brightness` flag to make the +output image easier to see. + +Here's an example of how to use all of them together: + +```bash +bazel-bin/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram \ +--input_wav=/tmp/my_audio.wav \ +--window=1024 \ +--stride=512 \ +--output_image=/tmp/my_spectrogram.png +``` diff --git a/tensorflow/examples/wav_to_spectrogram/main.cc b/tensorflow/examples/wav_to_spectrogram/main.cc new file mode 100644 index 0000000000..539e6c4fe4 --- /dev/null +++ b/tensorflow/examples/wav_to_spectrogram/main.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +int main(int argc, char* argv[]) { + // These are the command-line flags the program can understand. + // They define where the graph and input data is located, and what kind of + // input the model expects. If you train your own model, or use something + // other than inception_v3, then you'll need to update these. + tensorflow::string input_wav = + "tensorflow/core/kernels/spectrogram_test_data/short_test_segment.wav"; + tensorflow::int32 window_size = 256; + tensorflow::int32 stride = 128; + float brightness = 64.0f; + tensorflow::string output_image = "spectrogram.png"; + std::vector flag_list = { + tensorflow::Flag("input_wav", &input_wav, "audio file to load"), + tensorflow::Flag("window_size", &window_size, + "frequency sample window width"), + tensorflow::Flag("stride", &stride, + "how far apart to place frequency windows"), + tensorflow::Flag("brightness", &brightness, + "controls how bright the output image is"), + tensorflow::Flag("output_image", &output_image, + "where to save the spectrogram image to"), + }; + tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + + // We need to call this to set up global state for TensorFlow. + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return -1; + } + + tensorflow::Status wav_status = WavToSpectrogram( + input_wav, window_size, stride, brightness, output_image); + if (!wav_status.ok()) { + LOG(ERROR) << "WavToSpectrogram failed with " << wav_status; + return -1; + } + + return 0; +} diff --git a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.cc b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.cc new file mode 100644 index 0000000000..c69a359637 --- /dev/null +++ b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h" + +#include + +#include "tensorflow/cc/ops/audio_ops.h" +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/default_device.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::DT_FLOAT; +using tensorflow::DT_UINT8; +using tensorflow::Output; +using tensorflow::TensorShape; + +// Runs a TensorFlow graph to convert an audio file into a visualization. +tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav, + tensorflow::int32 window_size, + tensorflow::int32 stride, float brightness, + const tensorflow::string& output_image) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace tensorflow::ops; // NOLINT(build/namespaces) + // The following block creates a TensorFlow graph that: + // - Reads and decodes the audio file into a tensor of float samples. + // - Creates a float spectrogram from those samples. + // - Scales, clamps, and converts that spectrogram to 0 to 255 uint8's. + // - Reshapes the tensor so that it's [height, width, 1] for imaging. + // - Encodes it as a PNG stream and saves it out to a file. + Output file_reader = ReadFile(root.WithOpName("input_wav"), input_wav); + DecodeWav wav_decoder = + DecodeWav(root.WithOpName("wav_decoder"), file_reader); + Output spectrogram = AudioSpectrogram(root.WithOpName("spectrogram"), + wav_decoder.audio, window_size, stride); + Output brightness_placeholder = + Placeholder(root.WithOpName("brightness_placeholder"), DT_FLOAT, + Placeholder::Attrs().Shape(TensorShape({}))); + Output mul = Mul(root.WithOpName("mul"), spectrogram, brightness_placeholder); + Output min_const = Const(root.WithOpName("min_const"), 255.0f); + Output min = Minimum(root.WithOpName("min"), mul, min_const); + Output cast = Cast(root.WithOpName("cast"), min, DT_UINT8); + Output expand_dims_const = Const(root.WithOpName("expand_dims_const"), -1); + Output expand_dims = + ExpandDims(root.WithOpName("expand_dims"), cast, expand_dims_const); + Output squeeze = Squeeze(root.WithOpName("squeeze"), expand_dims, + Squeeze::Attrs().SqueezeDims({0})); + Output png_encoder = EncodePng(root.WithOpName("png_encoder"), squeeze); + WriteFile file_writer = + WriteFile(root.WithOpName("output_image"), output_image, png_encoder); + tensorflow::GraphDef graph; + TF_RETURN_IF_ERROR(root.ToGraphDef(&graph)); + + // Build a session object from this graph definition. The power of TensorFlow + // is that you can reuse complex computations like this, so usually we'd run a + // lot of different inputs through it. In this example, we're just doing a + // one-off run, so we'll create it and then use it immediately. + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions())); + TF_RETURN_IF_ERROR(session->Create(graph)); + + // We're passing in the brightness as an input, so create a tensor to hold the + // value. + tensorflow::Tensor brightness_tensor(DT_FLOAT, TensorShape({})); + brightness_tensor.scalar()() = brightness; + + // Run the session to analyze the audio and write out the file. + TF_RETURN_IF_ERROR( + session->Run({{"brightness_placeholder", brightness_tensor}}, {}, + {"output_image"}, nullptr)); + return tensorflow::Status::OK(); +} diff --git a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h new file mode 100644 index 0000000000..fa8cb0abe9 --- /dev/null +++ b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ +#define THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +// Runs a TensorFlow graph to convert an audio file into a visualization. Takes +// in the path to the audio file, the window size and stride parameters +// controlling the spectrogram creation, the brightness scaling to use, and a +// path to save the output PNG file to. +tensorflow::Status WavToSpectrogram(const tensorflow::string& input_wav, + tensorflow::int32 window_size, + tensorflow::int32 stride, float brightness, + const tensorflow::string& output_image); + +#endif // THIRD_PARTY_TENSORFLOW_EXAMPLES_WAV_TO_SPECTROGRAM_WAV_TO_SPECTROGRAM_H_ diff --git a/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram_test.cc b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram_test.cc new file mode 100644 index 0000000000..e599711445 --- /dev/null +++ b/tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram_test.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/examples/wav_to_spectrogram/wav_to_spectrogram.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/wav/wav_io.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +TEST(WavToSpectrogramTest, WavToSpectrogramTest) { + const tensorflow::string input_wav = + tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "input_wav.wav"); + const tensorflow::string output_image = tensorflow::io::JoinPath( + tensorflow::testing::TmpDir(), "output_image.png"); + float audio[8] = {-1.0f, 0.0f, 1.0f, 0.0f, -1.0f, 0.0f, 1.0f, 0.0f}; + tensorflow::string wav_string; + TF_ASSERT_OK( + tensorflow::wav::EncodeAudioAsS16LEWav(audio, 44100, 1, 8, &wav_string)); + TF_ASSERT_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), + input_wav, wav_string)); + TF_ASSERT_OK(WavToSpectrogram(input_wav, 4, 4, 64.0f, output_image)); + TF_EXPECT_OK(tensorflow::Env::Default()->FileExists(output_image)); +} diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 3e049724f6..a8e6ecdbf0 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -79,11 +79,13 @@ genrule( srcs = [ "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", + "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", + "@fft2d//:fft/readme.txt", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@highwayhash//:LICENSE", @@ -106,11 +108,13 @@ genrule( srcs = [ "//third_party/hadoop:LICENSE.txt", "//third_party/eigen3:LICENSE", + "//third_party/fft2d:LICENSE", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", + "@fft2d//:fft/readme.txt", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@highwayhash//:LICENSE", diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 2a96e80ccb..c17a7f7fb1 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -91,12 +91,14 @@ filegroup( name = "licenses", data = [ "//third_party/eigen3:LICENSE", + "//third_party/fft2d:LICENSE", "//third_party/hadoop:LICENSE.txt", "@boringssl//:LICENSE", "@com_googlesource_code_re2//:LICENSE", "@curl//:COPYING", "@eigen_archive//:COPYING.MPL2", "@farmhash_archive//:COPYING", + "@fft2d//:fft/readme.txt", "@gemmlowp//:LICENSE", "@gif_archive//:COPYING", "@grpc//:LICENSE", diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 7bcdb1613d..dd42c69dd2 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -500,6 +500,16 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name="zlib", actual="@zlib_archive//:zlib",) + native.new_http_archive( + name = "fft2d", + urls = [ + "http://bazel-mirror.storage.googleapis.com/www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + "http://www.kurims.kyoto-u.ac.jp/~ooura/fft.tgz", + ], + sha256 = "52bb637c70b971958ec79c9c8752b1df5ff0218a4db4510e60826e0cb79b5296", + build_file = str(Label("//third_party/fft2d:fft2d.BUILD")), + ) + temp_workaround_http_archive( name="snappy", urls=[ diff --git a/third_party/fft2d/BUILD b/third_party/fft2d/BUILD new file mode 100644 index 0000000000..93ea06e81b --- /dev/null +++ b/third_party/fft2d/BUILD @@ -0,0 +1,30 @@ +# Headers for 2D Fast Fourier Transform package +# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html +# This is a separate package because the original downloaded archive doesn't +# contain any header files. + +package( + default_visibility = ["//visibility:public"], +) + +# Unrestricted use; can only distribute original package. +# See fft/readme.txt +licenses(["notice"]) + +exports_files(["LICENSE"]) + +cc_library( + name = "fft2d_headers", + srcs = ["fft.h"], +) + +objc_library( + name = "fft2d_headersd_ios", + srcs = ["fft.h"], +) + +# Export the source code so that it could be compiled for Andoid native apps. +filegroup( + name = "fft2d_headers_srcs", + srcs = ["fft.h"], +) diff --git a/third_party/fft2d/LICENSE b/third_party/fft2d/LICENSE new file mode 100644 index 0000000000..2bd85506a8 --- /dev/null +++ b/third_party/fft2d/LICENSE @@ -0,0 +1,3 @@ +Copyright(C) 1997,2001 Takuya OOURA (email: ooura@kurims.kyoto-u.ac.jp). +You may use, copy, modify this code for any purpose and +without fee. You may distribute this ORIGINAL package. diff --git a/third_party/fft2d/fft.h b/third_party/fft2d/fft.h new file mode 100644 index 0000000000..252cc01fec --- /dev/null +++ b/third_party/fft2d/fft.h @@ -0,0 +1,36 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declarations for 1D FFT routines in third_party/fft2d/fft. + +#ifndef THIRD_PARTY_FFT2D_FFT_H__ +#define THIRD_PARTY_FFT2D_FFT_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +extern void cdft(int, int, double *, int *, double *); +extern void rdft(int, int, double *, int *, double *); +extern void ddct(int, int, double *, int *, double *); +extern void ddst(int, int, double *, int *, double *); +extern void dfct(int, double *, double *, int *, double *); +extern void dfst(int, double *, double *, int *, double *); + +#ifdef __cplusplus +} +#endif + +#endif // THIRD_PARTY_FFT2D_FFT_H__ diff --git a/third_party/fft2d/fft2d.BUILD b/third_party/fft2d/fft2d.BUILD new file mode 100644 index 0000000000..3dbd36aec0 --- /dev/null +++ b/third_party/fft2d/fft2d.BUILD @@ -0,0 +1,36 @@ +# 2D Fast Fourier Transform package +# from http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html + +package( + default_visibility = ["//visibility:public"], +) + +# Unrestricted use; can only distribute original package. +licenses(["notice"]) + +exports_files(["fft/readme.txt"]) + +FFT2D_SRCS = [ + "fft/fftsg.c", +] + +# This is the main 2D FFT library. The 2D FFTs in this library call +# 1D FFTs. In addition, fast DCTs are provided for the special case +# of 8x8 and 16x16. This code in this library is referred to as +# "Version II" on http://momonga.t.u-tokyo.ac.jp/~ooura/fft.html. +cc_library( + name = "fft2d", + srcs = FFT2D_SRCS, + linkopts = ["-lm"], +) + +objc_library( + name = "fft2d_ios", + srcs = FFT2D_SRCS, +) + +# Export the source code so that it could be compiled for Andoid native apps. +filegroup( + name = "fft2d_srcs", + srcs = FFT2D_SRCS, +) -- cgit v1.2.3