aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-14 10:37:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-14 10:41:15 -0800
commitc174994e74b55cdcfcab82fc55de97c3e54e9ea8 (patch)
tree3838adb18e1846f0b22bf9f070f10459dd73a584
parentcc617064703957aa699bc01c65c7f03f1e6f6b22 (diff)
Tensorflow driver to execute the given tf graphdef, and generate output
PiperOrigin-RevId: 185708396
-rw-r--r--tensorflow/contrib/lite/BUILD1
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add.pb26
-rw-r--r--tensorflow/contrib/lite/testing/BUILD26
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.cc189
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver.h75
-rw-r--r--tensorflow/contrib/lite/testing/tf_driver_test.cc56
6 files changed, 373 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 3520a4eaf0..44c4a7e2ca 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -10,6 +10,7 @@ exports_files(["LICENSE"])
exports_files(glob([
"testdata/*.bin",
+ "testdata/*.pb",
"models/testdata/*",
]))
diff --git a/tensorflow/contrib/lite/testdata/multi_add.pb b/tensorflow/contrib/lite/testdata/multi_add.pb
new file mode 100644
index 0000000000..e95a20841f
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add.pb
@@ -0,0 +1,26 @@
+
+I
+a Placeholder" /device:CPU:0*
+shape:*
+dtype0
+I
+b Placeholder" /device:CPU:0*
+dtype0*
+shape:
+I
+c Placeholder" /device:CPU:0*
+dtype0*
+shape:
+I
+d Placeholder" /device:CPU:0*
+dtype0*
+shape:
+&
+iAddbc" /device:CPU:0*
+T0
+&
+xAddai" /device:CPU:0*
+T0
+&
+yAdddi" /device:CPU:0*
+T0" \ No newline at end of file
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index b949045128..9351cf9d64 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -195,6 +195,32 @@ cc_binary(
],
)
+cc_library(
+ name = "tf_driver",
+ srcs = ["tf_driver.cc"],
+ hdrs = ["tf_driver.h"],
+ deps = [
+ ":split",
+ ":test_runner",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+)
+
+cc_test(
+ name = "tf_driver_test",
+ size = "small",
+ srcs = ["tf_driver_test.cc"],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"],
+ deps = [
+ ":tf_driver",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
tf_cc_test(
name = "generated_examples_zip_test",
size = "large",
diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc
new file mode 100644
index 0000000000..da6c6ce7b1
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tf_driver.cc
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tf_driver.h"
+
+#include <fstream>
+#include <iostream>
+
+#include "tensorflow/contrib/lite/testing/split.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tflite {
+namespace testing {
+
+namespace {
+
+tensorflow::Tensor CreateTensor(const tensorflow::DataType type,
+ const std::vector<int64_t>& dim) {
+ tensorflow::TensorShape shape{gtl::ArraySlice<int64>{
+ reinterpret_cast<const int64*>(dim.data()), dim.size()}};
+ return {type, shape};
+}
+
+template <typename T>
+void FillTensorWithData(tensorflow::Tensor* tensor, const string& csv_values) {
+ auto data = tensor->flat<T>();
+
+ const auto& values = testing::Split<T>(csv_values, ",");
+ for (int i = 0; i < values.size(); i++) {
+ data(i) = values[i];
+ }
+}
+
+template <typename T>
+void FillTensorWithZeros(tensorflow::Tensor* tensor) {
+ auto data = tensor->flat<T>();
+ for (int i = 0; i < tensor->NumElements(); i++) {
+ data(i) = 0;
+ }
+}
+
+template <typename T>
+string TensorDataToCsvString(const tensorflow::Tensor& tensor) {
+ std::stringstream stream;
+ const auto& data = tensor.flat<T>();
+ if (data.size() == 0) {
+ return "";
+ }
+ stream << data(0);
+ for (int i = 1; i < data.size(); i++) {
+ stream << "," << data(i);
+ }
+ return stream.str();
+}
+
+} // namespace
+
+TfDriver::TfDriver(const std::vector<string>& input_layer,
+ const std::vector<string>& input_layer_type,
+ const std::vector<string>& input_layer_shape,
+ const std::vector<string>& output_layer)
+ : input_names_(input_layer), output_names_(output_layer) {
+ CHECK_EQ(input_layer.size(), input_layer_type.size());
+ CHECK_EQ(input_layer.size(), input_layer_shape.size());
+
+ input_ids_.resize(input_layer.size());
+ input_tensors_.reserve(input_layer.size());
+ input_types_.resize(input_layer.size());
+ input_shapes_.resize(input_layer.size());
+ for (int i = 0; i < input_layer.size(); i++) {
+ input_ids_[i] = i;
+ input_tensors_[input_layer[i]] = {};
+ CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i]));
+ input_shapes_[i] = Split<int64_t>(input_layer_shape[i], ",");
+ }
+
+ output_ids_.resize(output_layer.size());
+ output_tensors_.reserve(output_layer.size());
+ for (int i = 0; i < output_layer.size(); i++) {
+ output_ids_[i] = i;
+ }
+}
+
+void TfDriver::LoadModel(const string& bin_file_path) {
+ if (!IsValid()) return;
+ std::cout << std::endl << "Loading model: " << bin_file_path << std::endl;
+ std::ifstream model(bin_file_path);
+ if (model.fail()) {
+ Invalidate("Failed to find the model");
+ return;
+ }
+
+ tensorflow::GraphDef graphdef;
+ if (!graphdef.ParseFromIstream(&model)) {
+ Invalidate("Failed to parse tensorflow graphdef");
+ return;
+ }
+
+ tensorflow::SessionOptions options;
+ session_.reset(tensorflow::NewSession(options));
+ auto status = session_->Create(graphdef);
+ if (!status.ok()) {
+ Invalidate("Failed to create session");
+ }
+}
+
+void TfDriver::SetInput(int id, const string& csv_values) {
+ if (!IsValid()) return;
+
+ auto tensor = CreateTensor(input_types_[id], input_shapes_[id]);
+ switch (input_types_[id]) {
+ case tensorflow::DT_FLOAT: {
+ FillTensorWithData<float>(&tensor, csv_values);
+ break;
+ }
+ case tensorflow::DT_INT32: {
+ FillTensorWithData<int32_t>(&tensor, csv_values);
+ break;
+ }
+ default:
+ fprintf(stderr, "Unsupported type %d in SetInput\n", input_types_[id]);
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+ input_tensors_[input_names_[id]] = tensor;
+}
+
+void TfDriver::ResetTensor(int id) {
+ if (!IsValid()) return;
+ auto tensor = input_tensors_[input_names_[id]];
+ switch (input_types_[id]) {
+ case tensorflow::DT_FLOAT: {
+ FillTensorWithZeros<float>(&tensor);
+ break;
+ }
+ case tensorflow::DT_INT32: {
+ FillTensorWithZeros<int32_t>(&tensor);
+ break;
+ }
+ default:
+ fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]);
+ Invalidate("Unsupported tensor data type");
+ return;
+ }
+}
+
+void TfDriver::ReshapeTensor(int id, const string& csv_values) {
+ input_shapes_[id] = Split<int64_t>(csv_values, ",");
+ input_tensors_[input_names_[id]] =
+ CreateTensor(input_types_[id], input_shapes_[id]);
+ ResetTensor(id);
+}
+
+string TfDriver::ReadOutput(int id) {
+ if (!IsValid()) return "";
+ switch (output_tensors_[id].dtype()) {
+ case tensorflow::DT_FLOAT:
+ return TensorDataToCsvString<float>(output_tensors_[id]);
+ case tensorflow::DT_INT32:
+ return TensorDataToCsvString<int32_t>(output_tensors_[id]);
+ default:
+ fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]);
+ Invalidate("Unsupported tensor data type");
+ return "";
+ }
+}
+
+void TfDriver::Invoke() {
+ if (!IsValid()) return;
+ auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()},
+ output_names_, {}, &output_tensors_);
+ if (!status.ok()) {
+ Invalidate("Failed to invoke interpreter");
+ }
+}
+
+} // namespace testing
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/contrib/lite/testing/tf_driver.h
new file mode 100644
index 0000000000..2928e57282
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tf_driver.h
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_
+#define TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/testing/split.h"
+#include "tensorflow/contrib/lite/testing/test_runner.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tflite {
+namespace testing {
+
+// A test runner that feeds inputs into Tensorflow and generates outputs.
+class TfDriver : public TestRunner {
+ public:
+ explicit TfDriver(const std::vector<string>& input_layer,
+ const std::vector<string>& input_layer_type,
+ const std::vector<string>& input_layer_shape,
+ const std::vector<string>& output_layer);
+ ~TfDriver() override {}
+
+ void LoadModel(const string& bin_file_path) override;
+ void SetInput(int id, const string& csv_values) override;
+ void Invoke() override;
+ string ReadOutput(int id);
+
+ const std::vector<int>& GetInputs() override { return input_ids_; }
+ const std::vector<int>& GetOutputs() override { return output_ids_; }
+ void ReshapeTensor(int id, const string& csv_values) override;
+ // Note: ResetTensor only works for input tensor.
+ void ResetTensor(int id) override;
+
+ // no-op. SetInput will overwrite existing data .
+ void AllocateTensors() override {}
+ // no-op. Tf driver is not supposed to check the results.
+ void SetExpectation(int id, const string& csv_values) override {}
+ // tf driver is not supposed to check the results.
+ bool CheckResults() override { return false; }
+
+ private:
+ std::unique_ptr<tensorflow::Session> session_;
+ std::vector<int> input_ids_;
+ std::vector<string> input_names_;
+ std::vector<std::vector<int64_t>> input_shapes_;
+ std::vector<tensorflow::DataType> input_types_;
+ std::unordered_map<string, tensorflow::Tensor> input_tensors_;
+
+ std::vector<int> output_ids_;
+ std::vector<string> output_names_;
+ std::vector<::tensorflow::Tensor> output_tensors_;
+};
+
+} // namespace testing
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_
diff --git a/tensorflow/contrib/lite/testing/tf_driver_test.cc b/tensorflow/contrib/lite/testing/tf_driver_test.cc
new file mode 100644
index 0000000000..c0faa4676a
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/tf_driver_test.cc
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/testing/tf_driver.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace testing {
+namespace {
+
+using ::testing::ElementsAre;
+
+TEST(TfDriverTest, SimpleTest) {
+ std::unique_ptr<TfDriver> runner(
+ new TfDriver({"a", "b", "c", "d"}, {"float", "float", "float", "float"},
+ {"1,8,8,3", "1,8,8,3", "1,8,8,3", "1,8,8,3"}, {"x", "y"}));
+
+ runner->LoadModel(
+ "third_party/tensorflow/contrib/lite/testdata/multi_add.pb");
+ EXPECT_TRUE(runner->IsValid()) << runner->GetErrorMessage();
+
+ ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3));
+ ASSERT_THAT(runner->GetOutputs(), ElementsAre(0, 1));
+
+ for (int i : {0, 1, 2, 3}) {
+ runner->ReshapeTensor(i, "1,2,2,1");
+ }
+ ASSERT_TRUE(runner->IsValid());
+
+ runner->SetInput(0, "0.1,0.2,0.3,0.4");
+ runner->SetInput(1, "0.001,0.002,0.003,0.004");
+ runner->SetInput(2, "0.001,0.002,0.003,0.004");
+ runner->SetInput(3, "0.01,0.02,0.03,0.04");
+ runner->ResetTensor(2);
+ runner->Invoke();
+
+ ASSERT_EQ(runner->ReadOutput(0), "0.101,0.202,0.303,0.404");
+ ASSERT_EQ(runner->ReadOutput(1), "0.011,0.022,0.033,0.044");
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite