diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/ops/trt_engine_op.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/ops/trt_engine_op.cc | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc index 383635f428..e0c7b62723 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_engine_op.cc @@ -42,8 +42,14 @@ REGISTER_OP("TRTEngineOp") .Attr("precision_mode: {'FP32', 'FP16', 'INT8', 'INT8CALIB'}") .Attr("calibration_data: string = ''") .Input("in_tensor: InT") - .Output("out_tensor: OutT") - .SetShapeFn(shape_inference::TRTEngineOpShapeInference); + .Output("out_tensor: OutT"); +// TODO(jie): TF requires concrete output shape for concrete input shapes. +// This is tricky for batch dimension, since we cannot ensure which input +// would carry the correct batch dimension (for the current stage of the +// implementation, we do require all input tensor to carry the same batch +// size, but this could change in the future). Hence we disable shape +// inference function as a workaround. +// .SetShapeFn(shape_inference::TRTEngineOpShapeInference); } // namespace tensorflow |