diff options
16 files changed, 522 insertions, 62 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2de1a79d28..751cde2b10 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -209,8 +209,11 @@ py_test( size = "small", srcs = ["optimize_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:optimization", + "//tensorflow/contrib/data/python/ops:stats_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python/data/ops:dataset_ops", @@ -431,8 +434,8 @@ py_test( tags = ["no_pip"], deps = [ ":reader_dataset_ops_test_base", + ":stats_dataset_test_base", "//tensorflow/contrib/data/python/ops:stats_ops", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", @@ -442,6 +445,16 @@ py_test( ], ) +py_library( + name = "stats_dataset_test_base", + srcs = ["stats_dataset_test_base.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "threadpool_dataset_ops_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py index d8156dc9c7..2427935c73 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py @@ -19,7 +19,9 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import optimization +from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test @@ -160,5 +162,34 @@ class OptimizeDatasetTest(test.TestCase, parameterized.TestCase): sess.run(get_next) +class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): + + def testLatencyStatsOptimization(self): + + stats_aggregator = stats_ops.StatsAggregator() + dataset = dataset_ops.Dataset.from_tensors(1).apply( + optimization.assert_next( + ["LatencyStats", "Map", "LatencyStats", "Prefetch", + "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply( + optimization.optimize(["latency_all_edges"])).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + iterator = dataset.make_initializable_iterator() + get_next = iterator.get_next() + summary_t = stats_aggregator.get_summary() + + with self.test_session() as sess: + sess.run(iterator.initializer) + self.assertEqual(1 * 1, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + summary_str = sess.run(summary_t) + self._assertSummaryHasCount(summary_str, + "record_latency_TensorDataset/_1", 1) + self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4", + 1) + self._assertSummaryHasCount(summary_str, + "record_latency_PrefetchDataset/_6", 1) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index b4945685c1..a41d21f8c1 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base +from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base from tensorflow.contrib.data.python.ops import stats_ops -from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -29,28 +29,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class StatsDatasetTestBase(test.TestCase): - - def _assertSummaryHasCount(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.num) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - def _assertSummaryHasSum(self, summary_str, tag, expected_value): - summary_proto = summary_pb2.Summary() - summary_proto.ParseFromString(summary_str) - for value in summary_proto.value: - if tag == value.tag: - self.assertEqual(expected_value, value.histo.sum) - return - self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) - - -class StatsDatasetTest(StatsDatasetTestBase): +class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase): def testBytesProduced(self): stats_aggregator = stats_ops.StatsAggregator() @@ -197,7 +176,7 @@ class StatsDatasetTest(StatsDatasetTestBase): class FeatureStatsDatasetTest( - StatsDatasetTestBase, + stats_dataset_test_base.StatsDatasetTestBase, reader_dataset_ops_test_base.ReadBatchFeaturesTestBase): def testFeaturesStats(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py new file mode 100644 index 0000000000..9a13acf8f0 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing the input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.core.framework import summary_pb2 +from tensorflow.python.platform import test + + +class StatsDatasetTestBase(test.TestCase): + """Base class for testing statistics gathered in `StatsAggregator`.""" + + def _assertSummaryHasCount(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.num) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) + + def _assertSummaryHasSum(self, summary_str, tag, expected_value): + summary_proto = summary_pb2.Summary() + summary_proto.ParseFromString(summary_str) + for value in summary_proto.value: + if tag == value.tag: + self.assertEqual(expected_value, value.histo.sum) + return + self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto)) diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index 7998f0a902..a6b6b6f8b2 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -22,9 +22,7 @@ namespace grappler { GraphView::GraphView(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { auto node = graph_->mutable_node(i); - auto result = nodes_.emplace(node->name(), node); - // Check that the graph doesn't contain multiple nodes with the same name. - CHECK(result.second) << "Non unique node name detected: " << node->name(); + AddUniqueNodeOrDie(node); } for (NodeDef& node : *graph_->mutable_node()) { @@ -32,6 +30,12 @@ GraphView::GraphView(GraphDef* graph) : graph_(graph) { } } +void GraphView::AddUniqueNodeOrDie(NodeDef* node) { + auto result = nodes_.emplace(node->name(), node); + // Check that the graph doesn't contain multiple nodes with the same name. + CHECK(result.second) << "Non unique node name detected: " << node->name(); +} + void GraphView::AddFanouts(NodeDef* node) { for (int i = 0; i < node->input_size(); ++i) { OutputPort fanin; diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index 050789d2e2..ac260f85a0 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -115,6 +115,8 @@ class GraphView { const NodeDef& node, bool include_controlling_edges) const; protected: + // Add a new `node` to the graph. + void AddUniqueNodeOrDie(NodeDef* node); // Add fanout to every `node` input. void AddFanouts(NodeDef* node); std::unordered_map<string, NodeDef*>* MutableNodes() { return &nodes_; } diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index 6abafe11a2..f0aff90c6c 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -23,10 +23,22 @@ NodeDef* MutableGraphView::AddNode(NodeDef&& node) { auto* node_in_graph = GetGraph()->add_node(); *node_in_graph = std::move(node); - auto result = MutableNodes()->emplace(node_in_graph->name(), node_in_graph); - // Check that the graph doesn't contain multiple nodes with the same name. - CHECK(result.second) << "Non unique node name detected: " - << node_in_graph->name(); + AddUniqueNodeOrDie(node_in_graph); + + AddFanouts(node_in_graph); + return node_in_graph; +} + +NodeDef* MutableGraphView::InsertNode(const NodeDef& input_node, NodeDef&& node, + const int output_port_id) { + auto* node_in_graph = GetGraph()->add_node(); + *node_in_graph = std::move(node); + + AddUniqueNodeOrDie(node_in_graph); + + // replace input for the output nodes of `input_node` with `node` + ReplaceInput(input_node, *node_in_graph, output_port_id); + AddFanouts(node_in_graph); return node_in_graph; } diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h index 105eb972e8..971e5503d4 100644 --- a/tensorflow/core/grappler/mutable_graph_view.h +++ b/tensorflow/core/grappler/mutable_graph_view.h @@ -29,9 +29,16 @@ class MutableGraphView : public GraphView { using GraphView::GraphView; GraphDef* GetGraph() { return MutableGraph(); } + // Adds a new node to graph and updates the view. NodeDef* AddNode(NodeDef&& node); + // Inserts a new node to the graph after `input` node and updates the view. + // This adds `node` to the graph and replaces the input for the output + // nodes of `input` with a port `output_port_id` with the new node. + NodeDef* InsertNode(const NodeDef& input, NodeDef&& node, + int output_port_id = 0); + // Replaces the input for the output nodes of 'old_input' with a port // `output_port_id` with 'new_input'. // diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index f09dfb8271..2536bec35d 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -23,7 +23,18 @@ namespace tensorflow { namespace grappler { namespace { -TEST(MutableGraphViewTest, AddAndReplaceInput) { +bool FindChildWithName(const MutableGraphView& graph, + const string& output_port_name, + const string& input_name) { + GraphView::OutputPort output_port = graph.GetOutputPort(output_port_name, 0); + auto fanout = graph.GetFanout(output_port); + for (auto& input_port : fanout) { + if (input_port.node->name() == input_name) return true; + } + return false; +} + +TrivialTestGraphInputYielder SimpleGraph() { // This outputs simple graph like: // x // / \ @@ -35,7 +46,13 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { // AddN AddN_1 // \ / // y - TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"}); + TrivialTestGraphInputYielder simple_graph(2, 2, 2, false, + {"/CPU:0", "/GPU:0"}); + return simple_graph; +} + +TEST(MutableGraphViewTest, AddAndReplaceInput) { + TrivialTestGraphInputYielder fake_input = SimpleGraph(); GrapplerItem item; CHECK(fake_input.NextItem(&item)); @@ -49,18 +66,7 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { EXPECT_EQ("Square", fanin.node->name()); EXPECT_EQ(0, fanin.port_id); - auto find_child_with_name = [&graph](string output_port_name, - string input_name) { - GraphView::OutputPort output_port = - graph.GetOutputPort(output_port_name, 0); - auto fanout = graph.GetFanout(output_port); - for (auto& input_port : fanout) { - if (input_port.node->name() == input_name) return true; - } - return false; - }; - - EXPECT_FALSE(find_child_with_name("Square", "new_node")); + EXPECT_FALSE(FindChildWithName(graph, "Square", "new_node")); NodeDef new_node = *input.node; new_node.set_name("new_node"); @@ -70,13 +76,40 @@ TEST(MutableGraphViewTest, AddAndReplaceInput) { EXPECT_NE(graph.GetNode("new_node"), nullptr); graph.ReplaceInput(*input.node, *node_in_graph); - EXPECT_TRUE(find_child_with_name("Square", "new_node")); - EXPECT_TRUE(find_child_with_name("new_node", "y")); + EXPECT_TRUE(FindChildWithName(graph, "Square", "new_node")); + EXPECT_TRUE(FindChildWithName(graph, "new_node", "y")); +} + +TEST(MutableGraphViewTest, InsertNodes) { + TrivialTestGraphInputYielder fake_input = SimpleGraph(); + + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphDef new_graph = item.graph; + MutableGraphView graph(&new_graph); + + GraphView::InputPort input = graph.GetInputPort("AddN", 0); + + NodeDef new_node = *input.node; + new_node.set_name("new_node"); + new_node.set_input(0, input.node->name()); + + EXPECT_EQ(graph.GetNode("new_node"), nullptr); + graph.InsertNode(*input.node, std::move(new_node)); + EXPECT_NE(graph.GetNode("new_node"), nullptr); + EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN")); + EXPECT_TRUE(FindChildWithName(graph, "Square", "AddN_1")); + EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN")); + EXPECT_TRUE(FindChildWithName(graph, "Square_1", "AddN_1")); + EXPECT_TRUE(FindChildWithName(graph, "AddN", "new_node")); + EXPECT_TRUE(FindChildWithName(graph, "AddN_1", "y")); + EXPECT_TRUE(FindChildWithName(graph, "new_node", "y")); } TEST(MutableGraphViewTest, DeleteNodes) { // Outputs simple graph as described in first test. - TrivialTestGraphInputYielder fake_input(2, 2, 2, false, {"/CPU:0", "/GPU:0"}); + TrivialTestGraphInputYielder fake_input = SimpleGraph(); GrapplerItem item; CHECK(fake_input.NextItem(&item)); diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index d7ac58c99d..451ef6cabb 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -70,6 +70,26 @@ tf_cc_test( ) cc_library( + name = "latency_all_edges", + srcs = ["latency_all_edges.cc"], + hdrs = [ + "latency_all_edges.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ] + tf_protos_all(), +) + +cc_library( name = "map_and_batch_fusion", srcs = ["map_and_batch_fusion.cc"], hdrs = [ @@ -213,6 +233,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":function_rename", + ":latency_all_edges", ":map_and_batch_fusion", ":map_fusion", ":noop_elimination", @@ -220,3 +241,17 @@ cc_library( ], alwayslink = 1, ) + +tf_cc_test( + name = "latency_all_edges_test", + srcs = ["latency_all_edges_test.cc"], + deps = [ + ":graph_utils", + ":latency_all_edges", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ], +) diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 6ce6533369..838787d2a5 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -27,11 +27,17 @@ namespace { constexpr char kConstOpName[] = "Const"; template <typename Predicate, typename Collection> -int GetElementIdxWithPredicate(const Predicate& predicate, - const Collection& collection) { - auto it = std::find_if(collection.begin(), collection.end(), predicate); - if (it == collection.end()) return -1; - return std::distance(collection.begin(), it); +std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate, + const Collection& collection) { + std::vector<int> indices = {}; + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + indices.push_back(idx); + } + idx++; + } + return indices; } std::vector<int> CreateNameIndex(const GraphDef& graph) { @@ -189,29 +195,39 @@ bool ContainsFunctionNodeWithName(const string& name, } int FindGraphNodeWithName(const string& name, const GraphDef& graph) { - return GetElementIdxWithPredicate( + std::vector<int> indices = GetElementIndicesWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, graph.node()); + return indices.empty() ? -1 : indices.front(); } int FindNodeWithOp(const string& op, const GraphDef& graph) { - return GetElementIdxWithPredicate( + std::vector<int> indices = GetElementIndicesWithPredicate( + [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); + return indices.empty() ? -1 : indices.front(); +} + +std::vector<int> FindAllGraphNodesWithOp(const string& op, + const GraphDef& graph) { + return GetElementIndicesWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); } int FindGraphFunctionWithName(const string& name, const FunctionDefLibrary& library) { - return GetElementIdxWithPredicate( + std::vector<int> indices = GetElementIndicesWithPredicate( [&name](const FunctionDef& function) { return function.signature().name() == name; }, library.function()); + return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithName(const string& name, const FunctionDef& function) { - return GetElementIdxWithPredicate( + std::vector<int> indices = GetElementIndicesWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, function.node_def()); + return indices.empty() ? -1 : indices.front(); } void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph, @@ -219,7 +235,12 @@ void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph, string name = prefix; int id = graph->node_size(); while (ContainsGraphNodeWithName(name, *graph)) { - name = strings::StrCat(prefix, "/_", id); + if (name.rfind("_generated") != std::string::npos && + (name.rfind("_generated") == (name.size() - strlen("_generated")))) { + name.insert(name.rfind("_generated"), strings::StrCat("/_", id)); + } else { + name = strings::StrCat(prefix, "/_", id); + } ++id; } node->set_name(std::move(name)); diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 0847748802..39c687b501 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -90,10 +90,15 @@ int FindGraphFunctionWithName(const string& name, // function node does not exist. int FindFunctionNodeWithName(const string& name, const FunctionDef& function); -// Returns the index of a node with the given op or -1 if no such node +// Returns the index of the first node with the given op or -1 if no such node // exists. int FindNodeWithOp(const string& op, const GraphDef& graph); +// Returns the list of indices of all nodes with the given op or empty list if +// no such node exists. +std::vector<int> FindAllGraphNodesWithOp(const string& op, + const GraphDef& graph); + // Sets the node name using `prefix` as a prefix while guaranteeing the name // is unique across the graph. void SetUniqueGraphNodeName(const string& prefix, GraphDef* graph, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 59ed79ab8f..e6789d47b5 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -167,10 +167,34 @@ TEST(GraphUtilsTest, FindNodeWithOp) { EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1); AddNode("A", "OpA", {}, {}, &graph); - EXPECT_NE(FindNodeWithOp("OpA", *graph.GetGraph()), -1); + AddNode("B", "OpB", {"A"}, {}, &graph); + AddNode("A2", "OpA", {"B"}, {}, &graph); + EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0); - graph.DeleteNodes({"A"}); + graph.DeleteNodes({"B"}); + EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1); + EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1); +} + +TEST(GraphUtilsTest, FindAllGraphNodesWithOp) { + GraphDef graph_def; + MutableGraphView graph(&graph_def); EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1); + + AddNode("A", "OpA", {}, {}, &graph); + AddNode("B", "OpB", {"A"}, {}, &graph); + AddNode("A2", "OpA", {"B"}, {}, &graph); + std::vector<int> result_indices = + FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); + EXPECT_EQ(result_indices.size(), 2); + EXPECT_EQ(result_indices.at(0), 0); + EXPECT_EQ(result_indices.at(1), 2); + + graph.DeleteNodes({"A2"}); + std::vector<int> result_indices_new = + FindAllGraphNodesWithOp("OpA", *graph.GetGraph()); + EXPECT_EQ(result_indices_new.size(), 1); + EXPECT_EQ(result_indices_new.at(0), 0); } TEST(GraphUtilsTest, SetUniqueGraphNodeName) { diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc new file mode 100644 index 0000000000..0b25b1ea9d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/latency_all_edges.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char kInsertOpName[] = "LatencyStatsDataset"; + +NodeDef make_latency_node(const NodeDef& node, MutableGraphView* graph) { + NodeDef new_node; + new_node.set_op(kInsertOpName); + graph_utils::SetUniqueGraphNodeName( + strings::StrCat(kInsertOpName, "_generated"), graph->GetGraph(), + &new_node); + // Set the input of LatencyDataset node as `node` + new_node.add_input(node.name()); + + NodeDef* tag = graph_utils::AddScalarConstNode<StringPiece>( + StringPiece("record_latency_" + node.name()), graph); + new_node.add_input(tag->name()); + + // Set `output_types` and `output_shapes` attributes. + for (auto key : {"output_shapes", "output_types"}) { + if (node.attr().find(key) != node.attr().end()) { + (*new_node.mutable_attr())[key] = node.attr().at(key); + } else { + const char* kInferredAttrPrefix = "T"; + if (node.attr().find(strings::StrCat(kInferredAttrPrefix, key)) != + node.attr().end()) { + (*new_node.mutable_attr())[key] = + node.attr().at(strings::StrCat(kInferredAttrPrefix, key)); + } + } + } + return new_node; +} + +} // namespace + +Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + MutableGraphView graph(output); + + // Add LatencyDatasetOp node after each node. + // TODO(shivaniagrawal): Add Op to return Latency for the particular Op than + // for the edge (e2 - e1?). + for (const NodeDef& node : item.graph.node()) { + if (node.op().rfind("Dataset") != node.op().size() - strlen("Dataset") || + node.attr().empty() || + node.name().rfind("_generated") == + node.name().size() - strlen("_generated")) { + // TODO(b/111805951): Replace this with non-approximate way to check if + // node corresponds to a `Dataset` op. + continue; + } + GraphView::OutputPort output_port = graph.GetOutputPort(node.name(), 0); + auto fanout = graph.GetFanout(output_port); + if (fanout.size() > 1) { + LOG(WARNING) << node.name() << " has fanout size " << fanout.size(); + continue; + } else { // fanout will have size 0 for last dataset node in the pipeline. + if (fanout.size() == 1) { + NodeDef* output_node = (*(fanout.begin())).node; + if (output_node->name().rfind("_generated") == + output_node->name().size() - strlen("_generated")) { + continue; + } + } + } + + graph.InsertNode(node, make_latency_node(node, &graph)); + } + return Status::OK(); +} + +void LatencyAllEdges::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(LatencyAllEdges, "latency_all_edges"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.h b/tensorflow/core/grappler/optimizers/data/latency_all_edges.h new file mode 100644 index 0000000000..f6c71a9ec7 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +class LatencyAllEdges : public CustomGraphOptimizer { + public: + LatencyAllEdges() = default; + ~LatencyAllEdges() override = default; + + string name() const override { return "latency_all_edges"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_LATENCY_ALL_EDGES_H_ diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc new file mode 100644 index 0000000000..6789cf5bd6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges_test.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/latency_all_edges.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(LatencyAllEdgesTest, AddLatenciesAfterTensorMapPrefetch) { + using test::function::NDef; + GrapplerItem item; + NodeDef component_node = + NDef("component_nodes", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}); + NodeDef from_tensor_node = + NDef("from_tensor_nodes", "TensorDataset", {"component_nodes"}, + {{"Toutput_types", {}}, {"output_shapes", {}}}); + + NodeDef captured_input_node = NDef("captured_input_node", "Const", {}, + {{"value", ""}, {"dtype", DT_STRING}}); + NodeDef map_node = NDef("map_node", "MapDataset", + {"from_tensor_node", "captured_input_node"}, + {{"f", {}}, + {"Targumemts", {}}, + {"output_shapes", {}}, + {"output_types", {}}}); + NodeDef buffer_size_node = NDef("buffer_size_node", "Const", {}, + {{"value", 1}, {"dtype", DT_INT32}}); + NodeDef prefetch_node = NDef("prefetch_node", "Prefetch_Dataset", + {"map_node", "buffer_size_node"}, + {{"output_shapes", {}}, {"output_types", {}}}); + + item.graph = test::function::GDef({component_node, from_tensor_node, + captured_input_node, map_node, + buffer_size_node, prefetch_node}); + + LatencyAllEdges optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("LatencyStatsDataset", output)); + std::vector<int> latency_node_indices = + graph_utils::FindAllGraphNodesWithOp("LatencyStatsDataset", output); + EXPECT_EQ(latency_node_indices.size(), 3); + std::vector<NodeDef> dataset_nodes = {std::move(from_tensor_node), + std::move(map_node), + std::move(prefetch_node)}; + for (int i = 0; i < latency_node_indices.size(); i++) { + NodeDef latency_node = output.node(latency_node_indices[i]); + EXPECT_EQ(latency_node.input_size(), 2); + EXPECT_EQ(latency_node.input(0), dataset_nodes[i].name()); + EXPECT_TRUE( + AreAttrValuesEqual(latency_node.attr().at("output_shapes"), + dataset_nodes[i].attr().at("output_shapes"))); + if (dataset_nodes[i].attr().find("output_types") != + dataset_nodes[i].attr().end()) { + EXPECT_TRUE( + AreAttrValuesEqual(latency_node.attr().at("output_types"), + dataset_nodes[i].attr().at("output_types"))); + } else { + if (dataset_nodes[i].attr().find("Toutput_types") != + dataset_nodes[i].attr().end()) { + EXPECT_TRUE( + AreAttrValuesEqual(latency_node.attr().at("output_types"), + dataset_nodes[i].attr().at("Toutput_types"))); + } + } + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow |