diff options
author | 2018-05-03 08:20:51 -0700 | |
---|---|---|
committer | 2018-05-03 08:20:51 -0700 | |
commit | 03de4a4a6cbfab49c2921d0cac5ccac31c0815f8 (patch) | |
tree | c942148ec022c72219ec152a2a44587d97563506 /tensorflow/contrib/tensorrt/plugin | |
parent | a2d35bddc7a1ab58b859ef396501472d7986ff0f (diff) |
Move/rename the plugin factory test file; delete duplicate test file; fix minor formatting issues.
Diffstat (limited to 'tensorflow/contrib/tensorrt/plugin')
-rw-r--r-- | tensorflow/contrib/tensorrt/plugin/trt_plugin.h | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h | 9 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc (renamed from tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc) | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h | 1 |
4 files changed, 14 insertions, 8 deletions
diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h index dca377c2d2..d80ec44372 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin.h @@ -19,6 +19,7 @@ limitations under the License. #include <iostream> #include <unordered_map> #include <vector> + #include "tensorflow/core/platform/types.h" #if GOOGLE_CUDA @@ -35,9 +36,11 @@ namespace tensorrt { // PluginDeserializeFunc & PluginConstructFunc through PluginFactoryTensorRT class PluginTensorRT : public nvinfer1::IPlugin { public: - PluginTensorRT() {}; + PluginTensorRT() {} PluginTensorRT(const void* serialized_data, size_t length); + virtual const string& GetPluginName() const = 0; + virtual bool Finalize() = 0; virtual bool SetAttribute(const string& key, const void* ptr, @@ -53,6 +56,7 @@ class PluginTensorRT : public nvinfer1::IPlugin { const size_t size); virtual size_t getSerializationSize() override; + virtual void serialize(void* buffer) override; protected: diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h index a088ffb842..6d2992bbbb 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory.h @@ -19,8 +19,9 @@ limitations under the License. #include <memory> #include <mutex> #include <unordered_map> -#include "trt_plugin.h" -#include "trt_plugin_utils.h" + +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" +#include "tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h" #if GOOGLE_CUDA #if GOOGLE_TENSORRT @@ -54,12 +55,12 @@ class PluginFactoryTensorRT : public nvinfer1::IPluginFactory { protected: std::unordered_map<string, - std::pair<PluginDeserializeFunc, PluginConstructFunc> > + std::pair<PluginDeserializeFunc, PluginConstructFunc>> plugin_registry_; // TODO(jie): Owned plugin should be associated with different sessions; // should really hand ownership of plugins to resource management; - std::vector<std::unique_ptr<PluginTensorRT> > owned_plugins_; + std::vector<std::unique_ptr<PluginTensorRT>> owned_plugins_; std::mutex instance_m_; }; diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc index ae5a3e8742..c5b0e75eb1 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugins_test.cc +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_factory_test.cc @@ -81,7 +81,7 @@ StubPlugin* CreateStubPluginDeserialize(const void* serialized_data, return new StubPlugin(serialized_data, length); } -class PluginTest : public ::testing::Test { +class TrtPluginFactoryTest : public ::testing::Test { public: bool RegisterStubPlugin() { if (PluginFactoryTensorRT::GetInstance()->IsPlugin( @@ -94,7 +94,7 @@ class PluginTest : public ::testing::Test { } }; -TEST_F(PluginTest, Registration) { +TEST_F(TrtPluginFactoryTest, Registration) { EXPECT_FALSE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); EXPECT_TRUE(RegisterStubPlugin()); @@ -103,7 +103,7 @@ TEST_F(PluginTest, Registration) { PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); } -TEST_F(PluginTest, CreationDeletion) { +TEST_F(TrtPluginFactoryTest, CreationDeletion) { EXPECT_TRUE(RegisterStubPlugin()); ASSERT_TRUE( PluginFactoryTensorRT::GetInstance()->IsPlugin(StubPlugin::kPluginName)); diff --git a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h index a94c67bba0..4ff6fbedb4 100644 --- a/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h +++ b/tensorflow/contrib/tensorrt/plugin/trt_plugin_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_PLUGIN_TRT_PLUGIN_UTILS #include <functional> + #include "tensorflow/contrib/tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/types.h" |