aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/graph
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-09-04 14:01:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 14:17:53 -0700
commit06e8109af2e5ae5bc149e25fc64fbf66d6c8b817 (patch)
treeb13b214063b3a4ba8f15c26e9170c9f44c49b854 /tensorflow/core/graph
parent8ef276fd2181fb71c2e232f60aa45ee96cb5905b (diff)
[tf.data] Add internal optimizations for executing simple functions in `MapDataset`.
PiperOrigin-RevId: 211520001
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r--tensorflow/core/graph/testlib.cc27
-rw-r--r--tensorflow/core/graph/testlib.h9
2 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index ea7788f654..0a38aa1c91 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) {
return ret;
}
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+ .Input(in)
+ .Attr("message", message)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+ .Attr("T", type)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+ .Input(in)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index 8585b35a19..bd0284d43a 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type);
// Add a DiagPart node in "g".
Node* DiagPart(Graph* g, Node* in, DataType type);
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow