diff options
author | Rachel Lim <rachelim@google.com> | 2018-10-04 09:26:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 09:35:51 -0700 |
commit | 1fb84c2e41c454939a02a69093cb214673eab343 (patch) | |
tree | e3ee1c19e3a73e1d1cddbc76d5573b7800b1048b | |
parent | ac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (diff) |
Add ability to vectorize nodes that do not derive from function arguments. (This indirectly handles "Const" outputs automagically, since they are always unstacked.)
PiperOrigin-RevId: 215749824
5 files changed, 492 insertions, 13 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 755af3361e..ee7c14e3ab 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -524,6 +524,7 @@ cc_library( deps = [ ":function_utils", ":graph_utils", + "//tensorflow/cc:ops", "@com_google_absl//absl/strings", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 9328a7ca99..ba521e79bc 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -44,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, // Function inputs and outputs are the same as original, just // with different shapes. *vectorized_func->mutable_signature() = orig_func.signature(); - graph_utils::SetUniqueGraphFunctionName("vectorized_function", library, + graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library, vectorized_func); // Add MapDefun node diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 2d6cf562b1..344c420902 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,10 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" -#include <memory> #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/ops.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" @@ -28,13 +28,13 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { @@ -45,6 +45,22 @@ namespace { // Describes a tensor with its operation Node and output position typedef std::pair<Node*, int> TensorDesc; +// Equivalent to python Pfor's WrappedTensor struct +struct WrappedTensor { + TensorDesc tensor; + + // Whether the tensor is stacked, i.e. represents the results of applying + // the operation on all slices of the input, where each row i of the + // tensor corresponds to the op's output on slice i of the input. False + // if the tensor is not stacked, i.e. represents the result of the op on + // a single slice of the input, where the result does not vary between + // slices. + bool stacked; + + WrappedTensor(TensorDesc&& tensor, bool stacked) + : tensor(std::move(tensor)), stacked(stacked) {} +}; + const char* const kRetValOp = "_Retval"; void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, @@ -132,7 +148,8 @@ class Vectorization { const NodeDef& map_defun_node, FunctionDef** result); private: - // Converts FunctionDefs to Graphs. + // Converts FunctionDefs to Graphs and adds mappings from + // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_. Status Initialize(const FunctionDef& outer_scope, const NodeDef& map_defun_node); @@ -162,9 +179,30 @@ class Vectorization { // the conversion map. Status AddConversionMapping(Node* op_node); - // Maps a tensor to the corresponding vectorized tensor. For example, - // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0} - std::map<TensorDesc, TensorDesc> conversion_map_; + // Given a tensor t in `unstacked`, stacks it by doing the equivalent of + // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of + // inputs to `map_defun_node_`. This stacked tensor will be compatible with + // the expected output shape of `map_defun_node_`. + // This is equivalent to the _stack function in python Pfor. + Status StackTensor(WrappedTensor* unstacked, TensorDesc* result); + + // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by + // doing a depth-first search from the ret nodes. Lifts nodes that are + // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly + // and add mappings to `conversion_map_`. + Status AddUnstackedNodeMappings(); + + // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor + // is unstacked. + bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status); + + // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input + // nodes to `conversion_map_`. + Status AddArgNodeMappings(); + + // Maps a tensor to the corresponding WrappedTensor. For example, + // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true) + std::map<TensorDesc, WrappedTensor> conversion_map_; // Unconvertible ret nodes std::set<Node*> unconvertible_; @@ -180,6 +218,10 @@ class Vectorization { std::unique_ptr<Graph> outer_scope_; std::unique_ptr<FunctionBody> map_defun_fn_; Node* map_defun_node_ = nullptr; // Owned by `outer_scope` + + // Caches the loop_len_node_ needed for tiling unstacked output. This + // corresponds to a vector with one element. + Node* loop_len_node_ = nullptr; // Owned by `outer_scope` Status status_; }; @@ -224,7 +266,7 @@ Status Vectorization::AddConversionMapping(Node* op_node) { // Add output mappings. for (size_t i = 0; i < op_node->num_outputs(); ++i) { - conversion_map_.insert({{op_node, i}, std::move(output_ports[i])}); + conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}}); } return Status::OK(); @@ -242,10 +284,22 @@ Status Vectorization::ConvertOutput(int output_position) { if (auto found = gtl::FindOrNull(conversion_map_, output)) { // It's possible the output already has a mapping, if it comes from a node // that has already been converted. - converted_output = *found; + if (found->stacked) { + converted_output = found->tensor; + } else { + // Some outputs may be unstacked if they don't derive from arg nodes + // (for example, if a function returns a constant). For these, we + // have to add extra nodes to tile it in the 0th dimension. + TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); + } } else { + // Note: All unstacked nodes are converted ahead of time in `Initialize`, + // and here we assume that all op vectorizers create only stacked outputs. + // This may not hold in the future, as more vectorizers are added that + // may actually create unstacked outputs. For example, see the `Shape` + // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); - converted_output = conversion_map_.at(output); + converted_output = conversion_map_.at(output).tensor; } ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, @@ -297,6 +351,7 @@ void Vectorization::VectorizeHelper() { map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } } + Status Vectorization::Initialize(const FunctionDef& outer_scope, const NodeDef& map_defun_node) { // Convert outer_scope and map_defun_fn to FunctionBodys so we can @@ -337,16 +392,184 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope, } map_defun_node_ = outer_scope_->FindNodeId(node_id); - // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to - // the conversion map + TF_RETURN_IF_ERROR(AddArgNodeMappings()); + + TF_RETURN_IF_ERROR(AddUnstackedNodeMappings()); + loop_len_node_ = nullptr; + + return Status::OK(); +} + +// TODO(rachelim): It might be profitable to use the C++ API for this instead of +// NodeBuilder +Status Vectorization::StackTensor(WrappedTensor* unstacked, + TensorDesc* result) { + // Note that all these nodes are necessary as the size of the batch may not be + // constant. + if (unstacked->stacked) { + return errors::Internal("Can only stack unstacked tensor."); + } + + Graph* g = outer_scope_.get(); + auto node_builder = [](StringPiece op) { + return NodeBuilder(strings::StrCat("vectorized/stack/", op), op); + }; + + auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph, + Node** result) { + TF_RETURN_IF_ERROR(val.status); + return node_builder("Const") + .Attr("value", val.tensor) + .Attr("dtype", val.tensor.dtype()) + .Finalize(graph, result); + }; + + // If loop_len_node_ hasn't been created yet, add the node and cache it. + if (loop_len_node_ == nullptr) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node)); + + Node* shape_node; + TF_RETURN_IF_ERROR( + node_builder("Shape").Input(input_node).Finalize(g, &shape_node)); + + Node* const_vec_0; + TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0)); + Node* const_vec_1; + TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1)); + + Node* strided_slice_node; + TF_RETURN_IF_ERROR(node_builder("StridedSlice") + .Input(shape_node) // input + .Input(const_vec_0) // begin + .Input(const_vec_1) // end + .Input(const_vec_1) // strides + .Finalize(g, &strided_slice_node)); + + // Produces a vector of length 1 + TF_RETURN_IF_ERROR(node_builder("Reshape") + .Input(strided_slice_node) // tensor + .Input(const_vec_1) // shape + .Finalize(g, &loop_len_node_)); + } + + Node* ones_shape; + TF_RETURN_IF_ERROR(node_builder("Shape") + .Input(unstacked->tensor.first) // input + .Finalize(g, &ones_shape)); + + Node* ones; + TF_RETURN_IF_ERROR( + node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones)); + + Node* const_0; + TF_RETURN_IF_ERROR(make_const(0, g, &const_0)); + + Node* multiples; + TF_RETURN_IF_ERROR(node_builder("Concat") + .Input(const_0) // concat_dim + .Input({{loop_len_node_, 0}, {ones, 0}}) // values + .Finalize(g, &multiples)); + + Node* expand_dims; + TF_RETURN_IF_ERROR(node_builder("ExpandDims") + .Input(unstacked->tensor.first) // input + .Input(const_0) // dim + .Finalize(g, &expand_dims)); + + TF_RETURN_IF_ERROR(node_builder("Tile") + .Input(expand_dims) // input + .Input(multiples) // multiples + .Finalize(g, &result->first)); + result->second = 0; + return Status::OK(); +} + +Status Vectorization::AddArgNodeMappings() { for (auto arg_node : map_defun_fn_->arg_nodes) { Node* input_node; TF_RETURN_IF_ERROR(map_defun_node_->input_node( arg_node->attrs().Find("index")->i(), &input_node)); - conversion_map_.insert({{arg_node, 0}, {input_node, 0}}); + conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}}); + + // Control inputs + conversion_map_.insert({{arg_node, Graph::kControlSlot}, + {{input_node, Graph::kControlSlot}, true}}); } + return Status::OK(); +} +bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, + Status* status) { + if (auto found = gtl::FindOrNull(conversion_map_, tensor)) { + return !found->stacked; + } + + if (tensor.first->op_def().is_stateful()) { + // We don't lift stateful nodes directly out of the MapDefun, since they may + // have to be executed N times. + return false; + } + + bool is_unstacked = true; + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + // A node is unstacked if all of its inputs are unstacked + is_unstacked &= AddUnstackedNodeMappingsHelper( + {edge->src(), edge->src_output()}, status); + } + + if (!is_unstacked) { + return false; + } + + // If the node is unstacked, we copy it into outer_scope_ and + // add it to the map. Note that we don't clean up the nodes that are copied + // in map_defun_fn_, and rely on them being pruned out later. + Node* node = outer_scope_->AddNode(tensor.first->def(), status); + if (!status->ok()) return true; + + // Add input edges to nodes that should already have been lifted. + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node, + edge->dst_input()); + } else { + status->Update(errors::Internal( + "Could not find input conversion even though we did depth first " + "conversion.")); + } + } + + // Add output mappings + for (int i = 0; i < tensor.first->num_outputs(); ++i) { + conversion_map_.insert( + {{tensor.first, i}, WrappedTensor({node, i}, false)}); + } + conversion_map_.insert({{tensor.first, Graph::kControlSlot}, + WrappedTensor({node, Graph::kControlSlot}, false)}); + + return true; +} + +Status Vectorization::AddUnstackedNodeMappings() { + SetVector<Node*> unstacked_nodes; + Status s; + for (const auto& ret_node : map_defun_fn_->ret_nodes) { + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge)); + AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s); + TF_RETURN_IF_ERROR(s); + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index 1ff62217dd..a958d706c1 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -670,6 +670,257 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { cast_node.input(1) == control_input); } +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeConst) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Const:output:0"}}); + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Const", *vectorized)); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node.name()); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | +------+ +------+ | | +// | | |Const | |Const | | | +// | | +---+--+ +---+--+ | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | | +// | +------+ | +// | +------+ |Const | | +// | |Const | +---+--+ | +// | +---+--+ | | +// | : +---v--+ | +// | ::::::> Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | +Stack*+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2), + FunctionDefHelper::Const("ConstDep", 3)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64, + false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto find_const = [vectorized](int val) -> const NodeDef* { + for (const auto& n : vectorized->node_def()) { + if (n.attr().at("value").tensor().int_val(0) == val) { + return &n; + } + } + return nullptr; + }; + + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = find_const(2); + auto const_dep_node = find_const(3); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node->name()); + EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name())); +} + // TODO(rachelim): More test cases when we get around to implementing them: // [] A badly defined converter, e.g. doesn't produce nodes that have the // same number of outputs/inputs as the nodes to be converted diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index 32ebc49c40..971a2d94b9 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -78,6 +78,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("Basic", lambda x: (x, x + 1), None), + ("Const", lambda x: 2, 12), ("Parallel", lambda x: (x, x + 1), 12), ("Gather", lambda x: array_ops.gather(x, 0), 12), ) @@ -207,6 +208,9 @@ class MapVectorizationBenchmark(test.Benchmark): def benchmarkAddConst(self): self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const") + def benchmarkReturnConst(self): + self._benchmark_helper(lambda *args: [constant_op.constant(2)], "ret_const") + def benchmarkSelect(self): self._benchmark_helper(lambda *args: args[0], "select") |