aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/plugin
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-05-03 08:20:51 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-05-03 08:20:51 -0700
commit03de4a4a6cbfab49c2921d0cac5ccac31c0815f8 (patch)
treec942148ec022c72219ec152a2a44587d97563506 /tensorflow/contrib/tensorrt/plugin
parenta2d35bddc7a1ab58b859ef396501472d7986ff0f (diff)
Move/rename the plugin factory test file; delete duplicate test file; fix minor formatting issues.
Diffstat (limited to 'tensorflow/contrib/tensorrt/plugin')
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin.h6
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h9
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc (renamed from tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc)6
-rw-r--r--tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h1
4 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
index dca377c2d2..d80ec44372 100644
--- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <iostream>
#include <unordered_map>
#include <vector>
+
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
@@ -35,9 +36,11 @@ namespace tensorrt {
// PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT
class PluginTensorRT : public nvinfer1::IPlugin {
public:
- PluginTensorRT() {};
+ PluginTensorRT() {}
PluginTensorRT(const void* serialized_data, size_t length);
+
virtual const string& GetPluginName() const = 0;
+
virtual bool Finalize() = 0;
virtual bool SetAttribute(const string& key, const void* ptr,
@@ -53,6 +56,7 @@ class PluginTensorRT : public nvinfer1::IPlugin {
const size_t size);
virtual size_t getSerializationSize() override;
+
virtual void serialize(void* buffer) override;
protected:
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
index a088ffb842..6d2992bbbb 100644
--- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h
@@ -19,8 +19,9 @@ limitations under the License.
#include <memory>
#include <mutex>
#include <unordered_map>
-#include "trt_plugin.h"
-#include "trt_plugin_utils.h"
+
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
+#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -54,12 +55,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory {
protected:
std::unordered_map<string,
- std::pair<PluginDeserializeFunc, PluginConstructFunc> >
+ std::pair<PluginDeserializeFunc, PluginConstructFunc>>
plugin_registry_;
// TODO(jie): Owned plugin should be associated with different sessions;
// should really hand ownership of plugins to resource management;
- std::vector<std::unique_ptr<PluginTensorRT> > owned_plugins_;
+ std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_;
std::mutex instance_m_;
};
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc
index ae5a3e8742..c5b0e75eb1 100644
--- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc
@@ -81,7 +81,7 @@ StubPlugin* CreateStubPluginDeserialize(const void* serialized_data,
return new StubPlugin(serialized_data, length);
}
-class PluginTest : public ::testing::Test {
+class TrtPluginFactoryTest : public ::testing::Test {
public:
bool RegisterStubPlugin() {
if (PluginFactoryTensorRT::GetInstance()->IsPlugin(
@@ -94,7 +94,7 @@ class PluginTest : public ::testing::Test {
}
};
-TEST_F(PluginTest, Registration) {
+TEST_F(TrtPluginFactoryTest, Registration) {
EXPECT_FALSE(
PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
EXPECT_TRUE(RegisterStubPlugin());
@@ -103,7 +103,7 @@ TEST_F(PluginTest, Registration) {
PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
}
-TEST_F(PluginTest, CreationDeletion) {
+TEST_F(TrtPluginFactoryTest, CreationDeletion) {
EXPECT_TRUE(RegisterStubPlugin());
ASSERT_TRUE(
PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName));
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
index a94c67bba0..4ff6fbedb4 100644
--- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
+++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS
#include <functional>
+
#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h"
#include "tensorflow/core/platform/types.h"