aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/convert
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-16 11:32:10 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-08-16 11:32:10 -0700
commit0b31ebb51bcdd4b0675d3f2166f316d0f22266b3 (patch)
tree5ba5a6333bb1a4a3617a2df314c97a63c402e762 /tensorflow/contrib/tensorrt/convert
parent6b5be9a7f33462bd20bf14b0df9ca1fcb2da6bb3 (diff)
Add c++ test and remove python test.
Diffstat (limited to 'tensorflow/contrib/tensorrt/convert')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.h6
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph_test.cc140
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc3
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h5
4 files changed, 152 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.h b/tensorflow/contrib/tensorrt/convert/convert_graph.h
index 9d986e4890..3525202369 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.h
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -84,6 +85,11 @@ std::vector<int> GetLinkedTensorRTVersion();
// Return runtime time TensorRT library version information.
std::vector<int> GetLoadedTensorRTVersion();
+
+// Helper method for the conversion, expose for testing.
+std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
+ const ConversionParams& params, const EngineInfo& engine);
+
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
new file mode 100644
index 0000000000..0d3963514a
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/tensorrt/convert/convert_graph.h"
+
+#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
+#include "tensorflow/core/public/session.h"
+
+#if GOOGLE_CUDA
+#if GOOGLE_TENSORRT
+
+namespace tensorflow {
+namespace tensorrt {
+namespace convert {
+
+class FakeCluster : public grappler::Cluster {
+ public:
+ FakeCluster() : Cluster(0) {}
+
+ void SetDeviceSet(const DeviceSet* device_set) { device_set_ = device_set; }
+
+ const DeviceSet* GetDeviceSet() const override { return device_set_; }
+
+ string type() const override { return ""; }
+ Status Provision() override { return Status::OK(); }
+ Status Initialize(const grappler::GrapplerItem& item) override {
+ return Status::OK();
+ }
+ virtual Status Run(const GraphDef& graph_def,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch,
+ RunMetadata* metadata) override {
+ return Status::OK();
+ }
+
+ private:
+ const DeviceSet* device_set_;
+};
+
+TEST(ConvertGraphTest, GetDeviceAndAllocator) {
+ ConversionParams params;
+ EngineInfo engine_info;
+ {
+ // params.cluster is not set, and no gpu device is available.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(-1, result.first);
+ EXPECT_EQ(nullptr, result.second);
+ }
+
+ // Create a session with two (virtual) gpu device.
+ SessionOptions options;
+ ConfigProto* config = &options.config;
+ GPUOptions* gpu_options = config->mutable_gpu_options();
+ auto virtual_devices =
+ gpu_options->mutable_experimental()->add_virtual_devices();
+ virtual_devices->add_memory_limit_mb(200);
+ virtual_devices->add_memory_limit_mb(200);
+ std::unique_ptr<Session> session(NewSession(options));
+
+ {
+ // params.cluster is not set, should find and return first gpu id and
+ // corresponding allocator.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_0_bfc", result.second->Name());
+ }
+
+ FakeCluster cluster;
+ params.cluster = &cluster;
+ {
+ // params.cluster->GetDeviceSet() returns null, should find and return first
+ // gpu id and corresponding allocator.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_0_bfc", result.second->Name());
+ }
+
+ // Build the DeviceSet.
+ DeviceSet device_set;
+ const DeviceMgr* device_mgr = nullptr;
+ TF_ASSERT_OK(session->LocalDeviceManager(&device_mgr));
+ for (auto d : device_mgr->ListDevices()) {
+ device_set.AddDevice(d);
+ }
+ cluster.SetDeviceSet(&device_set);
+ {
+ // engine_info.device is not set, should find and return first gpu id and
+ // corresponding allocator.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_0_bfc", result.second->Name());
+ }
+
+ engine_info.device = "/GPU:1";
+ {
+ // Set to use second device.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(0, result.first);
+ EXPECT_NE(nullptr, result.second);
+ EXPECT_EQ("GPU_1_bfc", result.second->Name());
+ }
+
+ engine_info.device = "/GPU:3";
+ {
+ // Set to use nonexistent device.
+ auto result = GetDeviceAndAllocator(params, engine_info);
+ EXPECT_EQ(-1, result.first);
+ EXPECT_EQ(nullptr, result.second);
+ }
+}
+
+} // namespace convert
+} // namespace tensorrt
+} // namespace tensorflow
+
+#endif // GOOGLE_TENSORRT
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 35fa590254..07b4efd33f 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -77,6 +77,9 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
+const char* const kInputPHName = "TensorRTInputPH_";
+const char* const kOutputPHName = "TensorRTOutputPH_";
+
namespace convert {
using ::tensorflow::str_util::Split;
using ::tensorflow::strings::StrAppend;
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index a60253740f..9274027e63 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -36,8 +36,9 @@ limitations under the License.
namespace tensorflow {
namespace tensorrt {
-static const char* kInputPHName = "TensorRTInputPH_";
-static const char* kOutputPHName = "TensorRTOutputPH_";
+extern const char* const kInputPHName;
+extern const char* const kOutputPHName;
+
namespace convert {
struct EngineConnection {