diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-14 10:37:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-14 10:41:15 -0800 |
commit | c174994e74b55cdcfcab82fc55de97c3e54e9ea8 (patch) | |
tree | 3838adb18e1846f0b22bf9f070f10459dd73a584 | |
parent | cc617064703957aa699bc01c65c7f03f1e6f6b22 (diff) |
Tensorflow driver to execute the given tf graphdef, and generate output
PiperOrigin-RevId: 185708396
-rw-r--r-- | tensorflow/contrib/lite/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testdata/multi_add.pb | 26 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/BUILD | 26 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tf_driver.cc | 189 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tf_driver.h | 75 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tf_driver_test.cc | 56 |
6 files changed, 373 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 3520a4eaf0..44c4a7e2ca 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -10,6 +10,7 @@ exports_files(["LICENSE"]) exports_files(glob([ "testdata/*.bin", + "testdata/*.pb", "models/testdata/*", ])) diff --git a/tensorflow/contrib/lite/testdata/multi_add.pb b/tensorflow/contrib/lite/testdata/multi_add.pb new file mode 100644 index 0000000000..e95a20841f --- /dev/null +++ b/tensorflow/contrib/lite/testdata/multi_add.pb @@ -0,0 +1,26 @@ + +I +aPlaceholder"
/device:CPU:0* +shape:* +dtype0 +I +bPlaceholder"
/device:CPU:0* +dtype0* +shape: +I +cPlaceholder"
/device:CPU:0* +dtype0* +shape: +I +dPlaceholder"
/device:CPU:0* +dtype0* +shape: +& +iAddbc"
/device:CPU:0* +T0 +& +xAddai"
/device:CPU:0* +T0 +& +yAdddi"
/device:CPU:0* +T0"
\ No newline at end of file diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b949045128..9351cf9d64 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -195,6 +195,32 @@ cc_binary( ], ) +cc_library( + name = "tf_driver", + srcs = ["tf_driver.cc"], + hdrs = ["tf_driver.h"], + deps = [ + ":split", + ":test_runner", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "tf_driver_test", + size = "small", + srcs = ["tf_driver_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], + deps = [ + ":tf_driver", + "@com_google_googletest//:gtest_main", + ], +) + tf_cc_test( name = "generated_examples_zip_test", size = "large", diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc new file mode 100644 index 0000000000..da6c6ce7b1 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -0,0 +1,189 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tf_driver.h" + +#include <fstream> +#include <iostream> + +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tflite { +namespace testing { + +namespace { + +tensorflow::Tensor CreateTensor(const tensorflow::DataType type, + const std::vector<int64_t>& dim) { + tensorflow::TensorShape shape{gtl::ArraySlice<int64>{ + reinterpret_cast<const int64*>(dim.data()), dim.size()}}; + return {type, shape}; +} + +template <typename T> +void FillTensorWithData(tensorflow::Tensor* tensor, const string& csv_values) { + auto data = tensor->flat<T>(); + + const auto& values = testing::Split<T>(csv_values, ","); + for (int i = 0; i < values.size(); i++) { + data(i) = values[i]; + } +} + +template <typename T> +void FillTensorWithZeros(tensorflow::Tensor* tensor) { + auto data = tensor->flat<T>(); + for (int i = 0; i < tensor->NumElements(); i++) { + data(i) = 0; + } +} + +template <typename T> +string TensorDataToCsvString(const tensorflow::Tensor& tensor) { + std::stringstream stream; + const auto& data = tensor.flat<T>(); + if (data.size() == 0) { + return ""; + } + stream << data(0); + for (int i = 1; i < data.size(); i++) { + stream << "," << data(i); + } + return stream.str(); +} + +} // namespace + +TfDriver::TfDriver(const std::vector<string>& input_layer, + const std::vector<string>& input_layer_type, + const std::vector<string>& input_layer_shape, + const std::vector<string>& output_layer) + : input_names_(input_layer), output_names_(output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); + + input_ids_.resize(input_layer.size()); + input_tensors_.reserve(input_layer.size()); + input_types_.resize(input_layer.size()); + input_shapes_.resize(input_layer.size()); + for (int i = 0; i < input_layer.size(); i++) { + input_ids_[i] = i; + input_tensors_[input_layer[i]] = {}; + CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i])); + input_shapes_[i] = Split<int64_t>(input_layer_shape[i], ","); + } + + output_ids_.resize(output_layer.size()); + output_tensors_.reserve(output_layer.size()); + for (int i = 0; i < output_layer.size(); i++) { + output_ids_[i] = i; + } +} + +void TfDriver::LoadModel(const string& bin_file_path) { + if (!IsValid()) return; + std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; + std::ifstream model(bin_file_path); + if (model.fail()) { + Invalidate("Failed to find the model"); + return; + } + + tensorflow::GraphDef graphdef; + if (!graphdef.ParseFromIstream(&model)) { + Invalidate("Failed to parse tensorflow graphdef"); + return; + } + + tensorflow::SessionOptions options; + session_.reset(tensorflow::NewSession(options)); + auto status = session_->Create(graphdef); + if (!status.ok()) { + Invalidate("Failed to create session"); + } +} + +void TfDriver::SetInput(int id, const string& csv_values) { + if (!IsValid()) return; + + auto tensor = CreateTensor(input_types_[id], input_shapes_[id]); + switch (input_types_[id]) { + case tensorflow::DT_FLOAT: { + FillTensorWithData<float>(&tensor, csv_values); + break; + } + case tensorflow::DT_INT32: { + FillTensorWithData<int32_t>(&tensor, csv_values); + break; + } + default: + fprintf(stderr, "Unsupported type %d in SetInput\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return; + } + input_tensors_[input_names_[id]] = tensor; +} + +void TfDriver::ResetTensor(int id) { + if (!IsValid()) return; + auto tensor = input_tensors_[input_names_[id]]; + switch (input_types_[id]) { + case tensorflow::DT_FLOAT: { + FillTensorWithZeros<float>(&tensor); + break; + } + case tensorflow::DT_INT32: { + FillTensorWithZeros<int32_t>(&tensor); + break; + } + default: + fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfDriver::ReshapeTensor(int id, const string& csv_values) { + input_shapes_[id] = Split<int64_t>(csv_values, ","); + input_tensors_[input_names_[id]] = + CreateTensor(input_types_[id], input_shapes_[id]); + ResetTensor(id); +} + +string TfDriver::ReadOutput(int id) { + if (!IsValid()) return ""; + switch (output_tensors_[id].dtype()) { + case tensorflow::DT_FLOAT: + return TensorDataToCsvString<float>(output_tensors_[id]); + case tensorflow::DT_INT32: + return TensorDataToCsvString<int32_t>(output_tensors_[id]); + default: + fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return ""; + } +} + +void TfDriver::Invoke() { + if (!IsValid()) return; + auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, + output_names_, {}, &output_tensors_); + if (!status.ok()) { + Invalidate("Failed to invoke interpreter"); + } +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/contrib/lite/testing/tf_driver.h new file mode 100644 index 0000000000..2928e57282 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ + +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session.h" + +namespace tflite { +namespace testing { + +// A test runner that feeds inputs into Tensorflow and generates outputs. +class TfDriver : public TestRunner { + public: + explicit TfDriver(const std::vector<string>& input_layer, + const std::vector<string>& input_layer_type, + const std::vector<string>& input_layer_shape, + const std::vector<string>& output_layer); + ~TfDriver() override {} + + void LoadModel(const string& bin_file_path) override; + void SetInput(int id, const string& csv_values) override; + void Invoke() override; + string ReadOutput(int id); + + const std::vector<int>& GetInputs() override { return input_ids_; } + const std::vector<int>& GetOutputs() override { return output_ids_; } + void ReshapeTensor(int id, const string& csv_values) override; + // Note: ResetTensor only works for input tensor. + void ResetTensor(int id) override; + + // no-op. SetInput will overwrite existing data . + void AllocateTensors() override {} + // no-op. Tf driver is not supposed to check the results. + void SetExpectation(int id, const string& csv_values) override {} + // tf driver is not supposed to check the results. + bool CheckResults() override { return false; } + + private: + std::unique_ptr<tensorflow::Session> session_; + std::vector<int> input_ids_; + std::vector<string> input_names_; + std::vector<std::vector<int64_t>> input_shapes_; + std::vector<tensorflow::DataType> input_types_; + std::unordered_map<string, tensorflow::Tensor> input_tensors_; + + std::vector<int> output_ids_; + std::vector<string> output_names_; + std::vector<::tensorflow::Tensor> output_tensors_; +}; + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tf_driver_test.cc b/tensorflow/contrib/lite/testing/tf_driver_test.cc new file mode 100644 index 0000000000..c0faa4676a --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tf_driver.h" + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; + +TEST(TfDriverTest, SimpleTest) { + std::unique_ptr<TfDriver> runner( + new TfDriver({"a", "b", "c", "d"}, {"float", "float", "float", "float"}, + {"1,8,8,3", "1,8,8,3", "1,8,8,3", "1,8,8,3"}, {"x", "y"})); + + runner->LoadModel( + "third_party/tensorflow/contrib/lite/testdata/multi_add.pb"); + EXPECT_TRUE(runner->IsValid()) << runner->GetErrorMessage(); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(0, 1)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + runner->ResetTensor(2); + runner->Invoke(); + + ASSERT_EQ(runner->ReadOutput(0), "0.101,0.202,0.303,0.404"); + ASSERT_EQ(runner->ReadOutput(1), "0.011,0.022,0.033,0.044"); +} + +} // namespace +} // namespace testing +} // namespace tflite |