aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-14 20:39:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-14 20:44:02 -0700
commit9037e241de1e64044ff55ab539ccc1fb013c178a (patch)
treef7b8bda19a5efdd57f99ce9cd7b0bf6fed211628
parent357cd4b8b2f960520fc57b6cfbf41117a2a20fc7 (diff)
Enable Add/AddN tree rewrite for symbolically equal shapes.
1) Rewrite a tree of Add/AddN ops with a single AddN, if all shapes are symbolically equal 2) Lookup shape properties using GraphProperties instead of direct access to Node attributes PiperOrigin-RevId: 189131726
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc173
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h3
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc61
-rw-r--r--tensorflow/core/grappler/utils.cc26
-rw-r--r--tensorflow/core/grappler/utils.h4
-rw-r--r--tensorflow/core/grappler/utils_test.cc41
6 files changed, 231 insertions, 77 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index c0fcfaf428..675cd8f072 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -197,35 +197,39 @@ bool IsNumberType(DataType dtype) { return kNumberTypes.Contains(dtype); }
const char kOutputShapesAttr[] = "_output_shapes";
-PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) {
- int output_pos;
- string node_name = ParseNodeName(input, &output_pos);
- const NodeDef* input_node = node_map.GetNode(node_name);
- auto attr = input_node->attr();
- if (attr.find(kOutputShapesAttr) == attr.end()) {
- return PartialTensorShape(); // unknown shape
- } else {
- return attr.at(kOutputShapesAttr).list().shape(output_pos);
- }
+// Shape is symbolically defined if it has a known rank, and each dimension is
+// defined, or is an unknown symbol (dim.size <= -2).
+bool ShapeIsSymbolicallyDefined(const TensorShapeProto& shape) {
+ return !shape.unknown_rank() &&
+ std::all_of(
+ shape.dim().begin(), shape.dim().end(),
+ [](const TensorShapeProto::Dim& dim) { return dim.size() != -1; });
+}
+
+bool ShapeIsSymbolicallyDefined(const OpInfo::TensorProperties& properties) {
+ return ShapeIsSymbolicallyDefined(properties.shape());
}
-bool ShapesEqual(const string& input_x, const string& input_y,
- const NodeMap& node_map) {
- PartialTensorShape x_shape = GetInputShape(input_x, node_map);
- PartialTensorShape y_shape = GetInputShape(input_y, node_map);
- if (x_shape.unknown_rank() || y_shape.unknown_rank() ||
- x_shape.dims() != y_shape.dims()) {
+bool ShapesSymbolicallyEqual(const TensorShapeProto& left,
+ const TensorShapeProto& right) {
+ if (left.unknown_rank() || right.unknown_rank() ||
+ left.dim_size() != right.dim_size()) {
return false;
}
- for (int i = 0; i < x_shape.dims(); ++i) {
- if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 ||
- x_shape.dim_size(i) != y_shape.dim_size(i)) {
+ for (int i = 0; i < left.dim_size(); ++i) {
+ if (left.dim(i).size() == -1 || right.dim(i).size() == -1 ||
+ left.dim(i).size() != right.dim(i).size()) {
return false;
}
}
return true;
}
+bool ShapesSymbolicallyEqual(const OpInfo::TensorProperties& left,
+ const OpInfo::TensorProperties& right) {
+ return ShapesSymbolicallyEqual(left.shape(), right.shape());
+}
+
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
@@ -290,16 +294,19 @@ NodeDef* GetTailOfValuePreservingChain(
struct ArithmeticOptimizerContext {
ArithmeticOptimizerContext(
const std::unordered_set<string>* nodes_to_preserve,
- GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map,
+ GraphDef* optimized_graph, GraphProperties* graph_properties,
+ NodeMap* node_map, FrameMap* frame_map,
SetVector<NodeDef*>* nodes_to_simplify)
: nodes_to_preserve(nodes_to_preserve),
optimized_graph(optimized_graph),
+ graph_properties(graph_properties),
node_map(node_map),
frame_map(frame_map),
nodes_to_simplify(nodes_to_simplify) {}
const std::unordered_set<string>* nodes_to_preserve;
GraphDef* optimized_graph;
+ GraphProperties* graph_properties;
NodeMap* node_map;
FrameMap* frame_map;
SetVector<NodeDef*>* nodes_to_simplify;
@@ -388,7 +395,7 @@ class ArithmeticOptimizerStage {
ctx_.nodes_to_simplify->PushBack(node);
}
- // Get a node by input name from a node map. Return a error if node was not
+ // Get a node by input name from a node map. Return an error if node was not
// found.
Status GetInputNode(const string& input, NodeDef** node) const {
string node_name = NodeName(input);
@@ -401,22 +408,31 @@ class ArithmeticOptimizerStage {
return Status::OK();
}
- // Get input shape from a node map. If node doesn't exists return unknown
- // shape.
- PartialTensorShape GetInputShape(const string& input) const {
- int position;
- string node_name = ParseNodeName(input, &position);
- NodeDef* node;
- Status node_status = GetInputNode(node_name, &node);
- if (!node_status.ok()) {
- return PartialTensorShape(); // unknown shape
+ // Lookup tensor properties by name. Tensor name might have non-zero port
+ // number. Return an error if tensor node doesn't exists in a graph, or it
+ // doesn't have properties defined for requested port.
+ Status GetTensorProperties(const string& tensor,
+ OpInfo::TensorProperties* properties) const {
+ int port;
+ string tensor_node_name = ParseNodeName(tensor, &port);
+ if (port < 0) {
+ return errors::InvalidArgument(
+ "Can't get tensor properties of control dependency ", tensor);
}
- auto attr = node->attr();
- if (attr.find(kOutputShapesAttr) == attr.end()) {
- return PartialTensorShape(); // unknown shape
- } else {
- return attr.at(kOutputShapesAttr).list().shape(position);
+
+ const auto& output_properties =
+ ctx_.graph_properties->GetOutputProperties(tensor_node_name);
+ auto num_outputs = output_properties.size();
+
+ if (num_outputs == 0 || port > num_outputs - 1) {
+ return errors::InvalidArgument(
+ "Node ", tensor_node_name,
+ " is missing output properties at position :", port,
+ " (num_outputs=", num_outputs, ")");
}
+
+ properties->CopyFrom(output_properties[port]);
+ return Status::OK();
}
NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
@@ -509,8 +525,8 @@ class ArithmeticOptimizerStage {
// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
// original inputs of absorbed nodes.
//
-// All nodes in a Add/AddN subgraph must have fully specified and identical
-// shape. All nodes must have the same device placement.
+// All nodes in a Add/AddN subgraph must have symbolically equal shape. All
+// nodes must have the same device placement.
//
// Example:
// AddN_1
@@ -533,16 +549,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (!IsRewritable(node)) {
return false;
}
- // and must have fully defined shape
- // TODO(ezhulenev): support partially defined shapes, when we can prove that
- // unknown dimensions in the rewritten subgraph are the same.
- PartialTensorShape shape = GetInputShape(node->name());
- if (!shape.IsFullyDefined()) {
- return false;
- }
- // and must have inputs of fully defined shape identical to the output
- // TODO(ezhulenev): relax this condition to support equal unknown dimensions
- return HasAllInputsOfIdenticalShape(*node, shape);
+
+ // shape must be symbolically defined and all inputs compatible with it
+ OpInfo::TensorProperties properties;
+ Status has_properties = GetTensorProperties(node->name(), &properties);
+ return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
+ HasAllInputsOfSymbolicallyEqualShape(*node, properties);
}
Status TrySimplify(const NodeDef* node,
@@ -567,23 +579,26 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
// input_nodes: [x, y, z, w, q, e]
struct AddOpsGroup {
const NodeDef* root_node;
- PartialTensorShape root_shape;
+ TensorShapeProto root_shape;
// Add/AddN operations below the root level that were absorbed by this group
std::vector<NodeDef*> absorbed_nodes;
// Inputs of absorbed nodes that will be forwarded to rewritten AddN node
std::vector<string> inputs;
};
- // Check if all inputs are fully defined and identical to expected shape
- bool HasAllInputsOfIdenticalShape(const NodeDef& node,
- const PartialTensorShape& shape) const {
+ // Check if all inputs have symbolically equal shapes
+ bool HasAllInputsOfSymbolicallyEqualShape(
+ const NodeDef& node, const OpInfo::TensorProperties& properties) const {
const AddOpsRewriteStage* self = this;
- return std::all_of(node.input().begin(), node.input().end(),
- [self, &shape](const string& input) {
- auto input_shape = self->GetInputShape(input);
- return input_shape.IsFullyDefined() &&
- input_shape.IsIdenticalTo(shape);
- });
+ return std::all_of(
+ node.input().begin(), node.input().end(),
+ [self, &properties](const string& input) {
+ OpInfo::TensorProperties input_properties;
+ Status has_input_properties =
+ self->GetTensorProperties(input, &input_properties);
+ return has_input_properties.ok() &&
+ ShapesSymbolicallyEqual(properties, input_properties);
+ });
}
// TODO(ezhulenev): use GraphRewriter?
@@ -614,27 +629,25 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
if (!node_status.ok()) {
return false;
}
-
- PartialTensorShape shape = GetInputShape(name);
- CHECK(shape.IsIdenticalTo(group.root_shape))
- << "Cannot absorb a node of incompatible shape";
-
// check basic preconditions
if (!IsRewritable(node)) {
return false;
}
- // with a single output consumer (presumably if we reach this node from
+ // with a single output data consumer (presumably if we reach this node from
// previously absorbed or a root node, it means that this node is not used
// as an input to any other op, outside of the group)
- if (ctx_.node_map->GetOutputs(node->name()).size() != 1) {
+ if (NumNonControlDataOutputs(*node, *ctx_.node_map) != 1) {
return false;
}
// must be on the same device as a root node
if (node->device() != group.root_node->device()) {
return false;
}
- // All input shapes must be fully defined and equal to the node shape
- return HasAllInputsOfIdenticalShape(*node, shape);
+ // All input shapes must be symbolically defined and equal to the node shape
+ OpInfo::TensorProperties properties;
+ Status has_properties = GetTensorProperties(name, &properties);
+ return has_properties.ok() &&
+ HasAllInputsOfSymbolicallyEqualShape(*node, properties);
}
// Node requirements both for a root node and an absorbed node
@@ -660,15 +673,19 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
}
// Check that optimized group node name doesn't exists. It might happen if
- // graph optimized multiple times without pruning beween invocations.
+ // graph optimized multiple times without pruning between invocations.
bool IsRewritten(const AddOpsGroup& group) const {
return ctx_.node_map->NodeExists(AddOpsGroupName(group));
}
// Create an AddOpsGroup with a root in a given node
Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
+ OpInfo::TensorProperties root_node_output_properties;
+ TF_RETURN_IF_ERROR(
+ GetTensorProperties(root_node->name(), &root_node_output_properties));
+
group->root_node = root_node;
- group->root_shape = GetInputShape(root_node->name());
+ group->root_shape = root_node_output_properties.shape();
group->absorbed_nodes.reserve(root_node->input_size());
for (int i = 0; i < root_node->input_size(); ++i) {
@@ -737,6 +754,9 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
added_node->add_input(input);
}
+ // Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(group.root_node, {added_node}, "", {});
+
VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
<< " Add/AddN nodes from the graph";
@@ -891,8 +911,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
mul_node->input(0) == common_factor ? 1 : 0;
unique_factors->push_back(mul_node->input(unique_factor_index));
if (i > 0 && !IsAdd(*node)) {
- *shapes_match = ShapesEqual(unique_factors->front(),
- unique_factors->back(), *ctx_.node_map);
+ OpInfo::TensorProperties lhs;
+ OpInfo::TensorProperties rhs;
+ TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
+ TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
+ *shapes_match = ShapesSymbolicallyEqual(lhs, rhs);
}
}
return Status::OK();
@@ -1627,8 +1650,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
}
const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
- node_map_.get(), &frame_map_,
- &nodes_to_simplify);
+ graph_properties_.get(), node_map_.get(),
+ &frame_map_, &nodes_to_simplify);
std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
@@ -1660,8 +1683,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
const NodeDef* node = nodes_to_simplify.PopBack();
// TODO(ezhulenev): move all rewrites into separate stages
- string simplified_tensor =
- TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
+ string simplified_tensor = "";
+ if (options_.enable_try_simplify_and_replace) {
+ simplified_tensor = TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
+ }
// if it was not simplified try to run it through all configured stages
if (simplified_tensor.empty()) {
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index d5a7af5ba6..2c6b52c072 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -55,6 +55,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
// Granular control for arithmetic optimizer stages
struct ArithmeticOptimizerOptions {
+ // TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests.
+ // Remove when all optimizers will be migrated to separate stages.
+ bool enable_try_simplify_and_replace = true;
bool combine_add_to_addn = true;
bool hoist_common_factor_out_of_aggregation = true;
bool remove_inverse_transpose = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e1f47625c1..d677aee589 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -89,6 +89,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
+ options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false;
options.remove_inverse_transpose = false;
@@ -1270,7 +1271,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
EXPECT_TRUE(IsNodesDirectlyConnected(node_map, "inputs", "outputs"));
}
-TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfIdenticalShape) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
tensorflow::Scope sx = s.NewSubScope("x");
tensorflow::Scope sy = s.NewSubScope("y");
@@ -1322,7 +1323,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
}
-TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_MultiplePasses) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
@@ -1395,7 +1396,7 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
}
-TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddInputMultipleTimes) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
@@ -1440,5 +1441,59 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
EXPECT_EQ("c", collapsed_add->input(3));
}
+TEST_F(ArithmeticOptimizerTest, AddOpsRewrite_AddOpsOfSymbolicallyEqualShape) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // unknown input shape propagated symbolically through the graph
+ auto input = ops::Variable(s.WithOpName("input"), {-1, 2}, DT_FLOAT);
+
+ // [a, b, c] have symbolically equal shapes
+ auto a = ops::Sqrt(s.WithOpName("a"), input);
+ auto b = ops::Square(s.WithOpName("b"), input);
+ auto c = ops::Round(s.WithOpName("c"), input);
+
+ // [add_ab, add_abc] shape must be inferred from inputs
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyAddToAddNCombining(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // +
+ // / \
+ // + c --> AddN(a, b, c)
+ // / \
+ // a b
+ EXPECT_EQ(6, output.node_size());
+
+ NodeMap node_map(&output);
+
+ // check add tree was replaced with AddN
+ const NodeDef* collapsed_add =
+ node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+ ASSERT_TRUE(collapsed_add != nullptr);
+ EXPECT_EQ("AddN", collapsed_add->op());
+ EXPECT_EQ(3, collapsed_add->input_size());
+ EXPECT_EQ("a", collapsed_add->input(0));
+ EXPECT_EQ("b", collapsed_add->input(1));
+ EXPECT_EQ("c", collapsed_add->input(2));
+
+ // check output was re-wired to new node
+ const NodeDef* updated_outputs = node_map.GetNode("outputs");
+ ASSERT_TRUE(updated_outputs != nullptr);
+ EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index eb1f882ff1..829bfe9e31 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
tensor->flat<T>()(0) = static_cast<T>(value);
return true;
}
+
+// Is 'node' an operator that consumes only the shape of its input, not the
+// data itself?
+// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
+// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
+bool IsShapeConsumer(const NodeDef& node) {
+ const string& op = node.op();
+ return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
+}
+
} // namespace
NodeMap::NodeMap(GraphDef* graph) {
@@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
return num_outputs;
}
+int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
+ int num_data_outputs = 0;
+ for (const NodeDef* output : node_map.GetOutputs(node.name())) {
+ if (IsShapeConsumer(*output)) continue;
+
+ for (int i = 0; i < output->input_size(); ++i) {
+ const string& input = output->input(i);
+ if (!IsControlInput(input) && NodeName(input) == node.name()) {
+ ++num_data_outputs;
+ break;
+ }
+ }
+ }
+ return num_data_outputs;
+}
+
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index fbd38c1531..7aa31939f5 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -144,6 +144,10 @@ int NumNonControlInputs(const NodeDef& node);
// Number of connected non-control outputs.
int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map);
+// Number of connected non-control data outputs (Ops that consume output tensor
+// data, not just it's shape).
+int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map);
+
// Removes redundant control inputs from node.
void DedupControlInputs(NodeDef* node);
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index eabce5b5ee..49a1996d25 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -292,6 +292,47 @@ TEST_F(UtilsTest, DedupControlInputs) {
EXPECT_EQ("gnu", foo.input(1));
}
+TEST_F(UtilsTest, NumNonControlOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // *) Round node has control dependency edge from Add, which
+ // is not on this scheme (ASCII graphics limitation).
+ //
+ // *Round [Sqrt, Shape]
+ // | |
+ // | ctrl |
+ // Mul ------> Add
+ // / \ / \
+ // x y a b
+ auto x = ops::Variable(s.WithOpName("x"), {1, 2}, DT_FLOAT);
+ auto y = ops::Variable(s.WithOpName("y"), {1, 2}, DT_FLOAT);
+ auto a = ops::Variable(s.WithOpName("a"), {1, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {1, 2}, DT_FLOAT);
+
+ auto mul = ops::Multiply(s.WithOpName("mul"), x, y);
+ auto add = ops::Add(s.WithOpName("add").WithControlDependencies(mul), a, b);
+
+ auto shape = ops::Shape(s.WithOpName("shape"), add);
+ auto sqrt = ops::Sqrt(s.WithOpName("sqrt"), add);
+
+ auto round =
+ ops::Round(s.WithOpName("round").WithControlDependencies(add), mul);
+
+ GraphDef graph;
+ TF_CHECK_OK(s.ToGraphDef(&graph));
+ NodeMap node_map(&graph);
+
+ const NodeDef* add_node = node_map.GetNode("add");
+ ASSERT_TRUE(add_node != nullptr);
+
+ // [a, b] are only non-control inputs
+ EXPECT_EQ(2, NumNonControlInputs(*add_node));
+ // [sqrt, shape] are non control outputs
+ EXPECT_EQ(2, NumNonControlOutputs(*add_node, node_map));
+ // sqrt is the only data output
+ EXPECT_EQ(1, NumNonControlDataOutputs(*add_node, node_map));
+}
+
TEST_F(UtilsTest, DeleteNodes) {}
} // namespace