aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/test/test_tftrt.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/test_tftrt.py')
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py57
1 files changed, 3 insertions, 54 deletions
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 0b661bd536..c78f6f2224 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -60,7 +60,6 @@ def get_simple_graph_def():
def run_graph(gdef, dumm_inp):
- """Run given graphdef once."""
gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
ops.reset_default_graph()
g = ops.Graph()
@@ -75,65 +74,15 @@ def run_graph(gdef, dumm_inp):
return val
-# Use real data that is representatitive of the inference dataset
-# for calibration. For this test script it is random data.
-def run_calibration(gdef, dumm_inp):
- """Run given calibration graph multiple times."""
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
- ops.reset_default_graph()
- g = ops.Graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=["input", "output"])
- inp = inp.outputs[0]
- out = out.outputs[0]
- with csess.Session(
- config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
- # run over real calibration data here, we are mimicking a calibration set of
- # 30 different batches. Use as much calibration data as you want
- for _ in range(30):
- val = sess.run(out, {inp: dumm_inp})
- return val
-
-
if "__main__" in __name__:
inp_dims = (100, 24, 24, 2)
dummy_input = np.random.random_sample(inp_dims)
- orig_graph = get_simple_graph_def() # use a frozen graph for inference
+ gdef = get_simple_graph_def()
# Get optimized graph
- trt_graph = trt.create_inference_graph(
- input_graph_def=orig_graph,
- outputs=["output"],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
- o1 = run_graph(orig_graph, dummy_input)
+ trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0])
+ o1 = run_graph(gdef, dummy_input)
o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input)
assert np.array_equal(o1, o2)
assert np.array_equal(o3, o2) # sanity check
- fp16_graph = trt.create_inference_graph(
- input_graph_def=orig_graph,
- outputs=["output"],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
- int8_calib_gdef = trt.create_inference_graph(
- input_graph_def=orig_graph,
- outputs=["output"],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2 # minimum number of nodes in an engine
- )
- o4 = run_graph(fp16_graph, dummy_input)
- _ = run_calibration(int8_calib_gdef, dummy_input)
- int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
- o5 = run_graph(int8_graph, dummy_input)
- assert np.allclose(o1, o4)
- assert np.allclose(o1, o5)
print("Pass")