aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/ops/trt_engine_op.cc')
-rw-r--r--tensorflow/contrib/tensorrt/ops/trt_engine_op.cc10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
index 383635f428..e0c7b62723 100644
--- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc
@@ -42,8 +42,14 @@ REGISTER_OP("TRTEngineOp")
.Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}")
.Attr("calibration_data: string = ''")
.Input("in_tensor: InT")
- .Output("out_tensor: OutT")
- .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
+ .Output("out_tensor: OutT");
+// TODO(jie): TF requires concrete output shape for concrete input shapes.
+// This is tricky for batch dimension, since we cannot ensure which input
+// would carry the correct batch dimension (for the current stage of the
+// implementation, we do require all input tensor to carry the same batch
+// size, but this could change in the future). Hence we disable shape
+// inference function as a workaround.
+// .SetShapeFn(shape_inference::TRTEngineOpShapeInference);
} // namespace tensorflow