diff options
author | Derek Murray <mrry@google.com> | 2018-09-04 14:01:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-04 14:17:53 -0700 |
commit | 06e8109af2e5ae5bc149e25fc64fbf66d6c8b817 (patch) | |
tree | b13b214063b3a4ba8f15c26e9170c9f44c49b854 /tensorflow/core/graph | |
parent | 8ef276fd2181fb71c2e232f60aa45ee96cb5905b (diff) |
[tf.data] Add internal optimizations for executing simple functions in `MapDataset`.
PiperOrigin-RevId: 211520001
Diffstat (limited to 'tensorflow/core/graph')
-rw-r--r-- | tensorflow/core/graph/testlib.cc | 27 | ||||
-rw-r--r-- | tensorflow/core/graph/testlib.h | 9 |
2 files changed, 36 insertions, 0 deletions
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ea7788f654..0a38aa1c91 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) { return ret; } +Node* CheckNumerics(Graph* g, Node* in, const string& message) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics") + .Input(in) + .Attr("message", message) + .Finalize(g, &ret)); + return ret; +} + +Node* Arg(Graph* g, int64 index, DataType type) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg") + .Attr("T", type) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + +Node* Retval(Graph* g, int64 index, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval") + .Input(in) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 8585b35a19..bd0284d43a 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type); // Add a DiagPart node in "g". Node* DiagPart(Graph* g, Node* in, DataType type); +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64 index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64 index, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow |