diff options
author | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-01-29 14:49:13 -0800 |
---|---|---|
committer | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-01-29 14:49:13 -0800 |
commit | 8e03944589542bd64559d68989bca4a4705eed93 (patch) | |
tree | 335e453a4774fdde0a699e83401d2d9a9de84d45 /tensorflow/contrib/tensorrt/shape_fn | |
parent | e01844e65e0dbd2682a894946bec7f072d36fa27 (diff) |
Fix build (part1):
1. Changed includes of "NvInfer.h" to "tensorrt/include/NvInfer.h"
2. Remove build target "tensorrt_ops.so" (src file doesn't exist and the
target is not used anywhere)
3. Add missing '#if GOOGLE_TENSORRT's
4. Use tf_cuda_library instead of cc_library for some targets to get the
tf_copts naturally.
5. Revert the changes that was accidentally made (by merging with
upstream head) from configure.py
6. Replace exception with LOG(FATAL) in
tensorflow/contrib/tensorrt/convert/convert_nodes.cc as exception is
not supported.
7. Revert the reinterprete_cast change in
tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
8. Fix minor formatting and naming issues according to the style guide.
Diffstat (limited to 'tensorflow/contrib/tensorrt/shape_fn')
-rw-r--r-- | tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 7624237efe..fef63c64d8 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -24,11 +24,12 @@ limitations under the License. namespace tensorflow { namespace shape_inference { + tensorflow::Status TRTEngineOpShapeInference(InferenceContext* c) { - tensorflow::tensorrt::Logger gLogger; + tensorflow::tensorrt::Logger logger; string serialized_engine; c->GetAttr("serialized_engine", &serialized_engine); - nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(gLogger); + nvinfer1::IRuntime* infer = nvinfer1::createInferRuntime(logger); nvinfer1::ICudaEngine* trt_engine = infer->deserializeCudaEngine( serialized_engine.c_str(), serialized_engine.size(), nullptr); |