From d6a3d6a8295359364c86aecc479e6392bcde0ce4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 10 Oct 2018 02:42:39 -0700 Subject: Automated rollback of commit 950cf87104bfee28e2165fe368f66337b8a1336d PiperOrigin-RevId: 216500702 --- tensorflow/core/graph/graph.cc | 2 +- .../grappler/optimizers/data/vectorization/BUILD | 34 ++--- .../data/vectorization/add_vectorizer.cc | 150 --------------------- .../optimizers/data/vectorization_utils.cc | 21 ++- .../optimizers/data/vectorization_utils_test.cc | 103 ++------------ .../optimization/map_vectorization_test.py | 1 - 6 files changed, 31 insertions(+), 280 deletions(-) delete mode 100644 tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index a17491d4f7..6f068546d2 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -34,7 +34,7 @@ namespace tensorflow { const int Graph::kControlSlot = -1; -struct NodeProperties { +class NodeProperties { public: NodeProperties(const OpDef* op_def, const NodeDef& node_def, const DataTypeSlice inputs, const DataTypeSlice outputs) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 09018d0124..985d6c6c3a 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -9,11 +9,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") VECTORIZER_DEPS = [ ":vectorizer_registry", - "//tensorflow/cc:ops", "//tensorflow/core/grappler/optimizers/data:graph_utils", - "//tensorflow/core:core_cpu", - "//tensorflow/cc:scope_internal", - "//tensorflow/cc:cc_ops", ] + tf_protos_all() cc_library( @@ -46,24 +42,6 @@ cc_library( ], ) -tf_cc_test( - name = "vectorizer_registry_test", - srcs = ["vectorizer_registry_test.cc"], - deps = [ - ":vectorizer_registry", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core:testlib", - ] + tf_protos_all(), -) - -cc_library( - name = "add_vectorizer", - srcs = ["add_vectorizer.cc"], - deps = VECTORIZER_DEPS, - alwayslink = 1, -) - cc_library( name = "cast_vectorizer", srcs = ["cast_vectorizer.cc"], @@ -83,10 +61,20 @@ cc_library( hdrs = ["vectorizer_registry.h"], visibility = ["//visibility:public"], deps = [ - ":add_vectorizer", ":cast_vectorizer", ":unpack_vectorizer", ":vectorizer", ":vectorizer_registry", ], ) + +tf_cc_test( + name = "vectorizer_registry_test", + srcs = ["vectorizer_registry_test.cc"], + deps = [ + ":vectorizer_registry", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + tf_protos_all(), +) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc deleted file mode 100644 index d90a51b01a..0000000000 --- a/tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/cc/framework/ops.h" -#include "tensorflow/cc/framework/scope_internal.h" -#include "tensorflow/cc/ops/array_ops.h" -#include "tensorflow/cc/ops/math_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" - -namespace tensorflow { -namespace grappler { - -namespace { - -const char* const kExpandDimsPrefix = "vectorized/expanddims/"; - -// Reshapes stacked inputs for broadcast. Stacked inputs have an extra leading -// dimension, which may cause automatic broadcasting rules to expand the -// input dimensions wrongly when the unstacked shapes have different ranks. -// To avoid that, we reshape stacked inputs to the maximum rank they need -// to be broadcasted to. -// -// For example, suppose we have inputs A and B, where A is a stacked tensor with -// shape [n, 5] (where n is the stack size) and B is an unstacked tensor with -// shape [12, 7, 5]. If we added them directly, tensorflow broadcasting rules -// would expand the dimensions of A to [1, n, 5], then (incorrectly) check that -// the dimensions n and 7 are compatible, and if so, create an output of shape -// [12, 7, 5]. However, correct addition of these inputs would create an output -// with shape [n, 12, 7, 5]: we need to manually expand the dimensions of A -// *after* the leading dimension, i.e. expand A to the shape [n, 1, 1, 5] before -// broadcasting. -Status ExpandDimsForBroadcast(std::vector* inputs, Graph* g) { - Status status; - Scope parent = NewInternalScope(g, &status, nullptr); - Scope s = parent.NewSubScope(kExpandDimsPrefix); - - // TODO(rachelim): We can potentially get rid of all these ops if shapes are - // known statically - - Output const_0 = ops::Const(s, 0); - Output const_1 = ops::Const(s, 1); - - std::vector ranks; - ranks.reserve(inputs->size()); - - // Get the stacked rank of each input - for (const auto& input : *inputs) { - Output rank = ops::Rank(s, Output(input.node, input.output_index)); - - if (!input.stacked) { - // If the input is unstacked, add 1 - rank = ops::Add(s, rank, const_1); - } - - ranks.push_back(rank); - } - - // Pack the ranks into one tensor to get the max - Output packed_ranks = ops::Stack(s, ranks); - - Output max_rank = - ops::Max(s, packed_ranks, const_0, ops::Max::Attrs().KeepDims(true)); - - std::vector expanded_inputs; - expanded_inputs.reserve(inputs->size()); - - // For all inputs that are stacked, expand dimensions after dim 0. - for (size_t i = 0; i < inputs->size(); ++i) { - if (!inputs->at(i).stacked) { - expanded_inputs.push_back(inputs->at(i)); - continue; - } - - Output input(inputs->at(i).node, inputs->at(i).output_index); - - // Number of dimensions to expand - Output rank_diff = ops::Sub(s, max_rank, ranks[i]); - - // [1] * rank_diff - Output ones = ops::Tile(s, ops::Const(s, {1}), rank_diff); - - Output const_vec_1 = ops::Const(s, {1}); - - Output shape = ops::Shape(s, input); - - // shape[:1] - Output concat_pre = - ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1, - ops::StridedSlice::Attrs().BeginMask(1)); - - // shape[1:] - Output concat_post = - ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1, - ops::StridedSlice::Attrs().EndMask(1)); - - // tf.concat([shape[:1], ones, shape[1:]], 0) - Output new_shape = ops::Concat(s, {concat_pre, ones, concat_post}, const_0); - - Output result = ops::Reshape(s, input, new_shape); - - expanded_inputs.push_back({result.node(), 0, true}); - } - - inputs->swap(expanded_inputs); - return status; -} - -class AddVectorizer : public Vectorizer { - public: - Status Vectorize(const Node& node, Graph* outer_scope, - std::vector&& inputs, - std::vector* outputs) override { - if (node.num_inputs() != 2) { - return errors::Internal("Add op should only have two inputs."); - } - - TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope)); - - // Add new Add node with the same op and attrs as the original node - Node* new_add_node; - TF_RETURN_IF_ERROR(NodeBuilder("Add", "Add") - .Input(inputs[0].node, inputs[0].output_index) - .Input(inputs[1].node, inputs[1].output_index) - .Finalize(outer_scope, &new_add_node)); - - // Add output mappings - outputs->push_back({new_add_node, 0, true}); - return Status::OK(); - } -}; - -REGISTER_VECTORIZER("Add", AddVectorizer); - -} // namespace -} // namespace grappler -} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index 8b93b1f2b8..d977ff3198 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -64,18 +64,9 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, } } -// Update node attrs to keep its properties consistent with the function -void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) { - map_defun_node->AddAttr("output_types", map_defun_fn->ret_types); - - // TODO(rachelim): Propagate precise shapes if they're known, which may enable - // subsequent optimizations. - map_defun_node->AddAttr("output_shapes", std::vector( - map_defun_fn->ret_types.size())); -} - Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, const TensorDesc& output) { + // Note that we don't update MapDefun attrs as we go, only when we are done DataType type = output.first->output_type(output.second); int index = map_defun_fn->ret_nodes.size(); @@ -92,13 +83,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); map_defun_fn->ret_nodes.push_back(ret_node); map_defun_fn->ret_types.push_back(type); - UpdateMapDefunAttrs(map_defun_fn, map_defun_node); return s; } void RemoveMapDefunOutput(int output_position, Graph* outer_scope, FunctionBody* map_defun_fn, Node* map_defun_node) { + // Note that we don't update MapDefun attrs as we go, only when we are done DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) << "Trying to remove output that doesn't exist. Output number: " << output_position; @@ -111,7 +102,6 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope, output_position); map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + output_position); - UpdateMapDefunAttrs(map_defun_fn, map_defun_node); // Renumber the nodes and edges that come after for (int i = 0; i < num_later_outputs; ++i) { @@ -352,6 +342,13 @@ void Vectorization::VectorizeHelper() { // need the MapDefun node and can delete it. if (map_defun_fn_->ret_nodes.empty()) { outer_scope_->RemoveNode(map_defun_node_); + } else { + // Update MapDefun node attrs accordingly + DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); + map_defun_node_->AddAttr( + "output_shapes", + std::vector(map_defun_fn_->ret_types.size())); + map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } } diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index be498d150b..a6020e36bb 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -145,7 +145,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { FunctionDef* vectorized; Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); LOG(ERROR) << s; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); EXPECT_EQ(GetRetval(*vectorized, 0), "ret0"); @@ -237,7 +237,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); auto map_defun_node = vectorized->node_def( function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized)); @@ -311,7 +311,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& cast_node = vectorized->node_def( @@ -389,7 +389,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& cast_node = vectorized->node_def( @@ -475,7 +475,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& unpack_node = vectorized->node_def( @@ -574,7 +574,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); const NodeDef& cast_node = vectorized->node_def( @@ -654,7 +654,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); // They should be unchanged // We check this somewhat manually as the names of nodes may have changed EXPECT_EQ(vectorized->node_def_size(), 1); @@ -738,7 +738,7 @@ TEST(VectorizeMapDefunTest, VectorizeConst) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized)); @@ -817,7 +817,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) { *lib.add_function() = outer; *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); EXPECT_TRUE( !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); auto const_node = vectorized->node_def( @@ -902,7 +902,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { *lib.add_function() = inner; FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); auto find_const = [vectorized](int val) -> const NodeDef* { for (const auto& n : vectorized->node_def()) { @@ -924,89 +924,6 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name())); } -// Before: -// -// +------+ -// +-----------------+ Arg0 +----------------------+ -// | +---+--+ | -// | | | -// | +---v--+ | -// | +-------------+ Arg0 +------------------+ | -// | | +---+--+ | | -// | | | | | -// | | | +-----+ | | -// | | | |Const| | | -// | | | +-+---+ | | -// | | | | | | -// | | | +--------+ | | -// | | | | | | -// | | +-v---v-+ | | -// | | | Add | | | -// | | +-+-----+ | | -// | | | | | -// | | | | | -// | | MapDefun +-v----+ | | -// | +---------------| Ret |----------------+ | -// | +--v---+ | -// | | | -// | | | -// | +--v---- | -// +-------------------| Ret |--------------------+ -// +------+ -// -// -// After: -// -// +------+ -// +------------+ Arg0 +----------------------+ -// | +---+--+ | -// | | | -// | | +-----+ | -// | | |Const| | -// | +-v---------+ +--+--+ | -// | |ExpandDims*| | | -// | +-----+-----+ | | -// | | | | -// | +-----+ +-----+ | -// | | | | -// | +-v-v-+ | -// | | Add | | -// | +--+--+ | -// | | | -// | +---v--+ | -// +-----------------------+ Ret +-----------+ -// +------+ -// -TEST(VectorizeMapDefunTest, VectorizeDefunAdd) { - // Note that this checks that the "Add" vectorizer is successful, but does not - // check that the transformed function is correct (i.e. produces the same - // output as the unvectorized map defun). For the latter, the tests are in - // tensorflow/python/data/experimental/kernel_tests/optimization/ - // map_vectorization_test.py - FunctionDef inner = FunctionDefHelper::Create( - "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, - {/* nodes */ FunctionDefHelper::Const("Const", 2), - {{"Add"}, "Add", {"arg0", "Const:output:0"}, {{"T", DT_INT32}}}}, - {{"ret0", "Add:z:0"}}); - - FunctionDef outer = FunctionDefHelper::Create( - "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, - {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); - - NodeDef* map_defun = - AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, - inner.signature().name(), &outer); - CHECK_NOTNULL(map_defun); - - FunctionDefLibrary lib; - *lib.add_function() = outer; - *lib.add_function() = inner; - FunctionDef* vectorized; - TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized)); - EXPECT_TRUE( - !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); -} - // TODO(rachelim): More test cases when we get around to implementing them: // [] A badly defined converter, e.g. doesn't produce nodes that have the // same number of outputs/inputs as the nodes to be converted diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py index d1d6cf28ab..803ff87924 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py @@ -80,7 +80,6 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): ("Basic", lambda x: (x, x + 1), None), ("Const", lambda x: 2, 12), ("Parallel", lambda x: (x, x + 1), 12), - ("Broadcast", lambda x: x + np.random.rand(5, 4, 3, 2), None), ("Gather", lambda x: array_ops.gather(x, 0), 12), ) def testOptimization(self, map_fn, num_parallel_calls): -- cgit v1.2.3