aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/spectrogram_test_utils.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/spectrogram_test_utils.cc')
-rw-r--r--tensorflow/core/kernels/spectrogram_test_utils.cc288
1 files changed, 288 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/spectrogram_test_utils.cc b/tensorflow/core/kernels/spectrogram_test_utils.cc
new file mode 100644
index 0000000000..a2141c649f
--- /dev/null
+++ b/tensorflow/core/kernels/spectrogram_test_utils.cc
@@ -0,0 +1,288 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/spectrogram_test_utils.h"
+
+#include <math.h>
+#include <stddef.h>
+
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/wav/wav_io.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data) {
+ string wav_data;
+ if (!ReadFileToString(Env::Default(), file_name, &wav_data).ok()) {
+ LOG(ERROR) << "Wave file read failed for " << file_name;
+ return false;
+ }
+ std::vector<float> decoded_data;
+ uint32 decoded_sample_count;
+ uint16 decoded_channel_count;
+ uint32 decoded_sample_rate;
+ if (!wav::DecodeLin16WaveAsFloatVector(
+ wav_data, &decoded_data, &decoded_sample_count,
+ &decoded_channel_count, &decoded_sample_rate)
+ .ok()) {
+ return false;
+ }
+ // Convert from float to double for the output value.
+ data->resize(decoded_data.size());
+ for (int i = 0; i < decoded_data.size(); ++i) {
+ (*data)[i] = decoded_data[i];
+ }
+ return true;
+}
+
+bool ReadRawFloatFileToComplexVector(
+ const string& file_name, int row_length,
+ std::vector<std::vector<std::complex<double> > >* data) {
+ data->clear();
+ string data_string;
+ if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
+ LOG(ERROR) << "Failed to open file " << file_name;
+ return false;
+ }
+ float real_out;
+ float imag_out;
+ const int kBytesPerValue = 4;
+ CHECK_EQ(sizeof(real_out), kBytesPerValue);
+ std::vector<std::complex<double> > data_row;
+ int row_counter = 0;
+ int offset = 0;
+ const int end = data_string.size();
+ while (offset < end) {
+ memcpy(&real_out, data_string.data() + offset, kBytesPerValue);
+ offset += kBytesPerValue;
+ memcpy(&imag_out, data_string.data() + offset, kBytesPerValue);
+ offset += kBytesPerValue;
+ if (row_counter >= row_length) {
+ data->push_back(data_row);
+ data_row.clear();
+ row_counter = 0;
+ }
+ data_row.push_back(std::complex<double>(real_out, imag_out));
+ ++row_counter;
+ }
+ if (row_counter >= row_length) {
+ data->push_back(data_row);
+ }
+ return true;
+}
+
+void ReadCSVFileToComplexVectorOrDie(
+ const string& file_name,
+ std::vector<std::vector<std::complex<double> > >* data) {
+ data->clear();
+ string data_string;
+ if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
+ LOG(FATAL) << "Failed to open file " << file_name;
+ return;
+ }
+ std::vector<string> lines = str_util::Split(data_string, '\n');
+ for (const string& line : lines) {
+ if (line == "") {
+ continue;
+ }
+ std::vector<std::complex<double> > data_line;
+ std::vector<string> values = str_util::Split(line, ',');
+ for (std::vector<string>::const_iterator i = values.begin();
+ i != values.end(); ++i) {
+ // each element of values may be in the form:
+ // 0.001+0.002i, 0.001, 0.001i, -1.2i, -1.2-3.2i, 1.5, 1.5e-03+21.0i
+ std::vector<string> parts;
+ // Find the first instance of + or - after the second character
+ // in the string, that does not immediately follow an 'e'.
+ size_t operator_index = i->find_first_of("+-", 2);
+ if (operator_index < i->size() &&
+ i->substr(operator_index - 1, 1) == "e") {
+ operator_index = i->find_first_of("+-", operator_index + 1);
+ }
+ parts.push_back(i->substr(0, operator_index));
+ if (operator_index < i->size()) {
+ parts.push_back(i->substr(operator_index, string::npos));
+ }
+
+ double real_part = 0.0;
+ double imaginary_part = 0.0;
+ for (std::vector<string>::const_iterator j = parts.begin();
+ j != parts.end(); ++j) {
+ if (j->find_first_of("ij") != string::npos) {
+ strings::safe_strtod((*j).c_str(), &imaginary_part);
+ } else {
+ strings::safe_strtod((*j).c_str(), &real_part);
+ }
+ }
+ data_line.push_back(std::complex<double>(real_part, imaginary_part));
+ }
+ data->push_back(data_line);
+ }
+}
+
+void ReadCSVFileToArrayOrDie(const string& filename,
+ std::vector<std::vector<float> >* array) {
+ string contents;
+ TF_CHECK_OK(ReadFileToString(Env::Default(), filename, &contents));
+ std::vector<string> lines = str_util::Split(contents, '\n');
+ contents.clear();
+
+ array->clear();
+ std::vector<float> values;
+ for (int l = 0; l < lines.size(); ++l) {
+ values.clear();
+ CHECK(str_util::SplitAndParseAsFloats(lines[l], ',', &values));
+ array->push_back(values);
+ }
+}
+
+bool WriteDoubleVectorToFile(const string& file_name,
+ const std::vector<double>& data) {
+ std::unique_ptr<WritableFile> file;
+ if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
+ LOG(ERROR) << "Failed to open file " << file_name;
+ return false;
+ }
+ for (int i = 0; i < data.size(); ++i) {
+ if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
+ sizeof(data[i])))
+ .ok()) {
+ LOG(ERROR) << "Failed to append to file " << file_name;
+ return false;
+ }
+ }
+ if (!file->Close().ok()) {
+ LOG(ERROR) << "Failed to close file " << file_name;
+ return false;
+ }
+ return true;
+}
+
+bool WriteFloatVectorToFile(const string& file_name,
+ const std::vector<float>& data) {
+ std::unique_ptr<WritableFile> file;
+ if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
+ LOG(ERROR) << "Failed to open file " << file_name;
+ return false;
+ }
+ for (int i = 0; i < data.size(); ++i) {
+ if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
+ sizeof(data[i])))
+ .ok()) {
+ LOG(ERROR) << "Failed to append to file " << file_name;
+ return false;
+ }
+ }
+ if (!file->Close().ok()) {
+ LOG(ERROR) << "Failed to close file " << file_name;
+ return false;
+ }
+ return true;
+}
+
+bool WriteDoubleArrayToFile(const string& file_name, int size,
+ const double* data) {
+ std::unique_ptr<WritableFile> file;
+ if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
+ LOG(ERROR) << "Failed to open file " << file_name;
+ return false;
+ }
+ for (int i = 0; i < size; ++i) {
+ if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
+ sizeof(data[i])))
+ .ok()) {
+ LOG(ERROR) << "Failed to append to file " << file_name;
+ return false;
+ }
+ }
+ if (!file->Close().ok()) {
+ LOG(ERROR) << "Failed to close file " << file_name;
+ return false;
+ }
+ return true;
+}
+
+bool WriteFloatArrayToFile(const string& file_name, int size,
+ const float* data) {
+ std::unique_ptr<WritableFile> file;
+ if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
+ LOG(ERROR) << "Failed to open file " << file_name;
+ return false;
+ }
+ for (int i = 0; i < size; ++i) {
+ if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
+ sizeof(data[i])))
+ .ok()) {
+ LOG(ERROR) << "Failed to append to file " << file_name;
+ return false;
+ }
+ }
+ if (!file->Close().ok()) {
+ LOG(ERROR) << "Failed to close file " << file_name;
+ return false;
+ }
+ return true;
+}
+
+bool WriteComplexVectorToRawFloatFile(
+ const string& file_name,
+ const std::vector<std::vector<std::complex<double> > >& data) {
+ std::unique_ptr<WritableFile> file;
+ if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
+ LOG(ERROR) << "Failed to open file " << file_name;
+ return false;
+ }
+ for (int i = 0; i < data.size(); ++i) {
+ for (int j = 0; j < data[i].size(); ++j) {
+ const float real_part(real(data[i][j]));
+ if (!file->Append(StringPiece(reinterpret_cast<const char*>(&real_part),
+ sizeof(real_part)))
+ .ok()) {
+ LOG(ERROR) << "Failed to append to file " << file_name;
+ return false;
+ }
+
+ const float imag_part(imag(data[i][j]));
+ if (!file->Append(StringPiece(reinterpret_cast<const char*>(&imag_part),
+ sizeof(imag_part)))
+ .ok()) {
+ LOG(ERROR) << "Failed to append to file " << file_name;
+ return false;
+ }
+ }
+ }
+ if (!file->Close().ok()) {
+ LOG(ERROR) << "Failed to close file " << file_name;
+ return false;
+ }
+ return true;
+}
+
+void SineWave(int sample_rate, float frequency, float duration_seconds,
+ std::vector<double>* data) {
+ data->clear();
+ for (int i = 0; i < static_cast<int>(sample_rate * duration_seconds); ++i) {
+ data->push_back(
+ sin(2.0 * M_PI * i * frequency / static_cast<double>(sample_rate)));
+ }
+}
+
+} // namespace tensorflow