diff options
author | Brennan Saeta <saeta@google.com> | 2018-10-09 11:54:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 11:58:43 -0700 |
commit | 072fcb995a3fd658ee2461b59b159498c710513d (patch) | |
tree | f3def3d3ac6e270ad32e428889a79d662c8bc9cf /tensorflow/core/grappler/optimizers/data/graph_test_utils.cc | |
parent | 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff) |
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_test_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/graph_test_utils.cc | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc index b2eec7220e..1f03c6515c 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { namespace grappler { @@ -44,6 +45,21 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, {"output_types", gtl::ArraySlice<TensorShape>{}}}); } +NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name, + StringPiece batch_size_node_name, + StringPiece num_parallel_calls_node_name, + StringPiece drop_remainder_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapAndBatchDatasetV2", + {string(input_node_name), "", string(batch_size_node_name), + string(num_parallel_calls_node_name), string(drop_remainder_node_name)}, + {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<TensorShape>{}}}); +} + } // end namespace graph_tests_utils } // end namespace grappler } // end namespace tensorflow |