aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-18 09:16:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 09:19:25 -0700
commit5362d9dbedea6f93b1065e31cd0ee16ab110c7e1 (patch)
tree29f3a465600fbd31cffe7f1f27feb41998f80f19
parentc0dbd7e456a18d8431930b18c399f78dc66dda0c (diff)
Data holder for the eager delegate.
PiperOrigin-RevId: 205088322
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD28
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.cc44
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data.h46
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc42
4 files changed, 160 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 9d8c20e96f..23d8f543e5 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -39,6 +39,34 @@ cc_test(
)
cc_library(
+ name = "delegate_data",
+ srcs = ["delegate_data.cc"],
+ hdrs = ["delegate_data.h"],
+ deps = [
+ ":buffer_map",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/common_runtime/eager:context",
+ ],
+)
+
+cc_test(
+ name = "delegate_data_test",
+ size = "small",
+ srcs = ["delegate_data_test.cc"],
+ tags = [
+ "tflite_not_portable",
+ ],
+ deps = [
+ ":delegate_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
new file mode 100644
index 0000000000..b2516379e7
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tflite {
+tensorflow::Status DelegateData::Create(std::unique_ptr<DelegateData>* data) {
+ std::vector<tensorflow::Device*> devices;
+
+ TF_RETURN_IF_ERROR(tensorflow::DeviceFactory::AddDevices(
+ tensorflow::SessionOptions(), "/device:cpu:*", &devices));
+
+ std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
+ new tensorflow::DeviceMgr(devices));
+ // Note that Rendezvous is ref-counted so it will be automatically deleted.
+ tensorflow::Rendezvous* rendezvous =
+ new tensorflow::IntraProcessRendezvous(device_mgr.get());
+ data->reset(new DelegateData(new tensorflow::EagerContext(
+ tensorflow::SessionOptions(),
+ tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
+ /*async=*/false, std::move(device_mgr), rendezvous)));
+ return tensorflow::Status();
+}
+
+DelegateData::DelegateData(tensorflow::EagerContext* eager_context)
+ : eager_context_(eager_context) {}
+
+DelegateData::~DelegateData() {}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data.h b/tensorflow/contrib/lite/delegates/eager/delegate_data.h
new file mode 100644
index 0000000000..053d174c08
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
+
+#include "tensorflow/contrib/lite/delegates/eager/buffer_map.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+
+namespace tflite {
+
+// Data kept by the Eager delegate for the lifetime of an Interpreter.
+class DelegateData {
+ public:
+ // Create a new DelegateData, initialized with a newly-created EagerContext.
+ static tensorflow::Status Create(std::unique_ptr<DelegateData>* data);
+
+ ~DelegateData();
+
+ // The EagerContext that is required for execution of Eager Ops.
+ tensorflow::EagerContext* GetEagerContext() { return eager_context_.get(); }
+
+ // Map from TF Lite tensor index to TensorFlow tensor.
+ BufferMap* GetBufferMap() { return &buffer_map_; }
+
+ private:
+ explicit DelegateData(tensorflow::EagerContext* eager_context);
+
+ std::unique_ptr<tensorflow::EagerContext> eager_context_;
+ BufferMap buffer_map_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_DATA_H_
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
new file mode 100644
index 0000000000..cf8bc27d04
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+namespace {
+
+TEST(DelegateDataTest, Basic) {
+ std::unique_ptr<DelegateData> data;
+ // We only check for success because it is hard to make initialization fail.
+ // It only happens if we manage to not link the CPU device factory into the
+ // binary.
+ EXPECT_TRUE(DelegateData::Create(&data).ok());
+
+ EXPECT_NE(data->GetEagerContext(), nullptr);
+ EXPECT_NE(data->GetBufferMap(), nullptr);
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}