diff options
author | Rachel Lim <rachelim@google.com> | 2018-09-28 16:10:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 16:17:55 -0700 |
commit | 478d370eb116ad2294134d75a886637a7d6da225 (patch) | |
tree | 279ef8e8a2c9abeeda583393a986f055b9be314c /tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | |
parent | a98bac521406bedef3ff2b9af9564b21ddda4d82 (diff) |
[tf.data] Use Graph instead of GraphDef/FunctionDef for vectorization transforms
PiperOrigin-RevId: 215011835
Diffstat (limited to 'tensorflow/core/grappler/optimizers/data/vectorization_utils.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/data/vectorization_utils.cc | 451 |
1 files changed, 273 insertions, 178 deletions
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index cb56b65985..cea667f668 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -14,13 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include <memory> #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" @@ -36,255 +40,346 @@ namespace tensorflow { namespace grappler { namespace vectorization_utils { -using function_utils::FunctionDefTensorDesc; - namespace { -void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node, - const string& output_retval, const DataType t) { - // Set to unknown shape - TensorShapeProto tensor_shape_proto; - PartialTensorShape().AsProto(&tensor_shape_proto); +// Describes a tensor with its operation Node and output position +typedef std::pair<Node*, int> TensorDesc; - function_utils::AddFunctionOutputWithUniqueName( - "vectorized_out", output_retval, map_defun_fn, t); +const char* const kRetValOp = "_Retval"; - *(*map_defun_node->mutable_attr())["output_shapes"] - .mutable_list() - ->add_shape() = tensor_shape_proto; - (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t); +void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, + Graph* graph) { + // NOTE: We need two for loops here because we can't mutate the set of output + // edges as we iterate over them. + std::vector<const Edge*> edges_to_replace; + for (auto edge : old_src.first->out_edges()) { + if (edge->src_output() == old_src.second) { + edges_to_replace.push_back(edge); + } + } + for (auto edge : edges_to_replace) { + graph->AddEdge(new_src.first, new_src.second, edge->dst(), + edge->dst_input()); + graph->RemoveEdge(edge); + } } -void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node, int output_position) { - DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size()) - << "Trying to remove output that doesn't exist. Output number: " - << output_position; +Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, + const TensorDesc& output) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DataType type = output.first->output_type(output.second); + int index = map_defun_fn->ret_nodes.size(); - int num_later_outputs = - map_defun_fn->signature().output_arg_size() - output_position - 1; + NodeDef ret_node_def; + ret_node_def.set_name("map_out"); + ret_node_def.set_op(kRetValOp); + AddNodeAttr("T", type, &ret_node_def); + AddNodeAttr("index", index, &ret_node_def); - // Remove from map_defun_fn's ret dict and output args - map_defun_fn->mutable_ret()->erase( - map_defun_fn->signature().output_arg(output_position).name()); - map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange( - output_position, 1); + Status s; + Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s); + TF_RETURN_IF_ERROR(s); - // Renumber outputs that come after - for (int i = 0; i < num_later_outputs; ++i) { - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i + 1), - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i), - outer_scope); - } - map_defun_node->mutable_attr() - ->at("output_shapes") - .mutable_list() - ->mutable_shape() - ->DeleteSubrange(output_position, 1); - map_defun_node->mutable_attr() - ->at("output_types") - .mutable_list() - ->mutable_type() - ->ExtractSubrange(output_position, 1, nullptr); + map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); + map_defun_fn->ret_nodes.push_back(ret_node); + map_defun_fn->ret_types.push_back(type); + + return s; } -int FindOutputToConvert(const FunctionDef& function, - const std::set<string>& unconvertible, - FunctionDefTensorDesc* f) { - for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) { - const string& ret_key = function.signature().output_arg(i).name(); - *f = FunctionDefTensorDesc(function.ret().at(ret_key)); +void RemoveMapDefunOutput(int output_position, Graph* outer_scope, + FunctionBody* map_defun_fn, Node* map_defun_node) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) + << "Trying to remove output that doesn't exist. Output number: " + << output_position; + + int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1; - if (unconvertible.find(f->node_name) == unconvertible.end()) { - return i; - } + // Modify map_defun_fn's signature and remove the output node from its graph + map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]); + map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() + + output_position); + map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + + output_position); + + // Renumber the nodes and edges that come after + for (int i = 0; i < num_later_outputs; ++i) { + ReplaceEdgeSources({map_defun_node, output_position + i + 1}, + {map_defun_node, output_position + i}, outer_scope); + // Each ret node has an "index" attr that has to be updated + map_defun_fn->ret_nodes[output_position + i]->AddAttr("index", + output_position + i); } - return -1; } // Helper class that vectorizes the body of a MapDefun node, adding new // operations to the graph that collectively compute the same value as what // running the MapDefun function on slices of the input would produce. -// Each instance of the class encapsulates all the data necessary to vectorize a -// MapDefun op in place. +// This class transforms the input FunctionDefs into their corresponding +// Graph objects and works on the graphs directly, then converts them back +// to FunctionDefs when GetResult is called. class Vectorization { public: - Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) - : outer_scope_(outer_scope), - map_defun_fn_(map_defun_fn), - map_defun_node_(map_defun_node) {} + explicit Vectorization(FunctionDefLibrary* lib) + : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {} - // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in - // the outer_scope_, until there are no convertible outputs remaining. - // This method is idempotent. - void Vectorize(); + // Adds the vectorized function and new map_defun_fn to lib, and points + // vectorized_function to the former. Returns an error status if + // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere + // along the way. + Status Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDef** result); private: - // Vectorizes the map defun function's output at output_position - Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc); - // Given a descriptor of the original output tensor, gets a string - // corresponding to the converted output tensor. - Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc, - string* converted); - Status AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc); + // Converts FunctionDefs to Graphs. + Status Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node); + + // Converts Graphs back to FunctionDefs and adds them to `lib_`. + Status GetResult(FunctionDef** vectorized_function); + + // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in + // `outer_scope_`, until there are no convertible outputs remaining. + void VectorizeHelper(); + + // Vectorizes map_defun_fn's output at output_position. + Status ConvertOutput(int output_position); // Adds mappings from node's outputs tensors to converted output tensors, // creating the necessary new node(s). Generally, the steps to convert an op // are: - // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_, - // and modify map_defun_node_ attrs accordingly - // 2) Create new node(s) in outer_scope_ that act on batched input tensors. + // 1) Create new node(s) in `outer_scope_` that act on batched input tensors. // These operations collectively compute the same value as what running // the original operation on slices of the input tensors would produce. // For example, a Cast op in MapDefun translates to a Cast op in - // outer_scope_, since the vectorized version of Cast is itself. - // 3) Set inputs of new node(s) to the corresponding converted inputs (that - // are now outputs of map_defun_node_) - // 4) For each output of the old node, add the mapping of output strings to - // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0") - Status AddConversionMappingFromOp(const NodeDef& node, - const FunctionDefTensorDesc& output_desc); - - // Maps a tensor name to the name of the corresponding vectorized tensor. For - // example, "Cast:y:0" -> "Vectorize/Cast:y:0" - std::map<string, string> conversion_map_; - // Unconvertible node names - std::set<string> unconvertible_; - - FunctionDef* outer_scope_; - FunctionDef* map_defun_fn_; - NodeDef* map_defun_node_; + // `outer_scope_`, since the vectorized version of Cast is itself. + // 2) Promote the inputs of the op inputs to outputs of the + // `map_defun_node_` and `map_defun_fn_`. + // 3) Add edges between the promoted inputs (that are now outputs of + // `map_defun_node`) and the inputs ports of the new node(s). + // 4) For each output of the old node, add the mapping of output tensors to + // the conversion map. + Status AddConversionMapping(Node* op_node); + + // Maps a tensor to the corresponding vectorized tensor. For example, + // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0} + std::map<TensorDesc, TensorDesc> conversion_map_; + + // Unconvertible ret nodes + std::set<Node*> unconvertible_; + + FunctionDefLibrary* lib_; // Not owned + FunctionLibraryDefinition lib_def_; + // Note that FunctionBody has a pointer to a Graph object that corresponds + // to the function's subgraph, with additional kArgOp and kRetValOp nodes + // that denote that function arguments and return values. These nodes have the + // attrs "T" for the type, and "index" for the argument / retval index + // respectively. FunctionBody also keeps track of arg/ret_nodes and + // arg/ret_types, that should be ordered according to argument/output indices. + std::unique_ptr<Graph> outer_scope_; + std::unique_ptr<FunctionBody> map_defun_fn_; + Node* map_defun_node_ = nullptr; // Owned by `outer_scope` + Status status_; }; -Status Vectorization::AddConversionMappingFromOp( - const NodeDef& node, const FunctionDefTensorDesc& output_desc) { - for (const string& input_name : node.input()) { - if (IsControlInput(input_name)) { +Status Vectorization::AddConversionMapping(Node* op_node) { + for (auto edge : op_node->in_edges()) { + if (edge->IsControlEdge()) { return errors::InvalidArgument( "Vectorizing outputs with control inputs is currently not " "supported."); } } - // TODO(rachelim): Have some mechanism for registering converters and some - // uniform, simpler way to represent them. - - DataTypeVector types; - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); - TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types)); - - std::vector<string> promoted_inputs; - promoted_inputs.reserve(node.input_size()); - for (int i = 0; i < node.input_size(); ++i) { - promoted_inputs.push_back(strings::StrCat( - map_defun_node_->name(), - ":output:", map_defun_fn_->signature().output_arg_size() + i)); - } - - auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); + auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string()); if (vectorizer == nullptr) { return errors::Unimplemented("No vectorizer registered for op: ", - node.op()); + op_node->type_string()); + } + std::vector<Port> input_ports, output_ports; + input_ports.reserve(op_node->num_inputs()); + output_ports.reserve(op_node->num_outputs()); + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + &input_ports, &output_ports)); + + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); + + if (op_node->num_outputs() != output_ports.size() || + op_node->num_inputs() != input_ports.size() || + input_edges.size() != input_ports.size()) { + return errors::Internal("Vectorizer inputs/outputs don't match."); } - TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, - &conversion_map_)); + // Promote the inputs of the op to MapDefun outputs and connect the edges + // accordingly. + for (size_t i = 0; i < op_node->num_inputs(); ++i) { + auto edge = input_edges[i]; + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1, + input_ports[i].first, input_ports[i].second); + } - // If we get here, the conversion was successful, so we promote the inputs - // of the ops to MapDefun outputs. - for (int i = 0; i < types.size(); ++i) { - AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]); + // Add output mappings. + for (size_t i = 0; i < op_node->num_outputs(); ++i) { + conversion_map_.insert({{op_node, i}, std::move(output_ports[i])}); } return Status::OK(); } -Status Vectorization::AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc) { - int input_index = function_utils::FindFunctionInputWithName( - output_desc.node_name, *map_defun_fn_); - if (input_index == -1) { - return errors::Internal("Cannot convert non-existent input."); +Status Vectorization::ConvertOutput(int output_position) { + // ret_edge->src() is the actual op that generated the retval, and + // ret_edge->dst() is the retval node whose op is "_Retval" + const Edge* ret_edge; + TF_RETURN_IF_ERROR( + map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge)); + + TensorDesc output({ret_edge->src(), ret_edge->src_output()}); + TensorDesc converted_output; + if (auto found = gtl::FindOrNull(conversion_map_, output)) { + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + converted_output = *found; + } else { + TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); + converted_output = conversion_map_.at(output); } - conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index); + ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, + outer_scope_.get()); + RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(), + map_defun_node_); + return Status::OK(); } -Status Vectorization::ConvertOutputHelper( - const FunctionDefTensorDesc& output_desc, string* converted) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) { - *converted = *found; - return Status::OK(); +Status Vectorization::Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, + FunctionDef** result) { + TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node)); + VectorizeHelper(); + return GetResult(result); +} + +void Vectorization::VectorizeHelper() { + while (true) { + int output_position = graph_utils::GetFirstElementIndexWithPredicate( + [this](Node* n) { + return this->unconvertible_.find(n) == this->unconvertible_.end(); + }, + map_defun_fn_->ret_nodes); + + // No outputs left to convert + if (output_position == -1) break; + + Status s = ConvertOutput(output_position); + if (!s.ok()) { + Node* output_node = map_defun_fn_->ret_nodes.at(output_position); + VLOG(2) << "Could not convert the output at node: " + << output_node->DebugString() << "\nError: " << s; + unconvertible_.insert(output_node); + } } - int index = function_utils::FindFunctionNodeWithName(output_desc.node_name, - *map_defun_fn_); - if (index == -1) { // The output comes from an input - TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc)); + // If we've converted all the outputs of the MapDefun function, we no longer + // need the MapDefun node and can delete it. + if (map_defun_fn_->ret_nodes.empty()) { + outer_scope_->RemoveNode(map_defun_node_); } else { - TF_RETURN_IF_ERROR(AddConversionMappingFromOp( - map_defun_fn_->node_def(index), output_desc)); + // Update MapDefun node attrs accordingly + DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); + map_defun_node_->AddAttr( + "output_shapes", + std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size())); + map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); } - *converted = conversion_map_.at(output_desc.full_str); - return Status::OK(); } +Status Vectorization::Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node) { + // Convert outer_scope and map_defun_fn to FunctionBodys so we can + // work on Graphs directly. + const FunctionDef* map_defun_fn = + lib_def_.Find(map_defun_node.attr().at("f").func().name()); + + if (map_defun_fn == nullptr) { + return errors::NotFound("Could not find function with name ", + map_defun_node.attr().at("f").func().name(), + " in function library."); + } -Status Vectorization::ConvertOutput(int output_position, - const FunctionDefTensorDesc& output_desc) { - string converted_output_name; - TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name)); + auto get_func_sig = [this](const string& op, const OpDef** sig) { + return this->lib_def_.LookUpOpDef(op, sig); + }; + + FunctionBody* outer_fn; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_, + get_func_sig, &outer_fn)); + // We don't need outer_fn, just the graph + outer_scope_.reset(outer_fn->graph); + outer_fn->graph = nullptr; + delete outer_fn; + + FunctionBody* tmp; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_, + get_func_sig, &tmp)); + map_defun_fn_.reset(tmp); + + // Find the MapDefun node in outer_scope_ + int node_id = graph_utils::GetFirstElementIndexWithPredicate( + [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); }, + outer_scope_->nodes()); + if (node_id == -1) { + return errors::NotFound("Could not find node with name ", + map_defun_node.name(), " in outer_scope."); + } + map_defun_node_ = outer_scope_->FindNodeId(node_id); + + // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to + // the conversion map + for (auto arg_node : map_defun_fn_->arg_nodes) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node( + arg_node->attrs().Find("index")->i(), &input_node)); - // Remove the old output and make everything that referenced it point - // to the new string - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node_->name(), ":output:", output_position), - converted_output_name, outer_scope_); - RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_, - output_position); + conversion_map_.insert({{arg_node, 0}, {input_node, 0}}); + } return Status::OK(); } -void Vectorization::Vectorize() { - while (true) { - FunctionDefTensorDesc desc; - int output_position = - FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc); - if (output_position == -1) break; +Status Vectorization::GetResult(FunctionDef** vectorized_function) { + TF_RETURN_IF_ERROR(status_); - if (!ConvertOutput(output_position, desc).ok()) { - unconvertible_.insert(desc.node_name); - } - } + if (!map_defun_fn_->ret_nodes.empty()) { + FunctionDef* map_defun_fn = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn)); - // If we've converted all the outputs of the MapDefun function, we no longer - // need the MapDefun node and can delete it. - if (map_defun_fn_->signature().output_arg_size() == 0) { - outer_scope_->mutable_node_def()->DeleteSubrange( - function_utils::FindFunctionNodeWithName(map_defun_node_->name(), - *outer_scope_), - 1); + AttrValue func_attr; + func_attr.mutable_func()->set_name(map_defun_fn->signature().name()); + map_defun_node_->AddAttr("f", func_attr); } - if (!unconvertible_.empty()) { - VLOG(2) << "The following nodes could not be converted: [" - << absl::StrJoin(unconvertible_, ", ") << "]."; - } + *vectorized_function = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_, + *vectorized_function); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *outer_scope_, (*vectorized_function)->signature().name(), + *vectorized_function)); + return Status::OK(); } + } // namespace -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) { - Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result) { + *result = nullptr; + return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result); } } // end namespace vectorization_utils |