aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-05 17:01:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 17:08:15 -0700
commit15d399cd8590c18dc643d979883fe4201c8ea631 (patch)
tree1c913d2a5a0a45a89f6d3343f8c8cf297ba6f8e3 /tensorflow/core
parentab97f1323bd2a98d20ed82dc3ff8585481961f0d (diff)
[tf.data vectorization] Feed inputs to vectorizers with notion of stackedness
PiperOrigin-RevId: 215989259
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD10
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc116
9 files changed, 144 insertions, 105 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 37aa24b947..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -13,9 +13,19 @@ VECTORIZER_DEPS = [
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index 3af6bab409..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -19,13 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
@@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer {
auto new_cast_node = outer_scope->AddNode(node.def(), &s);
TF_RETURN_IF_ERROR(s);
- // Add input and output mappings
- input_ports->push_back({new_cast_node, 0});
- output_ports->push_back({new_cast_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 74ce520ce1..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -19,15 +19,15 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
- if (node.num_inputs() != 1) {
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
@@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer {
int new_axis = node.def().attr().at("axis").i() + 1;
new_unpack_node->AddAttr("axis", new_axis);
- // Add the input mappings
- input_ports->push_back({new_unpack_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
// Add the output mappings
int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- output_ports->push_back({new_unpack_node, i});
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index 56eb88c95e..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -18,15 +18,12 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
-
-// Describes a tensor with its operation Node and output position
-typedef std::pair<Node*, int> Port;
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -36,17 +33,17 @@ class Vectorizer {
// Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs. The new Node(s) collectively have the
+ // on elements of `inputs`. The new Node(s) collectively have the
// same number of input and output ports as the node being converted.
- // Adds mappings for the new nodes' input and output ports to `inputs` and
- // `outputs` respectively, where the i'th Port in inputs/outputs
- // corresponds to the i'th input/output port of the node to be converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
virtual Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) = 0;
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 663ceba027..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* inputs,
- std::vector<Port>* outputs) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) {
NodeDef node_def;
Status s;
Node* node = g.AddNode(node_def, &s);
- std::vector<Port> inputs, outputs;
- EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 344c420902..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -45,22 +45,6 @@ namespace {
// Describes a tensor with its operation Node and output position
typedef std::pair<Node*, int> TensorDesc;
-// Equivalent to python Pfor's WrappedTensor struct
-struct WrappedTensor {
- TensorDesc tensor;
-
- // Whether the tensor is stacked, i.e. represents the results of applying
- // the operation on all slices of the input, where each row i of the
- // tensor corresponds to the op's output on slice i of the input. False
- // if the tensor is not stacked, i.e. represents the result of the op on
- // a single slice of the input, where the result does not vary between
- // slices.
- bool stacked;
-
- WrappedTensor(TensorDesc&& tensor, bool stacked)
- : tensor(std::move(tensor)), stacked(stacked) {}
-};
-
const char* const kRetValOp = "_Retval";
void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
@@ -239,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
return errors::Unimplemented("No vectorizer registered for op: ",
op_node->type_string());
}
- std::vector<Port> input_ports, output_ports;
- input_ports.reserve(op_node->num_inputs());
- output_ports.reserve(op_node->num_outputs());
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- &input_ports, &output_ports));
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
- if (op_node->num_outputs() != output_ports.size() ||
- op_node->num_inputs() != input_ports.size() ||
- input_edges.size() != input_ports.size()) {
- return errors::Internal("Vectorizer inputs/outputs don't match.");
- }
-
- // Promote the inputs of the op to MapDefun outputs and connect the edges
- // accordingly.
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
for (size_t i = 0; i < op_node->num_inputs(); ++i) {
auto edge = input_edges[i];
- TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
- {edge->src(), edge->src_output()}));
- outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
- input_ports[i].first, input_ports[i].second);
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
+
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}});
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
return Status::OK();
@@ -281,25 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) {
TensorDesc output({ret_edge->src(), ret_edge->src_output()});
TensorDesc converted_output;
- if (auto found = gtl::FindOrNull(conversion_map_, output)) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (found->stacked) {
- converted_output = found->tensor;
- } else {
- // Some outputs may be unstacked if they don't derive from arg nodes
- // (for example, if a function returns a constant). For these, we
- // have to add extra nodes to tile it in the 0th dimension.
- TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
- }
- } else {
- // Note: All unstacked nodes are converted ahead of time in `Initialize`,
- // and here we assume that all op vectorizers create only stacked outputs.
- // This may not hold in the future, as more vectorizers are added that
- // may actually create unstacked outputs. For example, see the `Shape`
- // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output).tensor;
+ found = &conversion_map_.at(output);
+ }
+
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -455,7 +450,7 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* ones_shape;
TF_RETURN_IF_ERROR(node_builder("Shape")
- .Input(unstacked->tensor.first) // input
+ .Input(unstacked->node) // input
.Finalize(g, &ones_shape));
Node* ones;
@@ -473,8 +468,8 @@ Status Vectorization::StackTensor(WrappedTensor* unstacked,
Node* expand_dims;
TF_RETURN_IF_ERROR(node_builder("ExpandDims")
- .Input(unstacked->tensor.first) // input
- .Input(const_0) // dim
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
.Finalize(g, &expand_dims));
TF_RETURN_IF_ERROR(node_builder("Tile")
@@ -491,11 +486,11 @@ Status Vectorization::AddArgNodeMappings() {
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
// Control inputs
conversion_map_.insert({{arg_node, Graph::kControlSlot},
- {{input_node, Graph::kControlSlot}, true}});
+ {input_node, Graph::kControlSlot, true}});
}
return Status::OK();
}
@@ -541,7 +536,7 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
if (auto found = gtl::FindOrNull(conversion_map_,
{edge->src(), edge->src_output()})) {
- outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node,
+ outer_scope_->AddEdge(found->node, found->output_index, node,
edge->dst_input());
} else {
status->Update(errors::Internal(
@@ -552,11 +547,10 @@ bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
// Add output mappings
for (int i = 0; i < tensor.first->num_outputs(); ++i) {
- conversion_map_.insert(
- {{tensor.first, i}, WrappedTensor({node, i}, false)});
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
}
conversion_map_.insert({{tensor.first, Graph::kControlSlot},
- WrappedTensor({node, Graph::kControlSlot}, false)});
+ WrappedTensor(node, Graph::kControlSlot, false)});
return true;
}