aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/test/test_tftrt.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/test_tftrt.py')
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 090aa8bdb0..d26f260086 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -191,7 +191,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input)
@@ -206,7 +206,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
int8_calib_gdef = trt.create_inference_graph(
input_graph_def=orig_graph,
outputs=["output"],
@@ -216,7 +216,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o4 = run_graph(fp16_graph, dummy_input)
_ = run_calibration(int8_calib_gdef, dummy_input)
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)