aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-10 02:42:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-10 02:47:15 -0700
commitd6a3d6a8295359364c86aecc479e6392bcde0ce4 (patch)
tree98658454a85871179cf61e734d2edeb4abab024a
parentdd7d31fa7bfa357e58987c2f3881d99c8050b6de (diff)
Automated rollback of commit 950cf87104bfee28e2165fe368f66337b8a1336d
PiperOrigin-RevId: 216500702
-rw-r--r--tensorflow/core/graph/graph.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD34
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc150
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc103
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py1
6 files changed, 31 insertions, 280 deletions
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index a17491d4f7..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -34,7 +34,7 @@ namespace tensorflow {
const int Graph::kControlSlot = -1;
-struct NodeProperties {
+class NodeProperties {
public:
NodeProperties(const OpDef* op_def, const NodeDef& node_def,
const DataTypeSlice inputs, const DataTypeSlice outputs)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 09018d0124..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,11 +9,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/cc:ops",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
- "//tensorflow/core:core_cpu",
- "//tensorflow/cc:scope_internal",
- "//tensorflow/cc:cc_ops",
] + tf_protos_all()
cc_library(
@@ -46,24 +42,6 @@ cc_library(
],
)
-tf_cc_test(
- name = "vectorizer_registry_test",
- srcs = ["vectorizer_registry_test.cc"],
- deps = [
- ":vectorizer_registry",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- ] + tf_protos_all(),
-)
-
-cc_library(
- name = "add_vectorizer",
- srcs = ["add_vectorizer.cc"],
- deps = VECTORIZER_DEPS,
- alwayslink = 1,
-)
-
cc_library(
name = "cast_vectorizer",
srcs = ["cast_vectorizer.cc"],
@@ -83,10 +61,20 @@ cc_library(
hdrs = ["vectorizer_registry.h"],
visibility = ["//visibility:public"],
deps = [
- ":add_vectorizer",
":cast_vectorizer",
":unpack_vectorizer",
":vectorizer",
":vectorizer_registry",
],
)
+
+tf_cc_test(
+ name = "vectorizer_registry_test",
+ srcs = ["vectorizer_registry_test.cc"],
+ deps = [
+ ":vectorizer_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc
deleted file mode 100644
index d90a51b01a..0000000000
--- a/tensorflow/core/grappler/optimizers/data/vectorization/add_vectorizer.cc
+++ /dev/null
@@ -1,150 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/cc/framework/ops.h"
-#include "tensorflow/cc/framework/scope_internal.h"
-#include "tensorflow/cc/ops/array_ops.h"
-#include "tensorflow/cc/ops/math_ops.h"
-#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
-
-namespace tensorflow {
-namespace grappler {
-
-namespace {
-
-const char* const kExpandDimsPrefix = "vectorized/expanddims/";
-
-// Reshapes stacked inputs for broadcast. Stacked inputs have an extra leading
-// dimension, which may cause automatic broadcasting rules to expand the
-// input dimensions wrongly when the unstacked shapes have different ranks.
-// To avoid that, we reshape stacked inputs to the maximum rank they need
-// to be broadcasted to.
-//
-// For example, suppose we have inputs A and B, where A is a stacked tensor with
-// shape [n, 5] (where n is the stack size) and B is an unstacked tensor with
-// shape [12, 7, 5]. If we added them directly, tensorflow broadcasting rules
-// would expand the dimensions of A to [1, n, 5], then (incorrectly) check that
-// the dimensions n and 7 are compatible, and if so, create an output of shape
-// [12, 7, 5]. However, correct addition of these inputs would create an output
-// with shape [n, 12, 7, 5]: we need to manually expand the dimensions of A
-// *after* the leading dimension, i.e. expand A to the shape [n, 1, 1, 5] before
-// broadcasting.
-Status ExpandDimsForBroadcast(std::vector<WrappedTensor>* inputs, Graph* g) {
- Status status;
- Scope parent = NewInternalScope(g, &status, nullptr);
- Scope s = parent.NewSubScope(kExpandDimsPrefix);
-
- // TODO(rachelim): We can potentially get rid of all these ops if shapes are
- // known statically
-
- Output const_0 = ops::Const(s, 0);
- Output const_1 = ops::Const(s, 1);
-
- std::vector<Output> ranks;
- ranks.reserve(inputs->size());
-
- // Get the stacked rank of each input
- for (const auto& input : *inputs) {
- Output rank = ops::Rank(s, Output(input.node, input.output_index));
-
- if (!input.stacked) {
- // If the input is unstacked, add 1
- rank = ops::Add(s, rank, const_1);
- }
-
- ranks.push_back(rank);
- }
-
- // Pack the ranks into one tensor to get the max
- Output packed_ranks = ops::Stack(s, ranks);
-
- Output max_rank =
- ops::Max(s, packed_ranks, const_0, ops::Max::Attrs().KeepDims(true));
-
- std::vector<WrappedTensor> expanded_inputs;
- expanded_inputs.reserve(inputs->size());
-
- // For all inputs that are stacked, expand dimensions after dim 0.
- for (size_t i = 0; i < inputs->size(); ++i) {
- if (!inputs->at(i).stacked) {
- expanded_inputs.push_back(inputs->at(i));
- continue;
- }
-
- Output input(inputs->at(i).node, inputs->at(i).output_index);
-
- // Number of dimensions to expand
- Output rank_diff = ops::Sub(s, max_rank, ranks[i]);
-
- // [1] * rank_diff
- Output ones = ops::Tile(s, ops::Const(s, {1}), rank_diff);
-
- Output const_vec_1 = ops::Const(s, {1});
-
- Output shape = ops::Shape(s, input);
-
- // shape[:1]
- Output concat_pre =
- ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1,
- ops::StridedSlice::Attrs().BeginMask(1));
-
- // shape[1:]
- Output concat_post =
- ops::StridedSlice(s, shape, const_vec_1, const_vec_1, const_vec_1,
- ops::StridedSlice::Attrs().EndMask(1));
-
- // tf.concat([shape[:1], ones, shape[1:]], 0)
- Output new_shape = ops::Concat(s, {concat_pre, ones, concat_post}, const_0);
-
- Output result = ops::Reshape(s, input, new_shape);
-
- expanded_inputs.push_back({result.node(), 0, true});
- }
-
- inputs->swap(expanded_inputs);
- return status;
-}
-
-class AddVectorizer : public Vectorizer {
- public:
- Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<WrappedTensor>&& inputs,
- std::vector<WrappedTensor>* outputs) override {
- if (node.num_inputs() != 2) {
- return errors::Internal("Add op should only have two inputs.");
- }
-
- TF_RETURN_IF_ERROR(ExpandDimsForBroadcast(&inputs, outer_scope));
-
- // Add new Add node with the same op and attrs as the original node
- Node* new_add_node;
- TF_RETURN_IF_ERROR(NodeBuilder("Add", "Add")
- .Input(inputs[0].node, inputs[0].output_index)
- .Input(inputs[1].node, inputs[1].output_index)
- .Finalize(outer_scope, &new_add_node));
-
- // Add output mappings
- outputs->push_back({new_add_node, 0, true});
- return Status::OK();
- }
-};
-
-REGISTER_VECTORIZER("Add", AddVectorizer);
-
-} // namespace
-} // namespace grappler
-} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 8b93b1f2b8..d977ff3198 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -64,18 +64,9 @@ void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
}
}
-// Update node attrs to keep its properties consistent with the function
-void UpdateMapDefunAttrs(FunctionBody* map_defun_fn, Node* map_defun_node) {
- map_defun_node->AddAttr("output_types", map_defun_fn->ret_types);
-
- // TODO(rachelim): Propagate precise shapes if they're known, which may enable
- // subsequent optimizations.
- map_defun_node->AddAttr("output_shapes", std::vector<PartialTensorShape>(
- map_defun_fn->ret_types.size()));
-}
-
Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
DataType type = output.first->output_type(output.second);
int index = map_defun_fn->ret_nodes.size();
@@ -92,13 +83,13 @@ Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
map_defun_fn->ret_nodes.push_back(ret_node);
map_defun_fn->ret_types.push_back(type);
- UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
return s;
}
void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
<< "Trying to remove output that doesn't exist. Output number: "
<< output_position;
@@ -111,7 +102,6 @@ void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
output_position);
map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
output_position);
- UpdateMapDefunAttrs(map_defun_fn, map_defun_node);
// Renumber the nodes and edges that come after
for (int i = 0; i < num_later_outputs; ++i) {
@@ -352,6 +342,13 @@ void Vectorization::VectorizeHelper() {
// need the MapDefun node and can delete it.
if (map_defun_fn_->ret_nodes.empty()) {
outer_scope_->RemoveNode(map_defun_node_);
+ } else {
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index be498d150b..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -145,7 +145,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
FunctionDef* vectorized;
Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
LOG(ERROR) << s;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
@@ -237,7 +237,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
auto map_defun_node = vectorized->node_def(
function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
@@ -311,7 +311,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@@ -389,7 +389,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@@ -475,7 +475,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& unpack_node = vectorized->node_def(
@@ -574,7 +574,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
const NodeDef& cast_node = vectorized->node_def(
@@ -654,7 +654,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
// We check this somewhat manually as the names of nodes may have changed
EXPECT_EQ(vectorized->node_def_size(), 1);
@@ -738,7 +738,7 @@ TEST(VectorizeMapDefunTest, VectorizeConst) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
@@ -817,7 +817,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
auto const_node = vectorized->node_def(
@@ -902,7 +902,7 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
*lib.add_function() = inner;
FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
auto find_const = [vectorized](int val) -> const NodeDef* {
for (const auto& n : vectorized->node_def()) {
@@ -924,89 +924,6 @@ TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
}
-// Before:
-//
-// +------+
-// +-----------------+ Arg0 +----------------------+
-// | +---+--+ |
-// | | |
-// | +---v--+ |
-// | +-------------+ Arg0 +------------------+ |
-// | | +---+--+ | |
-// | | | | |
-// | | | +-----+ | |
-// | | | |Const| | |
-// | | | +-+---+ | |
-// | | | | | |
-// | | | +--------+ | |
-// | | | | | |
-// | | +-v---v-+ | |
-// | | | Add | | |
-// | | +-+-----+ | |
-// | | | | |
-// | | | | |
-// | | MapDefun +-v----+ | |
-// | +---------------| Ret |----------------+ |
-// | +--v---+ |
-// | | |
-// | | |
-// | +--v---- |
-// +-------------------| Ret |--------------------+
-// +------+
-//
-//
-// After:
-//
-// +------+
-// +------------+ Arg0 +----------------------+
-// | +---+--+ |
-// | | |
-// | | +-----+ |
-// | | |Const| |
-// | +-v---------+ +--+--+ |
-// | |ExpandDims*| | |
-// | +-----+-----+ | |
-// | | | |
-// | +-----+ +-----+ |
-// | | | |
-// | +-v-v-+ |
-// | | Add | |
-// | +--+--+ |
-// | | |
-// | +---v--+ |
-// +-----------------------+ Ret +-----------+
-// +------+
-//
-TEST(VectorizeMapDefunTest, VectorizeDefunAdd) {
- // Note that this checks that the "Add" vectorizer is successful, but does not
- // check that the transformed function is correct (i.e. produces the same
- // output as the unvectorized map defun). For the latter, the tests are in
- // tensorflow/python/data/experimental/kernel_tests/optimization/
- // map_vectorization_test.py
- FunctionDef inner = FunctionDefHelper::Create(
- "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
- {/* nodes */ FunctionDefHelper::Const("Const", 2),
- {{"Add"}, "Add", {"arg0", "Const:output:0"}, {{"T", DT_INT32}}}},
- {{"ret0", "Add:z:0"}});
-
- FunctionDef outer = FunctionDefHelper::Create(
- "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
- {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
-
- NodeDef* map_defun =
- AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
- inner.signature().name(), &outer);
- CHECK_NOTNULL(map_defun);
-
- FunctionDefLibrary lib;
- *lib.add_function() = outer;
- *lib.add_function() = inner;
- FunctionDef* vectorized;
- TF_EXPECT_OK(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized));
- EXPECT_TRUE(
- !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
-}
-
// TODO(rachelim): More test cases when we get around to implementing them:
// [] A badly defined converter, e.g. doesn't produce nodes that have the
// same number of outputs/inputs as the nodes to be converted
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index d1d6cf28ab..803ff87924 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -80,7 +80,6 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
("Basic", lambda x: (x, x + 1), None),
("Const", lambda x: 2, 12),
("Parallel", lambda x: (x, x + 1), 12),
- ("Broadcast", lambda x: x + np.random.rand(5, 4, 3, 2), None),
("Gather", lambda x: array_ops.gather(x, 0), 12),
)
def testOptimization(self, map_fn, num_parallel_calls):