diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-28 07:23:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-28 07:23:24 -0700 |
commit | bf12134843638748c5541aed1dbb8647ebf504fd (patch) | |
tree | e837e01edcb4e9d2a0aa5d49c15a8a2cbc999ae6 /tensorflow/contrib/tensorrt | |
parent | e3095dc262bfdda08d01fce105680515a3d1a7f4 (diff) | |
parent | 8ae93f93c0d7df9ad9b143b5d7c888026759ab85 (diff) |
Merge pull request #21131 from jjsjann123:alignment_test
PiperOrigin-RevId: 206438376
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/memory_alignment_test.py | 72 |
2 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 033d5207f6..46f3c36e3d 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -393,6 +393,7 @@ cuda_py_tests( # "test/unary_test.py", # Blocked by trt4 installation # "test/vgg_block_nchw_test.py", # "test/vgg_block_test.py", + "test/memory_alignment_test.py", ], additional_deps = [ ":tf_trt_integration_test_base", diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py new file mode 100644 index 0000000000..3dd95c6f62 --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -0,0 +1,72 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model script to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test + + +class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): + + def GetParams(self): + """Testing conversion of BatchMatMul in TF-TRT conversion.""" + dtype = dtypes.float32 + input_name = "input" + input_dims = [2, 15, 15, 3] + g = ops.Graph() + with g.as_default(): + inp = array_ops.placeholder( + dtype=dtype, shape=[None] + input_dims[1:], name=input_name) + with g.device("/GPU:0"): + e1 = constant_op.constant( + np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype) + e2 = constant_op.constant( + np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype) + conv = nn.conv2d( + input=inp, + filter=e1, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + out = nn.conv2d( + input=conv, + filter=e2, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv_2") + array_ops.squeeze(out, name=self.output_name) + return trt_test.TfTrtIntegrationTestParams( + gdef=g.as_graph_def(), + input_names=[input_name], + input_dims=[input_dims], + num_expected_engines=1, + expected_output_dims=(2, 15, 15, 10), + allclose_atol=1.e-02, + allclose_rtol=1.e-02) + + +if __name__ == "__main__": + test.main() |