aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-10-09 11:54:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:58:43 -0700
commit072fcb995a3fd658ee2461b59b159498c710513d (patch)
treef3def3d3ac6e270ad32e428889a79d662c8bc9cf /tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
parent12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (diff)
[tf.data] NUMA-aware MapAndBatch dataset.
PiperOrigin-RevId: 216395709
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/graph_test_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
index b2eec7220e..1f03c6515c 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
@@ -44,6 +45,21 @@ NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
{"output_types", gtl::ArraySlice<TensorShape>{}}});
}
+NodeDef MakeMapAndBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece batch_size_node_name,
+ StringPiece num_parallel_calls_node_name,
+ StringPiece drop_remainder_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapAndBatchDatasetV2",
+ {string(input_node_name), "", string(batch_size_node_name),
+ string(num_parallel_calls_node_name), string(drop_remainder_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
} // end namespace graph_tests_utils
} // end namespace grappler
} // end namespace tensorflow