aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-30 12:04:22 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-07-30 12:04:22 -0700
commite8e2cc72f3367aee1789dc0f5bcbd8f027c7180f (patch)
tree703cd2a889e68ba8a1e7fa9fc26ebbb8200fb38d /tensorflow/contrib/tensorrt
parent4158295eef9489610ddcbfa8ba3d8bda43e65194 (diff)
Add more tests
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc20
-rw-r--r--tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc8
-rw-r--r--tensorflow/contrib/tensorrt/segment/segment.cc27
-rw-r--r--tensorflow/contrib/tensorrt/test/base_test.py104
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py2
-rw-r--r--tensorflow/contrib/tensorrt/test/neighboring_engine_test.py13
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py42
7 files changed, 152 insertions, 64 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 1e6300578d..e06704f5d1 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -326,32 +326,12 @@ tensorflow::Status GetEngineInfo(
}
VLOG(1) << "Adding const node " << input_node->name();
QCHECK(subgraph_node_names.insert(input_node->name()).second);
-#if 1
// Since we duplicate the const input node in both the segment graphdef
// and the engine, the segment node doesn't depend on it anymore, so we
// add a control dependency instead.
info->connections.emplace_back(
input_node->name(), input_node->id(), node_name, node_id,
/*input_edge=*/true);
-#else
- // Add control inputs to the const node as control input connections to
- // the engine.
- for (const auto const_in_edge : input_node->in_edges()) {
- QCHECK(const_in_edge->IsControlEdge()); // Must be control edge.
- auto const_in_node = const_in_edge->src();
- QCHECK(!segment_nodes.count(const_in_node->name()))
- << "Loop found between segment and non-segment nodes, from "
- "segment node "
- << const_in_node->name() << " to non-segment node "
- << input_node->name() << " to segment node " << node->name();
- if (const_in_node->IsSource()) continue;
- VLOG(1) << "Control edge from node " << const_in_node->name()
- << " to " << input_node->name();
- info->connections.emplace_back(
- const_in_node->name(), const_in_node->id(), input_node->name(),
- input_node->id(), /*input_edge=*/true);
- }
-#endif
} else {
// Non-const data input.
int port = Graph::kControlSlot - 1;
diff --git a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
index 044c736c03..f33f2cc4d6 100644
--- a/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/contrib/tensorrt/convert/trt_optimization_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stacktrace.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
@@ -189,9 +190,6 @@ tensorflow::Status TRTOptimizationPass::Optimize(
tensorflow::grappler::Cluster* cluster,
const tensorflow::grappler::GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Called TRTOptimization Pass " << name_;
- if (VLOG_IS_ON(1)) {
- PrintDebugInfo(cluster, item);
- }
// This is a hack to workaround optimizer issue. MetaOptimizer calls
// optimization passes on function objects as well, we should not modify
// generated funcdefs! This is fragile but we don't have any other option
@@ -203,6 +201,10 @@ tensorflow::Status TRTOptimizationPass::Optimize(
*optimized_graph = item.graph;
return tensorflow::Status::OK();
}
+ if (VLOG_IS_ON(1)) {
+ VLOG(2) << CurrentStackTrace();
+ PrintDebugInfo(cluster, item);
+ }
int max_dim = -1;
if (item.feed.size()) {
for (const auto& f : item.feed) {
diff --git a/tensorflow/contrib/tensorrt/segment/segment.cc b/tensorflow/contrib/tensorrt/segment/segment.cc
index 008fffc954..e1ed7ebf6c 100644
--- a/tensorflow/contrib/tensorrt/segment/segment.cc
+++ b/tensorflow/contrib/tensorrt/segment/segment.cc
@@ -558,27 +558,36 @@ tensorflow::Status SegmentGraph(
// then after doing this operation the resulting subgraph will keep the
// same properties 1 and 2.
//
- // For simplicity we use heuristics: for input nodes remove all its
- // input, for output nodes remove all its output. In this way, for common
- // cases the number of removed nodes should be minimum.
+ // For simplicity we use heuristics: for input and const output nodes
+ // remove all their inputs, and for non-const output nodes remove all
+ // their outputs. In this way, for common cases the number of removed
+ // nodes should be minimum.
auto remove_nodes = [&segment_nodes](
bool is_input_nodes,
std::deque<const tensorflow::Node*>* que) {
// Run a BFS on the queue to find all the input/output nodes.
std::set<const tensorflow::Node*> visited;
+ std::set<const tensorflow::Node*> logged(que->begin(), que->end());
while (!que->empty()) {
auto node = que->front();
que->pop_front();
if (!visited.insert(node).second) continue;
segment_nodes.erase(node);
- for (auto in :
- is_input_nodes ? node->in_nodes() : node->out_nodes()) {
+ for (auto in : (is_input_nodes || node->type_string() == "Const")
+ ? node->in_nodes()
+ : node->out_nodes()) {
if (segment_nodes.count(in)) {
que->push_back(in);
- VLOG(2) << "Need to remove node " << in->name()
- << " because one of its "
- << (is_input_nodes ? "output" : "input")
- << " nodes in the graph was removed: " << node->name();
+ if (VLOG_IS_ON(2)) {
+ if (!logged.count(in)) {
+ VLOG(2) << "----> Need to remove node " << in->name()
+ << " because one of its "
+ << (is_input_nodes ? "output" : "input")
+ << " nodes in the graph was removed: "
+ << node->name();
+ logged.insert(in);
+ }
+ }
}
}
}
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py
index 9d14e635f4..e765ae3661 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/contrib/tensorrt/test/base_test.py
@@ -234,5 +234,109 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase):
allclose_atol=1.e-06,
allclose_rtol=1.e-06)
+
+class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing single segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]},
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ n = inp
+ c = constant_op.constant(1.0, name="c")
+ n = math_ops.add(n, c, name="add")
+ n = math_ops.mul(n, n, name="mul")
+ n = math_ops.add(n, n, name="add1")
+ n = self.trt_incompatible_op(n, name="incompatible1")
+ n = math_ops.add(n, c, name="add2")
+ n = math_ops.mul(n, n, name="mul1")
+ n = math_ops.add(n, n, name="add3")
+ array_ops.squeeze(n, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["add2", "add3", "mul1"],
+ "my_trt_op_1": ["add", "add1", "mul"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
+class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Create a graph containing multiple segment."""
+ input_name = "input"
+ input_dims = [2, 32, 32, 3]
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=input_dims, name=input_name)
+ with g.device("/GPU:0"):
+ c1 = constant_op.constant(1.0, name="c1")
+ c2 = constant_op.constant(1.0, name="c2")
+ d1 = constant_op.constant(1.0, name="d1")
+ d2 = self.trt_incompatible_op(inp, name="d2")
+ with g.control_dependencies([d1, d2]):
+ add = math_ops.add(inp, c1, name="add")
+ with g.control_dependencies([d1, d2]):
+ mul = math_ops.mul(add, add, name="mul")
+ with g.control_dependencies([d1, d2]):
+ add1 = math_ops.add(mul, mul, name="add1")
+ edge = self.trt_incompatible_op(add1, name="incompatible")
+ with g.control_dependencies([d1, d2, add, mul]):
+ add2 = math_ops.add(edge, c2, name="add2")
+ with g.control_dependencies([d1, d2, add1, mul]):
+ mul1 = math_ops.mul(add2, add2, name="mul1")
+ with g.control_dependencies([d1, d2, add, add1]):
+ add3 = math_ops.add(mul1, mul1, name="add3")
+ array_ops.squeeze(add3, name=self.output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[input_dims],
+ expected_engines={
+ "my_trt_op_0": ["c1", "add", "add1", "mul"],
+ "my_trt_op_1": ["c2", "add2", "add3", "mul1"]
+ },
+ expected_output_dims=tuple(input_dims),
+ allclose_atol=1.e-06,
+ allclose_rtol=1.e-06)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
index 3dd95c6f62..66eb6be757 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -62,7 +62,7 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- num_expected_engines=1,
+ expected_engines=["my_trt_op_0"],
expected_output_dims=(2, 15, 15, 10),
allclose_atol=1.e-02,
allclose_rtol=1.e-02)
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
index 97e0d23b18..51c905a50b 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
@@ -25,7 +25,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.platform import test
@@ -51,15 +51,18 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase):
name="conv")
b = constant_op.constant(
np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
- t = conv * b
- e = gen_math_ops.tan(conv)
- t = t - e
+ t = math_ops.mul(conv, b, name="mul")
+ e = self.trt_incompatible_op(conv, name="incompatible")
+ t = math_ops.sub(t, e, name="sub")
array_ops.squeeze(t, name=self.output_name)
return trt_test.TfTrtIntegrationTestParams(
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0", "my_trt_op_1"],
+ expected_engines={
+ "my_trt_op_0": ["bias", "mul", "sub"],
+ "my_trt_op_1": ["weights", "conv"]
+ },
expected_output_dims=(2, 4, 5, 4),
allclose_atol=1.e-03,
allclose_rtol=1.e-03)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 5968af28ae..a35facaf12 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -23,6 +23,7 @@ import itertools
import warnings
import numpy as np
import six
+import os
from tensorflow.contrib.tensorrt.python import trt_convert
# pylint: disable=unused-import
@@ -151,7 +152,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
rewriter_cfg.optimizers.extend(["constfold", "layout"])
custom_op = rewriter_cfg.custom_optimizers.add()
custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 3
+ custom_op.parameter_map["minimum_segment_size"].i = 2
custom_op.parameter_map["max_batch_size"].i = max(
[dims[0] for dims in params.input_dims])
custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine
@@ -162,23 +163,6 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
else:
graph_options = config_pb2.GraphOptions()
- # Disable all other optimizations which can affect the converted graph.
- off = rewriter_config_pb2.RewriterConfig.OFF
- graph_options.optimizer_options.opt_level = config_pb2.OptimizerOptions.L0
- graph_options.rewrite_options.layout_optimizer = off
- graph_options.rewrite_options.constant_folding = off
- graph_options.rewrite_options.shape_optimization = off
- graph_options.rewrite_options.remapping = off
- graph_options.rewrite_options.arithmetic_optimization = off
- graph_options.rewrite_options.dependency_optimization = off
- graph_options.rewrite_options.loop_optimization = off
- graph_options.rewrite_options.function_optimization = off
- graph_options.rewrite_options.debug_stripper = off
- graph_options.rewrite_options.disable_model_pruning = True
- graph_options.rewrite_options.scoped_allocator_optimization = off
- graph_options.rewrite_options.memory_optimization = (
- rewriter_config_pb2.RewriterConfig.NO_MEM_OPT)
-
gpu_options = config_pb2.GPUOptions()
gpu_options.allow_growth = True
if trt_convert.get_linked_tensorrt_version()[0] == 3:
@@ -188,9 +172,14 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
gpu_options=gpu_options, graph_options=graph_options)
return config
- def _ExpectTestValue(self, engine_name, method, value):
+ def _ExpectTestValue(self, engine_name, method, expected_value):
+ label = "%s:%s" % (engine_name, method)
+ actual_value = trt_convert.get_test_value(label)
self.assertEqual(
- value, trt_convert.get_test_value("%s:%s" % (engine_name, method)))
+ expected_value,
+ actual_value,
+ msg="Unexpected test value with label %s. Actual: %s; expected: %s" %
+ (label, actual_value, expected_value))
def _ExpectCalibration(self, engine_name, value):
self._ExpectTestValue(engine_name, "ExecuteCalibration", value)
@@ -257,8 +246,9 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
graph_name = (
self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
".pbtxt")
- logging.info("Writing graph to %s/%s", self.get_temp_dir(), graph_name)
- graph_io.write_graph(gdef, self.get_temp_dir(), graph_name)
+ temp_dir = os.getenv('TRT_TEST_TMPDIR', self.get_temp_dir())
+ logging.info("Writing graph to %s/%s", temp_dir, graph_name)
+ graph_io.write_graph(gdef, temp_dir, graph_name)
def _VerifyConnections(self, params, converted_gdef):
old_to_new_node_map = {
@@ -314,8 +304,8 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
self.assertEqual(
expected_input_map,
actual_input_map,
- msg="expected:\n%s\nvs actual:\n%s" % (expected_input_map,
- actual_input_map))
+ msg="expected:\n%s\nvs actual:\n%s" % (sorted(
+ expected_input_map.items()), sorted(actual_input_map.items())))
def _VerifyGraphDef(self, params, run_params, gdef, graph_state):
self._WriteGraph(params, run_params, gdef, graph_state)
@@ -432,7 +422,7 @@ def _AddTests(test_class):
logging.info(
"Running test %s with parameters: use_optimizer=%s, "
"precision_mode=%s, dynamic_engine=%s",
- "testTfTRT_" + run_params.test_name, run_params.use_optimizer,
+ "testTfTrt_" + run_params.test_name, run_params.use_optimizer,
run_params.precision_mode, run_params.dynamic_engine)
self.RunTest(params, run_params)
@@ -461,7 +451,7 @@ def _AddTests(test_class):
precision_mode=precision_mode,
dynamic_engine=dynamic_engine,
test_name=test_name)
- setattr(test_class, "testTfTRT_" + test_name, _GetTest(run_params))
+ setattr(test_class, "testTfTrt_" + test_name, _GetTest(run_params))
if trt_convert.is_tensorrt_enabled():