aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/shape_fn
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-29 14:49:13 -0800
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-29 14:49:13 -0800
commit8e03944589542bd64559d68989bca4a4705eed93 (patch)
tree335e453a4774fdde0a699e83401d2d9a9de84d45 /tensorflow/contrib/tensorrt/shape_fn
parente01844e65e0dbd2682a894946bec7f072d36fa27 (diff)
Fix build (part1):
1. Changed includes of "NvInfer.h" to "tensorrt/include/NvInfer.h" 2. Remove build target "tensorrt_ops.so" (src file doesn't exist and the target is not used anywhere) 3. Add missing '#if GOOGLE_TENSORRT's 4. Use tf_cuda_library instead of cc_library for some targets to get the tf_copts naturally. 5. Revert the changes that was accidentally made (by merging with upstream head) from configure.py 6. Replace exception with LOG(FATAL) in tensorflow/contrib/tensorrt/convert/convert_nodes.cc as exception is not supported. 7. Revert the reinterprete_cast change in tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc 8. Fix minor formatting and naming issues according to the style guide.
Diffstat (limited to 'tensorflow/contrib/tensorrt/shape_fn')
-rw-r--r--tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
index 7624237efe..fef63c64d8 100644
--- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
+++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc
@@ -24,11 +24,12 @@ limitations under the License.
namespace tensorflow {
namespace shape_inference {
+
tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) {
- tensorflow::tensorrt::Logger gLogger;
+ tensorflow::tensorrt::Logger logger;
string serialized_engine;
c->GetAttr("serialized_engine", &serialized_engine);
- nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger);
+ nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger);
nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine(
serialized_engine.c_str(), serialized_engine.size(), nullptr);