diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-26 13:31:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 13:35:03 -0700 |
commit | d600b1b55fa851648918fed7a67f61eefd554034 (patch) | |
tree | b54ad0dafbfc7ce4b231b2cc330eca47489aed80 /tensorflow/core/grappler | |
parent | b61ca2d62ab9792e1f386c2e598fee4d07b51f1c (diff) |
[tf.data] Small utils cleanup to expose generic function
PiperOrigin-RevId: 214659488
Diffstat (limited to 'tensorflow/core/grappler')
5 files changed, 36 insertions, 33 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index d42a560cb2..d198a2a591 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -89,10 +89,10 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc index e95ea1a4c1..e3f6d8e1ea 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -22,23 +23,6 @@ limitations under the License. namespace tensorflow { namespace grappler { namespace function_utils { -namespace { - -template <typename Predicate, typename Collection> -std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate, - const Collection& collection) { - std::vector<int> indices = {}; - unsigned idx = 0; - for (auto&& element : collection) { - if (predicate(element)) { - indices.push_back(idx); - } - idx++; - } - return indices; -} - -} // namespace FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name, const string& output, int position) @@ -152,32 +136,27 @@ bool ContainsFunctionOutputWithName(StringPiece name, } int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, function.signature().input_arg()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, function.signature().output_arg()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, function.node_def()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, function.node_def()); - - return indices.empty() ? -1 : indices.front(); } void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 48825d0346..3eaaf8fbef 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -201,25 +201,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) { int FindGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const FunctionDef& function) { return function.signature().name() == name; }, library.function()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } std::vector<int> FindAllGraphNodesWithOp(const string& op, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 189a72d255..5dd7819100 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -31,6 +31,21 @@ namespace tensorflow { namespace grappler { namespace graph_utils { +// Returns the index of the first element in collection that fulfills predicate. +// If no such element exists, returns -1. +template <typename Predicate, typename Collection> +int GetFirstElementIndexWithPredicate(const Predicate& predicate, + const Collection& collection) { + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + return idx; + } + idx++; + } + return -1; +} + // Adds a node to the graph. NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 6877c207c4..db986542b2 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -24,6 +24,18 @@ namespace grappler { namespace graph_utils { namespace { +TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) { + std::vector<int> vec({1, 2, 3, 4, 5, 6}); + auto result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 3 == 0; }, vec); + + EXPECT_EQ(result, 2); + + result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 7 == 0; }, vec); + EXPECT_EQ(result, -1); +} + TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph_def; MutableGraphView graph(&graph_def); |