diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-14 17:51:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-14 17:55:32 -0800 |
commit | 61207fad1557da22a0d06bb22e35f23bd9ca4938 (patch) | |
tree | b1ed41f647d75ee11d92cea693da1bc0c9a658ff /tensorflow | |
parent | 5b3b689cc307632d38f08b1c7f6952bbe7a7ad7e (diff) |
Extract ReadOutput method to test runner
PiperOrigin-RevId: 185773920
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/lite/testing/BUILD | 16 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/join.h | 42 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/join_test.cc | 43 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/test_runner.h | 4 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/test_runner_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tf_driver.cc | 11 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tf_driver.h | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_driver.h | 1 |
8 files changed, 110 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index 8739ffb34c..87945353cf 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -123,6 +123,21 @@ cc_test( ) cc_library( + name = "join", + hdrs = ["join.h"], +) + +cc_test( + name = "join_test", + size = "small", + srcs = ["join_test.cc"], + deps = [ + ":join", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( name = "tflite_driver", srcs = ["tflite_driver.cc"], hdrs = ["tflite_driver.h"], @@ -201,6 +216,7 @@ cc_library( srcs = ["tf_driver.cc"], hdrs = ["tf_driver.h"], deps = [ + ":join", ":split", ":test_runner", "//tensorflow/core:core_cpu", diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h new file mode 100644 index 0000000000..ce8c072a21 --- /dev/null +++ b/tensorflow/contrib/lite/testing/join.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ + +#include <cstdlib> +#include <sstream> +#include <string> + +namespace tflite { +namespace testing { + +// Join a list of data separated by delimieter. +template <typename T> +string Join(T* data, size_t len, const string& delimiter) { + if (len == 0 || data == nullptr) { + return ""; + } + std::stringstream result; + result << data[0]; + for (int i = 1; i < len; i++) { + result << delimiter << data[i]; + } + return result.str(); +} + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ diff --git a/tensorflow/contrib/lite/testing/join_test.cc b/tensorflow/contrib/lite/testing/join_test.cc new file mode 100644 index 0000000000..bd04528381 --- /dev/null +++ b/tensorflow/contrib/lite/testing/join_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/join.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +TEST(JoinTest, JoinInt) { + std::vector<int> data = {1, 2, 3}; + EXPECT_EQ(Join(data.data(), data.size(), ","), "1,2,3"); +} + +TEST(JoinTest, JoinFloat) { + float data[] = {1.0, -3, 2.3, 1e-5}; + EXPECT_EQ(Join(data, 4, " "), "1 -3 2.3 1e-05"); +} + +TEST(JoinTest, JoinNullData) { EXPECT_THAT(Join<int>(nullptr, 3, ","), ""); } + +TEST(JoinTest, JoinZeroData) { + std::vector<int> data; + EXPECT_THAT(Join(data.data(), 0, ","), ""); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 60eaafa474..05770beee2 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -68,6 +68,10 @@ class TestRunner { // satisfied. virtual bool CheckResults() = 0; + // Read contents of tensor into csv format. + // The given 'id' is guaranteed to be one of the ids returned by GetOutputs(). + virtual string ReadOutput(int id) = 0; + // Set the base path for loading models. void SetModelBaseDir(const string& path) { model_base_dir_ = path; diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc index f712a5347a..3f04aa20bd 100644 --- a/tensorflow/contrib/lite/testing/test_runner_test.cc +++ b/tensorflow/contrib/lite/testing/test_runner_test.cc @@ -31,6 +31,7 @@ class ConcreteTestRunner : public TestRunner { void ResetTensor(int id) override {} void SetInput(int id, const string& csv_values) override {} void SetExpectation(int id, const string& csv_values) override {} + string ReadOutput(int id) override { return ""; } void Invoke() override {} bool CheckResults() override { return true; } bool CheckFloatSizes(size_t bytes, size_t values) { diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc index da6c6ce7b1..2c253bb198 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.cc +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -17,6 +17,7 @@ limitations under the License. #include <fstream> #include <iostream> +#include "tensorflow/contrib/lite/testing/join.h" #include "tensorflow/contrib/lite/testing/split.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -52,16 +53,8 @@ void FillTensorWithZeros(tensorflow::Tensor* tensor) { template <typename T> string TensorDataToCsvString(const tensorflow::Tensor& tensor) { - std::stringstream stream; const auto& data = tensor.flat<T>(); - if (data.size() == 0) { - return ""; - } - stream << data(0); - for (int i = 1; i < data.size(); i++) { - stream << "," << data(i); - } - return stream.str(); + return Join(data.data(), data.size(), ","); } } // namespace diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/contrib/lite/testing/tf_driver.h index 2928e57282..b766f85c4d 100644 --- a/tensorflow/contrib/lite/testing/tf_driver.h +++ b/tensorflow/contrib/lite/testing/tf_driver.h @@ -41,7 +41,7 @@ class TfDriver : public TestRunner { void LoadModel(const string& bin_file_path) override; void SetInput(int id, const string& csv_values) override; void Invoke() override; - string ReadOutput(int id); + string ReadOutput(int id) override; const std::vector<int>& GetInputs() override { return input_ids_; } const std::vector<int>& GetOutputs() override { return output_ids_; } diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 25689a9fb4..02b7de1534 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -45,6 +45,7 @@ class TfLiteDriver : public TestRunner { void SetExpectation(int id, const string& csv_values) override; void Invoke() override; bool CheckResults() override; + string ReadOutput(int id) override { return "no-op"; } private: class Expectation; |