aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
diff options
context:
space:
mode:
authorGravatar Rachel Lim <rachelim@google.com>2018-10-04 09:26:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 09:35:51 -0700
commit1fb84c2e41c454939a02a69093cb214673eab343 (patch)
treee3ee1c19e3a73e1d1cddbc76d5573b7800b1048b /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
parentac22e1583aed390d78d2e87a4bf8a6ec39400ec4 (diff)
Add ability to vectorize nodes that do not derive from function arguments. (This indirectly handles "Const" outputs automagically, since they are always unstacked.)
PiperOrigin-RevId: 215749824
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc247
1 files changed, 235 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 2d6cf562b1..344c420902 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,10 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
-#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
@@ -28,13 +28,13 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -45,6 +45,22 @@ namespace {
// Describes a tensor with its operation Node and output position
typedef std::pair<Node*, int> TensorDesc;
+// Equivalent to python Pfor's WrappedTensor struct
+struct WrappedTensor {
+ TensorDesc tensor;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(TensorDesc&& tensor, bool stacked)
+ : tensor(std::move(tensor)), stacked(stacked) {}
+};
+
const char* const kRetValOp = "_Retval";
void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
@@ -132,7 +148,8 @@ class Vectorization {
const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Converts FunctionDefs to Graphs.
+ // Converts FunctionDefs to Graphs and adds mappings from
+ // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
Status Initialize(const FunctionDef& outer_scope,
const NodeDef& map_defun_node);
@@ -162,9 +179,30 @@ class Vectorization {
// the conversion map.
Status AddConversionMapping(Node* op_node);
- // Maps a tensor to the corresponding vectorized tensor. For example,
- // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
- std::map<TensorDesc, TensorDesc> conversion_map_;
+ // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
+ // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
+ // inputs to `map_defun_node_`. This stacked tensor will be compatible with
+ // the expected output shape of `map_defun_node_`.
+ // This is equivalent to the _stack function in python Pfor.
+ Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
+
+ // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
+ // doing a depth-first search from the ret nodes. Lifts nodes that are
+ // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
+ // and add mappings to `conversion_map_`.
+ Status AddUnstackedNodeMappings();
+
+ // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
+ // is unstacked.
+ bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);
+
+ // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
+ // nodes to `conversion_map_`.
+ Status AddArgNodeMappings();
+
+ // Maps a tensor to the corresponding WrappedTensor. For example,
+ // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
+ std::map<TensorDesc, WrappedTensor> conversion_map_;
// Unconvertible ret nodes
std::set<Node*> unconvertible_;
@@ -180,6 +218,10 @@ class Vectorization {
std::unique_ptr<Graph> outer_scope_;
std::unique_ptr<FunctionBody> map_defun_fn_;
Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+
+ // Caches the loop_len_node_ needed for tiling unstacked output. This
+ // corresponds to a vector with one element.
+ Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
Status status_;
};
@@ -224,7 +266,7 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
+ conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}});
}
return Status::OK();
@@ -242,10 +284,22 @@ Status Vectorization::ConvertOutput(int output_position) {
if (auto found = gtl::FindOrNull(conversion_map_, output)) {
// It's possible the output already has a mapping, if it comes from a node
// that has already been converted.
- converted_output = *found;
+ if (found->stacked) {
+ converted_output = found->tensor;
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
+ }
} else {
+ // Note: All unstacked nodes are converted ahead of time in `Initialize`,
+ // and here we assume that all op vectorizers create only stacked outputs.
+ // This may not hold in the future, as more vectorizers are added that
+ // may actually create unstacked outputs. For example, see the `Shape`
+ // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output);
+ converted_output = conversion_map_.at(output).tensor;
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -297,6 +351,7 @@ void Vectorization::VectorizeHelper() {
map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}
+
Status Vectorization::Initialize(const FunctionDef& outer_scope,
const NodeDef& map_defun_node) {
// Convert outer_scope and map_defun_fn to FunctionBodys so we can
@@ -337,16 +392,184 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope,
}
map_defun_node_ = outer_scope_->FindNodeId(node_id);
- // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
- // the conversion map
+ TF_RETURN_IF_ERROR(AddArgNodeMappings());
+
+ TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
+ loop_len_node_ = nullptr;
+
+ return Status::OK();
+}
+
+// TODO(rachelim): It might be profitable to use the C++ API for this instead of
+// NodeBuilder
+Status Vectorization::StackTensor(WrappedTensor* unstacked,
+ TensorDesc* result) {
+ // Note that all these nodes are necessary as the size of the batch may not be
+ // constant.
+ if (unstacked->stacked) {
+ return errors::Internal("Can only stack unstacked tensor.");
+ }
+
+ Graph* g = outer_scope_.get();
+ auto node_builder = [](StringPiece op) {
+ return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
+ };
+
+ auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
+ Node** result) {
+ TF_RETURN_IF_ERROR(val.status);
+ return node_builder("Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype())
+ .Finalize(graph, result);
+ };
+
+ // If loop_len_node_ hasn't been created yet, add the node and cache it.
+ if (loop_len_node_ == nullptr) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
+
+ Node* shape_node;
+ TF_RETURN_IF_ERROR(
+ node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
+
+ Node* const_vec_0;
+ TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
+ Node* const_vec_1;
+ TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
+
+ Node* strided_slice_node;
+ TF_RETURN_IF_ERROR(node_builder("StridedSlice")
+ .Input(shape_node) // input
+ .Input(const_vec_0) // begin
+ .Input(const_vec_1) // end
+ .Input(const_vec_1) // strides
+ .Finalize(g, &strided_slice_node));
+
+ // Produces a vector of length 1
+ TF_RETURN_IF_ERROR(node_builder("Reshape")
+ .Input(strided_slice_node) // tensor
+ .Input(const_vec_1) // shape
+ .Finalize(g, &loop_len_node_));
+ }
+
+ Node* ones_shape;
+ TF_RETURN_IF_ERROR(node_builder("Shape")
+ .Input(unstacked->tensor.first) // input
+ .Finalize(g, &ones_shape));
+
+ Node* ones;
+ TF_RETURN_IF_ERROR(
+ node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
+
+ Node* const_0;
+ TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
+
+ Node* multiples;
+ TF_RETURN_IF_ERROR(node_builder("Concat")
+ .Input(const_0) // concat_dim
+ .Input({{loop_len_node_, 0}, {ones, 0}}) // values
+ .Finalize(g, &multiples));
+
+ Node* expand_dims;
+ TF_RETURN_IF_ERROR(node_builder("ExpandDims")
+ .Input(unstacked->tensor.first) // input
+ .Input(const_0) // dim
+ .Finalize(g, &expand_dims));
+
+ TF_RETURN_IF_ERROR(node_builder("Tile")
+ .Input(expand_dims) // input
+ .Input(multiples) // multiples
+ .Finalize(g, &result->first));
+ result->second = 0;
+ return Status::OK();
+}
+
+Status Vectorization::AddArgNodeMappings() {
for (auto arg_node : map_defun_fn_->arg_nodes) {
Node* input_node;
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}});
+
+ // Control inputs
+ conversion_map_.insert({{arg_node, Graph::kControlSlot},
+ {{input_node, Graph::kControlSlot}, true}});
}
+ return Status::OK();
+}
+bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
+ Status* status) {
+ if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
+ return !found->stacked;
+ }
+
+ if (tensor.first->op_def().is_stateful()) {
+ // We don't lift stateful nodes directly out of the MapDefun, since they may
+ // have to be executed N times.
+ return false;
+ }
+
+ bool is_unstacked = true;
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ // A node is unstacked if all of its inputs are unstacked
+ is_unstacked &= AddUnstackedNodeMappingsHelper(
+ {edge->src(), edge->src_output()}, status);
+ }
+
+ if (!is_unstacked) {
+ return false;
+ }
+
+ // If the node is unstacked, we copy it into outer_scope_ and
+ // add it to the map. Note that we don't clean up the nodes that are copied
+ // in map_defun_fn_, and rely on them being pruned out later.
+ Node* node = outer_scope_->AddNode(tensor.first->def(), status);
+ if (!status->ok()) return true;
+
+ // Add input edges to nodes that should already have been lifted.
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node,
+ edge->dst_input());
+ } else {
+ status->Update(errors::Internal(
+ "Could not find input conversion even though we did depth first "
+ "conversion."));
+ }
+ }
+
+ // Add output mappings
+ for (int i = 0; i < tensor.first->num_outputs(); ++i) {
+ conversion_map_.insert(
+ {{tensor.first, i}, WrappedTensor({node, i}, false)});
+ }
+ conversion_map_.insert({{tensor.first, Graph::kControlSlot},
+ WrappedTensor({node, Graph::kControlSlot}, false)});
+
+ return true;
+}
+
+Status Vectorization::AddUnstackedNodeMappings() {
+ SetVector<Node*> unstacked_nodes;
+ Status s;
+ for (const auto& ret_node : map_defun_fn_->ret_nodes) {
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
+ AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
+ TF_RETURN_IF_ERROR(s);
+ }
return Status::OK();
}