diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/test_tftrt.py')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/test_tftrt.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 090aa8bdb0..d26f260086 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -191,7 +191,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]) + cached_engine_batch_sizes=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -206,7 +206,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]) + cached_engine_batch_sizes=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], @@ -216,7 +216,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]) + cached_engine_batch_sizes=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) |