aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-14 17:51:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-14 17:55:32 -0800
commit61207fad1557da22a0d06bb22e35f23bd9ca4938 (patch)
treeb1ed41f647d75ee11d92cea693da1bc0c9a658ff /tensorflow
parent5b3b689cc307632d38f08b1c7f6952bbe7a7ad7e (diff)
Extract ReadOutput method to test runner
PiperOrigin-RevId: 185773920
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/testing/BUILD16
-rw-r--r--tensorflow/contrib/lite/testing/join.h42
-rw-r--r--tensorflow/contrib/lite/testing/join_test.cc43
-rw-r--r--tensorflow/contrib/lite/testing/test_runner.h4
-rw-r--r--tensorflow/contrib/lite/testing/test_runner_test.cc1
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.cc11
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.h2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h1
8 files changed, 110 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 8739ffb34c..87945353cf 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -123,6 +123,21 @@ cc_test(
)
cc_library(
+ name = "join",
+ hdrs = ["join.h"],
+)
+
+cc_test(
+ name = "join_test",
+ size = "small",
+ srcs = ["join_test.cc"],
+ deps = [
+ ":join",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+cc_library(
name = "tflite_driver",
srcs = ["tflite_driver.cc"],
hdrs = ["tflite_driver.h"],
@@ -201,6 +216,7 @@ cc_library(
srcs = ["tf_driver.cc"],
hdrs = ["tf_driver.h"],
deps = [
+ ":join",
":split",
":test_runner",
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/contrib/lite/testing/join.h b/tensorflow/contrib/lite/testing/join.h
new file mode 100644
index 0000000000..ce8c072a21
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/join.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_
+
+#include <cstdlib>
+#include <sstream>
+#include <string>
+
+namespace tflite {
+namespace testing {
+
+// Join a list of data separated by delimieter.
+template <typename T>
+string Join(T* data, size_t len, const string& delimiter) {
+ if (len == 0 || data == nullptr) {
+ return "";
+ }
+ std::stringstream result;
+ result << data[0];
+ for (int i = 1; i < len; i++) {
+ result << delimiter << data[i];
+ }
+ return result.str();
+}
+
+} // namespace testing
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_
diff --git a/tensorflow/contrib/lite/testing/join_test.cc b/tensorflow/contrib/lite/testing/join_test.cc
new file mode 100644
index 0000000000..bd04528381
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/join_test.cc
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/join.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+TEST(JoinTest, JoinInt) {
+ std::vector<int> data = {1, 2, 3};
+ EXPECT_EQ(Join(data.data(), data.size(), ","), "1,2,3");
+}
+
+TEST(JoinTest, JoinFloat) {
+ float data[] = {1.0, -3, 2.3, 1e-5};
+ EXPECT_EQ(Join(data, 4, " "), "1 -3 2.3 1e-05");
+}
+
+TEST(JoinTest, JoinNullData) { EXPECT_THAT(Join<int>(nullptr, 3, ","), ""); }
+
+TEST(JoinTest, JoinZeroData) {
+ std::vector<int> data;
+ EXPECT_THAT(Join(data.data(), 0, ","), "");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h
index 60eaafa474..05770beee2 100644
--- a/tensorflow/contrib/lite/testing/test_runner.h
+++ b/tensorflow/contrib/lite/testing/test_runner.h
@@ -68,6 +68,10 @@ class TestRunner {
// satisfied.
virtual bool CheckResults() = 0;
+ // Read contents of tensor into csv format.
+ // The given 'id' is guaranteed to be one of the ids returned by GetOutputs().
+ virtual string ReadOutput(int id) = 0;
+
// Set the base path for loading models.
void SetModelBaseDir(const string& path) {
model_base_dir_ = path;
diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc
index f712a5347a..3f04aa20bd 100644
--- a/tensorflow/contrib/lite/testing/test_runner_test.cc
+++ b/tensorflow/contrib/lite/testing/test_runner_test.cc
@@ -31,6 +31,7 @@ class ConcreteTestRunner : public TestRunner {
void ResetTensor(int id) override {}
void SetInput(int id, const string& csv_values) override {}
void SetExpectation(int id, const string& csv_values) override {}
+ string ReadOutput(int id) override { return ""; }
void Invoke() override {}
bool CheckResults() override { return true; }
bool CheckFloatSizes(size_t bytes, size_t values) {
diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc
index da6c6ce7b1..2c253bb198 100644
--- a/tensorflow/contrib/lite/testing/tf_driver.cc
+++ b/tensorflow/contrib/lite/testing/tf_driver.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <fstream>
#include <iostream>
+#include "tensorflow/contrib/lite/testing/join.h"
#include "tensorflow/contrib/lite/testing/split.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -52,16 +53,8 @@ void FillTensorWithZeros(tensorflow::Tensor* tensor) {
template <typename T>
string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
- std::stringstream stream;
const auto& data = tensor.flat<T>();
- if (data.size() == 0) {
- return "";
- }
- stream << data(0);
- for (int i = 1; i < data.size(); i++) {
- stream << "," << data(i);
- }
- return stream.str();
+ return Join(data.data(), data.size(), ",");
}
} // namespace
diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/contrib/lite/testing/tf_driver.h
index 2928e57282..b766f85c4d 100644
--- a/tensorflow/contrib/lite/testing/tf_driver.h
+++ b/tensorflow/contrib/lite/testing/tf_driver.h
@@ -41,7 +41,7 @@ class TfDriver : public TestRunner {
void LoadModel(const string& bin_file_path) override;
void SetInput(int id, const string& csv_values) override;
void Invoke() override;
- string ReadOutput(int id);
+ string ReadOutput(int id) override;
const std::vector<int>& GetInputs() override { return input_ids_; }
const std::vector<int>& GetOutputs() override { return output_ids_; }
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 25689a9fb4..02b7de1534 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -45,6 +45,7 @@ class TfLiteDriver : public TestRunner {
void SetExpectation(int id, const string& csv_values) override;
void Invoke() override;
bool CheckResults() override;
+ string ReadOutput(int id) override { return "no-op"; }
private:
class Expectation;