aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-31 22:35:45 -0800
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-01-31 22:35:45 -0800
commit9884a21214a88a2da847b9ea66533108e860067d (patch)
tree2e17b8c272078dea2060c776555767a34131b895
parentd074556997f3e8aad3a1ca2bcd723dc590777aeb (diff)
Add/fix copyright and fix format for the test script and other files.
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py101
-rw-r--r--tensorflow/contrib/tensorrt/trt_conversion.i2
-rw-r--r--third_party/tensorrt/LICENSE2
3 files changed, 61 insertions, 44 deletions
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 06b6f64c4a..32f839e624 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -1,53 +1,70 @@
-# Script to test TF-TensorRT integration
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to test TF-TensorRT integration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
import numpy as np
+
def getSimpleGraphDef():
- '''
- Create a simple graph and return its graph_def
- '''
- g=tf.Graph()
- with g.as_default():
- A=tf.placeholder(dtype=tf.float32,shape=(None,24,24,2),name="input")
- e=tf.constant([ [[[ 1., 0.5, 4., 6., 0.5, 1. ],
- [ 1., 0.5, 1., 1., 0.5, 1. ]]] ],
- name="weights",dtype=tf.float32)
- conv=tf.nn.conv2d(input=A,filter=e,strides=[1,2,2,1],padding="SAME",name="conv")
- b=tf.constant([ 4., 1.5, 2., 3., 5., 7. ],
- name="bias",dtype=tf.float32)
- t=tf.nn.bias_add(conv,b,name="biasAdd")
- relu=tf.nn.relu(t,"relu")
- idty=tf.identity(relu,"ID")
- v=tf.nn.max_pool(idty,[1,2,2,1],[1,2,2,1],"VALID",name="max_pool")
- out = tf.squeeze(v,name="output")
- return g.as_graph_def()
-
-def runGraph(gdef,dumm_inp):
- gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
- tf.reset_default_graph()
- g=tf.Graph()
- with g.as_default():
- inp,out=tf.import_graph_def(graph_def=gdef,
- return_elements=["input","output"])
- inp=inp.outputs[0]
- out=out.outputs[0]
- with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options),
- graph=g) as sess:
- val=sess.run(out,{inp:dumm_inp})
- return val
-if "__main__" in __name__:
- inpDims=(100,24,24,2)
- dummy_input=np.random.random_sample(inpDims)
- gdef=getSimpleGraphDef() #get graphdef
- trt_graph=trt.CreateInferenceGraph(gdef,["output"],inpDims[0]) # get optimized graph
- o1=runGraph(gdef,dummy_input)
- o2=runGraph(trt_graph,dummy_input)
- assert(np.array_equal(o1,o2))
-
+ """Create a simple graph and return its graph_def"""
+ g = tf.Graph()
+ with g.as_default():
+ A = tf.placeholder(dtype=tf.float32, shape=(None, 24, 24, 2), name="input")
+ e = tf.constant(
+ [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
+ name="weights",
+ dtype=tf.float32)
+ conv = tf.nn.conv2d(
+ input=A, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
+ b = tf.constant([4., 1.5, 2., 3., 5., 7.], name="bias", dtype=tf.float32)
+ t = tf.nn.bias_add(conv, b, name="biasAdd")
+ relu = tf.nn.relu(t, "relu")
+ idty = tf.identity(relu, "ID")
+ v = tf.nn.max_pool(
+ idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
+ out = tf.squeeze(v, name="output")
+ return g.as_graph_def()
+
+
+def runGraph(gdef, dumm_inp):
+ gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.50)
+ tf.reset_default_graph()
+ g = tf.Graph()
+ with g.as_default():
+ inp, out = tf.import_graph_def(
+ graph_def=gdef, return_elements=["input", "output"])
+ inp = inp.outputs[0]
+ out = out.outputs[0]
+ with tf.Session(
+ config=tf.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
+ val = sess.run(out, {inp: dumm_inp})
+ return val
+
+
+if "__main__" in __name__:
+ inpDims = (100, 24, 24, 2)
+ dummy_input = np.random.random_sample(inpDims)
+ gdef = getSimpleGraphDef()
+ trt_graph = trt.CreateInferenceGraph(gdef, ["output"],
+ inpDims[0]) # Get optimized graph
+ o1 = runGraph(gdef, dummy_input)
+ o2 = runGraph(trt_graph, dummy_input)
+ assert (np.array_equal(o1, o2))
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/contrib/tensorrt/trt_conversion.i
index 828b4b35c2..f085380054 100644
--- a/tensorflow/contrib/tensorrt/trt_conversion.i
+++ b/tensorflow/contrib/tensorrt/trt_conversion.i
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/third_party/tensorrt/LICENSE b/third_party/tensorrt/LICENSE
index 96c0e3bcaf..146d9b765c 100644
--- a/third_party/tensorrt/LICENSE
+++ b/third_party/tensorrt/LICENSE
@@ -188,7 +188,7 @@ Copyright 2018 The TensorFlow Authors. All rights reserved.
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright 2015, The TensorFlow Authors.
+ Copyright 2018, The TensorFlow Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.