aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-09-26 13:31:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 13:35:03 -0700
commitd600b1b55fa851648918fed7a67f61eefd554034 (patch)
treeb54ad0dafbfc7ce4b231b2cc330eca47489aed80 /tensorflow/core/grappler
parentb61ca2d62ab9792e1f386c2e598fee4d07b51f1c (diff)
[tf.data] Small utils cleanup to expose generic function
PiperOrigin-RevId: 214659488
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc31
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc9
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc12
5 files changed, 36 insertions, 33 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index d42a560cb2..d198a2a591 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -89,10 +89,10 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
index e95ea1a4c1..e3f6d8e1ea 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
@@ -22,23 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace function_utils {
-namespace {
-
-template <typename Predicate, typename Collection>
-std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
- const Collection& collection) {
- std::vector<int> indices = {};
- unsigned idx = 0;
- for (auto&& element : collection) {
- if (predicate(element)) {
- indices.push_back(idx);
- }
- idx++;
- }
- return indices;
-}
-
-} // namespace
FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
const string& output, int position)
@@ -152,32 +136,27 @@ bool ContainsFunctionOutputWithName(StringPiece name,
}
int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().input_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().output_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
function.node_def());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; },
function.node_def());
-
- return indices.empty() ? -1 : indices.front();
}
void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 48825d0346..3eaaf8fbef 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -201,25 +201,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const FunctionDef& function) {
return function.signature().name() == name;
},
library.function());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
- return indices.empty() ? -1 : indices.front();
}
std::vector<int> FindAllGraphNodesWithOp(const string& op,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 189a72d255..5dd7819100 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -31,6 +31,21 @@ namespace tensorflow {
namespace grappler {
namespace graph_utils {
+// Returns the index of the first element in collection that fulfills predicate.
+// If no such element exists, returns -1.
+template <typename Predicate, typename Collection>
+int GetFirstElementIndexWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ return idx;
+ }
+ idx++;
+ }
+ return -1;
+}
+
// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 6877c207c4..db986542b2 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -24,6 +24,18 @@ namespace grappler {
namespace graph_utils {
namespace {
+TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
+ std::vector<int> vec({1, 2, 3, 4, 5, 6});
+ auto result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 3 == 0; }, vec);
+
+ EXPECT_EQ(result, 2);
+
+ result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 7 == 0; }, vec);
+ EXPECT_EQ(result, -1);
+}
+
TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);