diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/test_tftrt.py')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/test_tftrt.py | 57 |
1 files changed, 3 insertions, 54 deletions
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 0b661bd536..c78f6f2224 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -60,7 +60,6 @@ def get_simple_graph_def(): def run_graph(gdef, dumm_inp): - """Run given graphdef once.""" gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) ops.reset_default_graph() g = ops.Graph() @@ -75,65 +74,15 @@ def run_graph(gdef, dumm_inp): return val -# Use real data that is representatitive of the inference dataset -# for calibration. For this test script it is random data. -def run_calibration(gdef, dumm_inp): - """Run given calibration graph multiple times.""" - gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50) - ops.reset_default_graph() - g = ops.Graph() - with g.as_default(): - inp, out = importer.import_graph_def( - graph_def=gdef, return_elements=["input", "output"]) - inp = inp.outputs[0] - out = out.outputs[0] - with csess.Session( - config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess: - # run over real calibration data here, we are mimicking a calibration set of - # 30 different batches. Use as much calibration data as you want - for _ in range(30): - val = sess.run(out, {inp: dumm_inp}) - return val - - if "__main__" in __name__: inp_dims = (100, 24, 24, 2) dummy_input = np.random.random_sample(inp_dims) - orig_graph = get_simple_graph_def() # use a frozen graph for inference + gdef = get_simple_graph_def() # Get optimized graph - trt_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) - o1 = run_graph(orig_graph, dummy_input) + trt_graph = trt.create_inference_graph(gdef, ["output"], inp_dims[0]) + o1 = run_graph(gdef, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) assert np.array_equal(o1, o2) assert np.array_equal(o3, o2) # sanity check - fp16_graph = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) - int8_calib_gdef = trt.create_inference_graph( - input_graph_def=orig_graph, - outputs=["output"], - max_batch_size=inp_dims[0], - max_workspace_size_bytes=1 << 25, - precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8" - minimum_segment_size=2 # minimum number of nodes in an engine - ) - o4 = run_graph(fp16_graph, dummy_input) - _ = run_calibration(int8_calib_gdef, dummy_input) - int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) - o5 = run_graph(int8_graph, dummy_input) - assert np.allclose(o1, o4) - assert np.allclose(o1, o5) print("Pass") |