diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-12 18:29:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-12 18:31:28 -0700 |
commit | 93afca507ec09ff3b5cdf05cbd5eb265e83fc8cb (patch) | |
tree | 4e25cefafcc41c3e94bf7abed519c72a45cf0c18 | |
parent | 7d89bfcd72bef4c5c9328a88ee520d81642b5284 (diff) |
Convert GrapplerFunctionItem to (Specialized)FunctionDef.
PiperOrigin-RevId: 192704808
-rw-r--r-- | tensorflow/core/grappler/utils/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/functions.cc | 328 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/functions.h | 92 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/functions_test.cc | 179 |
4 files changed, 504 insertions, 98 deletions
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 05d9cbaa2b..b473f32c45 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -165,6 +165,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", ], ) @@ -177,6 +178,8 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/core:all_kernels", "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index dd0d918e72..e8d423a759 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -23,27 +23,82 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { +namespace { + +Status OutputNameRange(const FunctionLibraryDefinition& flib, + const NodeDef& node, + tensorflow::NameRangeMap* outputs_range_map) { + const OpRegistrationData* registration; + TF_RETURN_IF_ERROR(flib.LookUp(node.op(), ®istration)); + TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(node, registration->op_def, + nullptr, outputs_range_map)); + return Status::OK(); +} + +Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib, + const NodeDef& node, + GrapplerFunctionConnectivity* connectivity) { + tensorflow::NameRangeMap outputs_range_map; + TF_RETURN_IF_ERROR(OutputNameRange(flib, node, &outputs_range_map)); + connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map); + return Status::OK(); +} + +// Replace the placeholder attribute values with the values specified in +// instantiation attributes. +Status ResolveFunctionBodyNodeAttrPlaceholders( + const AttrValueMap& func_instantiation_attr, NodeDef* node) { + for (auto& attr : *node->mutable_attr()) { + const string& placeholder = attr.second.placeholder(); + if (placeholder.empty()) continue; + + auto it = func_instantiation_attr.find(placeholder); + if (it != func_instantiation_attr.end()) { + attr.second = it->second; + } else { + return errors::InvalidArgument("Can't resolve placeholder: ", + placeholder); + } + } + return Status::OK(); +} + +} // namespace + void GrapplerFunctionConnectivity::RegisterInputArgExpansion( const InputArgExpansion& input_arg_expansion) { - input_arg_expansions_.insert( - {input_arg_expansion.input_name, input_arg_expansion}); + const auto& input_name = input_arg_expansion.input_name; + const auto& placeholders = input_arg_expansion.placeholders; + input_arg_expansions_.emplace(input_name, input_arg_expansion); + for (int i = 0; i < placeholders.size(); ++i) { + const string& placeholder = input_arg_expansion.placeholders[i]; + input_arg_placeholders_.emplace( + placeholder, InputArgPlaceholder{input_name, /*position=*/i}); + } } void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs( const string& node_name, const tensorflow::NameRangeMap& outputs) { - function_body_outputs_.insert({node_name, outputs}); + function_body_outputs_[node_name] = outputs; } Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( const string& func_def_input, std::vector<string>* graph_def_inputs) const { using ::tensorflow::strings::Scanner; + if (IsControlInput(func_def_input)) { + graph_def_inputs->push_back(func_def_input); + return Status::OK(); + } + // Parse input format: "node_name[:node_output][:position]" string node_name; string node_output; @@ -150,11 +205,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs( std::vector<string> expanded_inputs; for (const string& function_def_input : function_body_node->input()) { - if (!IsControlInput(function_def_input)) - TF_RETURN_IF_ERROR( - ExpandFunctionDefInput(function_def_input, &expanded_inputs)); - else - expanded_inputs.push_back(function_def_input); + TF_RETURN_IF_ERROR( + ExpandFunctionDefInput(function_def_input, &expanded_inputs)); } function_body_node->clear_input(); @@ -163,10 +215,66 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs( return Status::OK(); } -Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name, - DataType* data_type) const { - auto it = func_attr_->find(type_attr_name); - if (it == func_attr_->end()) { +Status GrapplerFunctionConnectivity::AsFunctionDefInput( + const string& graph_def_input, string* func_def_input) const { + using gtl::FindOrNull; + + if (IsControlInput(graph_def_input)) { + *func_def_input = graph_def_input; + return Status::OK(); + } + + int position; + string node_name = ParseNodeName(graph_def_input, &position); + CHECK_GE(position, 0); + + // Check if it's an input arg placeholder + if (position == 0) { + const InputArgPlaceholder* placeholder = + FindOrNull(input_arg_placeholders_, node_name); + if (placeholder != nullptr) { + *func_def_input = + strings::StrCat(placeholder->input_name, ":", placeholder->position); + return Status::OK(); + } + } + + // It must be output from one of the function body nodes + const tensorflow::NameRangeMap* outputs_range_map = + FindOrNull(function_body_outputs_, node_name); + if (outputs_range_map != nullptr) { + for (const auto& el : *outputs_range_map) { + const auto& output_name = el.first; + const auto& output_range = el.second; + if (position >= output_range.first && position < output_range.second) { + int pos = position - output_range.first; + *func_def_input = + strings::StrCat(node_name, ":", output_name, ":", pos); + return Status::OK(); + } + } + } + + return errors::InvalidArgument("Unknown graph def input: ", graph_def_input); +} + +Status GrapplerFunctionConnectivity::AsFunctionDefNode( + NodeDef* function_body_node) const { + string func_def_input; + + for (int i = 0; i < function_body_node->input_size(); ++i) { + TF_RETURN_IF_ERROR( + AsFunctionDefInput(function_body_node->input(i), &func_def_input)); + function_body_node->set_input(i, func_def_input); + } + + return Status::OK(); +} + +Status GrapplerFunctionItemInstantiation::GetTypeAttr( + const string& type_attr_name, DataType* data_type) const { + auto it = func_instantiation_attr_->find(type_attr_name); + if (it == func_instantiation_attr_->end()) { return errors::InvalidArgument("Type attribute ", type_attr_name, " is not defined"); } else if (it->second.type() == DT_INVALID) { @@ -178,31 +286,48 @@ Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name, return Status::OK(); } -Status GrapplerFunctionItemBuilder::GetArgType(const OpDef::ArgDef& arg, - DataType* data_type) const { +Status GrapplerFunctionItemInstantiation::GetArgType( + const OpDef::ArgDef& arg, DataType* data_type) const { if (arg.type() != DT_INVALID) { *data_type = arg.type(); } else { + if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) { + return errors::InvalidArgument( + "Arguments with sequence of tensors are not supported. Unsupported " + "argument name: ", + arg.name()); + } TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type)); } return Status::OK(); } GrapplerFunctionItem::GrapplerFunctionItem( - const string& function_name, + const string& func_name, const AttrValueMap& func_attr, const std::vector<InputArgExpansion>& input_arg_expansions, const std::vector<OutputArgExpansion>& output_arg_expansions, GraphDef&& function_body) - : function_name_(function_name), + : func_attr_(func_attr), input_arg_expansions_(input_arg_expansions), output_arg_expansions_(output_arg_expansions) { + id = func_name; + // Fill the feed nodes with input placeholders + for (const InputArgExpansion& input_arg : input_arg_expansions_) { + for (const string& placeholder : input_arg.placeholders) { + feed.emplace_back(placeholder, Tensor()); + input_arg_placeholders_.insert(placeholder); + } + } + // Fill the fetch nodes with outputs + for (const OutputArgExpansion& output_arg : output_arg_expansions_) { + for (const string& output_tensor : output_arg.output_tensors) { + fetch.push_back(output_tensor); + } + } + // Swap the graph body graph.Swap(&function_body); } -const string& GrapplerFunctionItem::function_name() const { - return function_name_; -} - const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const { return input_arg_expansions_; } @@ -215,6 +340,11 @@ const std::size_t GrapplerFunctionItem::input_size() const { return input_arg_expansions_.size(); } +bool GrapplerFunctionItem::IsInputPlaceholder(const string& node_name) const { + return input_arg_placeholders_.find(node_name) != + input_arg_placeholders_.end(); +} + const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const { return output_arg_expansions_; } @@ -227,10 +357,19 @@ const std::size_t GrapplerFunctionItem::output_size() const { return output_arg_expansions_.size(); } +const AttrValueMap& GrapplerFunctionItem::func_attr() const { + return func_attr_; +} + const GraphDef& GrapplerFunctionItem::function_body() const { return graph; } GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; } +GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) { + graph.Swap(&other); + return *this; +} + std::vector<string> OutputTensors(const GrapplerFunctionItem& item) { std::vector<string> output_tensors; for (const OutputArgExpansion& output : item.outputs()) { @@ -241,18 +380,27 @@ std::vector<string> OutputTensors(const GrapplerFunctionItem& item) { return output_tensors; } -Status MakeGrapplerFunctionItem( - const FunctionDef& func, - const std::unordered_map<string, AttrValue>& func_attr, - const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item) { +Status MakeGrapplerFunctionItem(const FunctionDef& func, + const AttrValueMap& func_instantiation_attr, + const FunctionLibraryDefinition& flib, + GrapplerFunctionItem* item) { const OpDef& signature = func.signature(); if (signature.name().empty()) { return errors::InvalidArgument("Function name must be specified"); } - // Helper methods to lookup function attributes - GrapplerFunctionItemBuilder builder(&func_attr); + // Function types will be resolved from function instantiation attributes. All + // other attributes will be lost during conversion to FunctionDef. + for (const OpDef::AttrDef& attr : signature.attr()) { + if (attr.type() != "type") { + return errors::InvalidArgument( + "Function signature must have only type attributes"); + } + } + + // Helper methods to lookup function instantiation attributes + GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr); // Mapping from FunctionDef input format (name[:output][:position]) to // GraphDef input format (name[:position]) @@ -260,7 +408,10 @@ Status MakeGrapplerFunctionItem( std::vector<InputArgExpansion> inputs; std::vector<OutputArgExpansion> outputs; + + // Function body shares the library with the graph that instantiated it. GraphDef function_body; + *function_body.mutable_library() = flib.ToProto(); // TODO(ezhulenev): support functions with tensor sequence inputs/outputs @@ -284,7 +435,7 @@ Status MakeGrapplerFunctionItem( } DataType input_data_type; - TF_RETURN_IF_ERROR(builder.GetArgType(input, &input_data_type)); + TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type)); NodeDef* placeholder = function_body.add_node(); placeholder->set_name(input.name()); @@ -292,6 +443,7 @@ Status MakeGrapplerFunctionItem( (*placeholder->mutable_attr())["T"].set_type(input_data_type); InputArgExpansion input_expansion{/*input_name=*/input.name(), + /*data_type=*/input_data_type, /*placeholders=*/{input.name()}}; connectivity.RegisterInputArgExpansion(input_expansion); inputs.push_back(input_expansion); @@ -302,24 +454,12 @@ Status MakeGrapplerFunctionItem( NodeDef* new_node = function_body.add_node(); *new_node = func_def_node; - // Replace the placeholder attribute values with the specified value - for (auto& attr : *new_node->mutable_attr()) { - const string& ph_name = attr.second.placeholder(); - auto it = func_attr.find(ph_name); - if (it != func_attr.end()) { - attr.second = it->second; - } - } - - // Functions use a custom format to encode connectivity. Map these custom - // strings to regular ones. - tensorflow::NameRangeMap outputs_range_map; - const OpRegistrationData* registration; - TF_RETURN_IF_ERROR(func_library.LookUp(func_def_node.op(), ®istration)); - TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( - func_def_node, registration->op_def, nullptr, &outputs_range_map)); - connectivity.RegisterFunctionBodyOutputs(func_def_node.name(), - outputs_range_map); + // Resolve all placeholder values using function instantiation attributes. + TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders( + func_instantiation_attr, new_node)); + // Register node output range in a function connectivity. + TF_RETURN_IF_ERROR( + RegisterFunctionBodyOutputs(flib, func_def_node, &connectivity)); } // Rewrite inputs to use GraphDef format @@ -331,20 +471,96 @@ Status MakeGrapplerFunctionItem( for (const OpDef::ArgDef& out : signature.output_arg()) { std::vector<string> output_tensors; auto ret = func.ret().find(out.name()); - if (ret != func.ret().end()) { - // Expand outputs using provided output mapping - TF_RETURN_IF_ERROR( - connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)); - } else { - // Otherwise output must be one of the function inputs - TF_RETURN_IF_ERROR( - connectivity.ExpandFunctionDefInput(out.name(), &output_tensors)); + TF_RETURN_IF_ERROR( + ret != func.ret().end() + // Expand outputs using provided output mapping + ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors) + // Otherwise output must be one of the function inputs + : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors)); + + DataType output_data_type; + TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type)); + + OutputArgExpansion output{/*output_name=*/out.name(), + /*data_type=*/output_data_type, + /*output_tensors=*/output_tensors}; + outputs.push_back(output); + } + + *item = GrapplerFunctionItem( + /*func_name=*/signature.name(), + /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()), + inputs, outputs, std::move(function_body)); + return Status::OK(); +} + +// Register GrapplerFunctionItem input arg expansion and function body outputs +// in the GrapplerFunctionConnectivity +Status RegisterGrapplerFunctionConnectivity( + const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, + GrapplerFunctionConnectivity* connectivity) { + for (const InputArgExpansion& input : item.inputs()) { + connectivity->RegisterInputArgExpansion(input); + } + for (const NodeDef& func_body_node : item.function_body().node()) { + TF_RETURN_IF_ERROR( + RegisterFunctionBodyOutputs(flib, func_body_node, connectivity)); + } + return Status::OK(); +} + +Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item, + const FunctionLibraryDefinition& flib, + FunctionDef* func) { + func->mutable_signature()->set_name(item.id); + + // Build a GrapplerFunctionConnectivity from inputs and new function body. + GrapplerFunctionConnectivity connectivity; + TF_RETURN_IF_ERROR( + RegisterGrapplerFunctionConnectivity(item, flib, &connectivity)); + + // Add function input arguments. + for (const InputArgExpansion& input_arg : item.inputs()) { + OpDef::ArgDef arg_def; + arg_def.set_name(input_arg.input_name); + arg_def.set_type(input_arg.data_type); + *func->mutable_signature()->add_input_arg() = arg_def; + } + + // Add function output arguments. + for (const OutputArgExpansion& output_arg : item.outputs()) { + OpDef::ArgDef arg_def; + arg_def.set_name(output_arg.output_name); + arg_def.set_type(output_arg.data_type); + *func->mutable_signature()->add_output_arg() = arg_def; + + CHECK(output_arg.output_tensors.size() == 1) // do some sanity checking + << "Outputs of tensor sequences are not supported"; + + string ret; + for (const string& output_tensor : output_arg.output_tensors) { + TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(output_tensor, &ret)); + (*func->mutable_ret())[output_arg.output_name] = ret; } - outputs.push_back({out.name(), output_tensors}); } - *item = GrapplerFunctionItem(signature.name(), inputs, outputs, - std::move(function_body)); + // Copy function definition specific attributes. + for (const auto& attr : item.func_attr()) { + const auto& attr_name = attr.first; + const auto& attr_value = attr.second; + (*func->mutable_attr())[attr_name] = attr_value; + } + + // Copy function body nodes to the FunctionDef and update input format + for (const NodeDef& func_body_node : item.function_body().node()) { + // Do not copy input placeholders + if (item.IsInputPlaceholder(func_body_node.name())) continue; + + NodeDef* func_def_node = func->add_node_def(); + *func_def_node = func_body_node; + TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node)); + } + return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 60ea8857c0..2ac3917a66 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -28,14 +28,19 @@ limitations under the License. namespace tensorflow { namespace grappler { +using AttrValueMap = std::unordered_map<string, AttrValue>; + // Depending on the function instantiation attributes, input argument to the // function might be a single tensor, list of tensors of the same type, or a // list of tensors of different types. // // InputArgExpansion keeps track of the placeholders that were added to the -// function body in place of function inputs. +// function body in place of function inputs and a resolved input data type. struct InputArgExpansion { + // TODO(ezhulenev): Add support for functions with tensor sequence inputs of + // different data types string input_name; // name of the function input argument + DataType data_type; // input data type std::vector<string> placeholders; // names of placeholder nodes in the // function body }; @@ -44,11 +49,14 @@ struct InputArgExpansion { // to one or more outputs of one of the function body nodes. // // OutputArgExpansion keeps mapping from a function output arg to the output -// tensors of a function body nodes, that compute function outputs. +// tensors of a function body nodes and a resolved output data type struct OutputArgExpansion { + // TODO(ezhulenev): Add support for functions with tensor sequence outputs of + // different data types string output_name; // name of the function output argument - std::vector<string> output_tensors; // names of output tensors from the - // function body graph nodes + DataType data_type; // output data type + std::vector<string> output_tensors; // names of output tensor from the + // function body nodes }; // FunctionDef uses different connectivity encoding for the function body nodes, @@ -67,26 +75,46 @@ class GrapplerFunctionConnectivity { Status ExpandFunctionDefInput(const string& func_def_input, std::vector<string>* graph_def_inputs) const; - // Update Node inputs from FunctionDef to GraphDef format + // Update Node inputs from FunctionDef to GraphDef format. Status ExpandNodeInputs(NodeDef* function_body_node) const; - // TODO(ezhulenev): fold GraphDef inputs back to FunctionDef format - // Status FoldGraphDefInputs(const std::vector<sting> graph_def_inputs, - // std::vector<string>* function_def_inputs) const; + // When expanding inputs in function def format, single input might be + // expanded into multiple tensors. When converting back to the function def + // format from graph def format, it's always a 1-to-1 relationship. + // FunctionDef built from GrapplerFunctionItem is always specialized to it's + // instantiation attributes and length of input args (and node def outputs) is + // known. + + // Map from GraphDef input format to FunctionDef input format using registered + // input arg expansion and function body outputs. + Status AsFunctionDefInput(const string& graph_def_input, + string* func_def_input) const; + + // Update Node inputs from GraphDef to FunctionDef format. + Status AsFunctionDefNode(NodeDef* function_body_node) const; private: + // Mapping from input name to input arg expansion. std::unordered_map<string, InputArgExpansion> input_arg_expansions_; + // Mapping from function body node name to output names range map. std::unordered_map<string, tensorflow::NameRangeMap> function_body_outputs_; + + struct InputArgPlaceholder { + string input_name; + int position; + }; + + // Mapping from input arg placeholder to the function input tensor. + std::unordered_map<string, InputArgPlaceholder> input_arg_placeholders_; }; -// Helper methods to build GrapplerFunctionItem from a function def and function -// attributes. -class GrapplerFunctionItemBuilder { +// Get Function type attributes using attributes of a node that instantiated +// a function. +class GrapplerFunctionItemInstantiation { public: - using FunctionAttr = std::unordered_map<string, AttrValue>; - - explicit GrapplerFunctionItemBuilder(const FunctionAttr* func_attr) - : func_attr_(func_attr) {} + explicit GrapplerFunctionItemInstantiation( + const AttrValueMap* func_instantiation_attr) + : func_instantiation_attr_(func_instantiation_attr) {} // Get DataType from attributes by name. Return error if attribute is missing, // or it doesn't define a valid data type. @@ -97,20 +125,20 @@ class GrapplerFunctionItemBuilder { Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const; private: - const FunctionAttr* func_attr_; // do not own + const AttrValueMap* func_instantiation_attr_; // do not own }; // A special case of GrapplerItem, constructed from a TensorFlow Function. class GrapplerFunctionItem : public GrapplerItem { public: - GrapplerFunctionItem() {} + GrapplerFunctionItem() = default; GrapplerFunctionItem( - const string& function_name, + const string& func_name, const AttrValueMap& func_attr, const std::vector<InputArgExpansion>& input_arg_expansions, const std::vector<OutputArgExpansion>& output_arg_expansions, GraphDef&& function_body); - const string& function_name() const; + bool IsInputPlaceholder(const string& node_name) const; const std::vector<InputArgExpansion>& inputs() const; const InputArgExpansion& input(int i) const; @@ -120,13 +148,20 @@ class GrapplerFunctionItem : public GrapplerItem { const OutputArgExpansion& output(int i) const; const std::size_t output_size() const; + const AttrValueMap& func_attr() const; const GraphDef& function_body() const; GraphDef& mutable_function_body(); + GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other); + private: - string function_name_; + AttrValueMap func_attr_; // Attributes specific to function definition that + // produced this item (FuncDef.attr field). + std::vector<InputArgExpansion> input_arg_expansions_; std::vector<OutputArgExpansion> output_arg_expansions_; + + std::set<string> input_arg_placeholders_; }; // Return all output tensors referenced by item output args. @@ -136,8 +171,21 @@ std::vector<string> OutputTensors(const GrapplerFunctionItem& item); // Return error if the given function def cannot be converted. Status MakeGrapplerFunctionItem( const FunctionDef& func, - const std::unordered_map<string, AttrValue>& func_attr, - const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item); + const std::unordered_map<string, AttrValue>& func_instantiation_attr, + const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item); + +// Register GrapplerFunctionItem input arg expansion and function body outputs +// in the GrapplerFunctionConnectivity. Use function library definition to +// lookup function body nodes output names and ranges. +Status RegisterGrapplerFunctionConnectivity( + const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, + GrapplerFunctionConnectivity* connectivity); + +// Make a specialized FunctionDef from the GrapplerFunctionItem. Use function +// library definition to lookup function body nodes output names and ranges. +Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item, + const FunctionLibraryDefinition& flib, + FunctionDef* func); } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc index 1eb3298e89..a9a708bf67 100644 --- a/tensorflow/core/grappler/utils/functions_test.cc +++ b/tensorflow/core/grappler/utils/functions_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" @@ -32,8 +33,9 @@ class FunctionsTest : public ::testing::Test {}; TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) { GrapplerFunctionConnectivity connectivity; - connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}}); - connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}}); + connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}}); + connectivity.RegisterInputArgExpansion( + {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}}); connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}}); connectivity.RegisterFunctionBodyOutputs("Func", @@ -93,11 +95,50 @@ TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) { EXPECT_EQ("Func:3", inputs[0]); } +TEST_F(FunctionsTest, GrapplerFunctionConnectivity_AsFunctionDefInput) { + GrapplerFunctionConnectivity connectivity; + + connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}}); + connectivity.RegisterInputArgExpansion( + {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}}); + + connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}}); + connectivity.RegisterFunctionBodyOutputs("Func", + {{"o1", {0, 2}}, {"o2", {2, 4}}}); + + string input; + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputA", &input)); + EXPECT_EQ("inputA:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_0", &input)); + EXPECT_EQ("inputB:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_1", &input)); + EXPECT_EQ("inputB:1", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Add", &input)); + EXPECT_EQ("Add:z:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func", &input)); + EXPECT_EQ("Func:o1:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:1", &input)); + EXPECT_EQ("Func:o1:1", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:2", &input)); + EXPECT_EQ("Func:o2:0", input); + + TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:3", &input)); + EXPECT_EQ("Func:o2:1", input); +} + TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) { GrapplerFunctionConnectivity connectivity; - connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}}); - connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}}); + connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}}); + connectivity.RegisterInputArgExpansion( + {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}}); NodeDef node; node.add_input("inputA:0"); @@ -131,12 +172,12 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) { std::unordered_map<string, AttrValue> func_attr; func_attr["T"].set_type(DT_FLOAT); - FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); - EXPECT_EQ("XTimesTwo", item.function_name()); + EXPECT_EQ("XTimesTwo", item.id); EXPECT_EQ(4, item.function_body().node_size()); EXPECT_EQ(1, item.input_size()); @@ -206,12 +247,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { std::unordered_map<string, AttrValue> func_attr; func_attr["T"].set_type(DT_FLOAT); - FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); - EXPECT_EQ("SubGrad", item.function_name()); + EXPECT_EQ("SubGrad", item.id); EXPECT_EQ(12, item.function_body().node_size()); ASSERT_EQ(3, item.input_size()); @@ -251,8 +292,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) { } TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { - FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); - TF_ASSERT_OK(library.AddFunctionDef(FunctionDefHelper::Define( + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); + TF_ASSERT_OK(flib.AddFunctionDef(FunctionDefHelper::Define( // Name "Swap", // Args @@ -290,7 +331,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) { func_attr["T"].set_type(DT_FLOAT); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); int count = 0; for (const NodeDef &node : item.function_body().node()) { @@ -348,10 +389,10 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) { {{"out", "Exp:y:0"}}); std::unordered_map<string, AttrValue> func_attr; - FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); EXPECT_EQ(1, item.output_size()); EXPECT_EQ("Exp", item.output(0).output_tensors[0]); @@ -391,12 +432,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) { {{"out0", "in0"}}); std::unordered_map<string, AttrValue> func_attr; - FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); - EXPECT_EQ("ForwardInputs", item.function_name()); + EXPECT_EQ("ForwardInputs", item.id); EXPECT_EQ(5, item.function_body().node_size()); EXPECT_EQ(3, item.output_size()); @@ -437,10 +478,10 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { std::unordered_map<string, AttrValue> func_attr; func_attr["T"].set_type(DT_FLOAT); - FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary()); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); GrapplerFunctionItem item; - TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item)); + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); EXPECT_EQ(0, item.input_size()); EXPECT_EQ(1, item.output_size()); @@ -456,6 +497,104 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) { EXPECT_EQ("two", cast.input(0)); } +TEST_F(FunctionsTest, MakeSpecializedFunctionDef) { + const Tensor kTwo = test::AsScalar<int64>(2); + FunctionDef func = FunctionDefHelper::Define( + // Name + "XTimesTwo", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}}, + {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}}, + {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}}, + }); + + std::unordered_map<string, AttrValue> func_attr; + func_attr["T"].set_type(DT_FLOAT); + FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary()); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + + FunctionDef specialized; + TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized)); + + // Input and output types are resolved based on instantiation attributes. + EXPECT_EQ("x", specialized.signature().input_arg(0).name()); + EXPECT_EQ(DT_FLOAT, specialized.signature().input_arg(0).type()); + EXPECT_EQ("y", specialized.signature().output_arg(0).name()); + EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type()); + + // Function body specialized for instantiation types + int count = 0; + for (const NodeDef &node : specialized.node_def()) { + if (node.name() == "scale" && count++) { + EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type()); + } else if (node.name() == "y" && count++) { + EXPECT_EQ("Mul", node.op()); + EXPECT_EQ("x:0", node.input(0)); + EXPECT_EQ("scale:y:0", node.input(1)); + EXPECT_EQ(DT_FLOAT, node.attr().at("T").type()); + } + } + EXPECT_EQ(2, count); +} + +TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) { + using test::function::NDef; + + FunctionDef mul_func = FunctionDefHelper::Create( + "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + + FunctionDef func = FunctionDefHelper::Create( + "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"}, + {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "output:z:0"}}); + + GraphDef id_func_body = test::function::GDef( + {/* pass input to output through identity */ + NDef("output", "Identity", {"x"}, {{"T", "float"}})}); + + std::unordered_map<string, AttrValue> func_attr; + func_attr["T"].set_type(DT_FLOAT); + + FunctionDefLibrary lib_def; + *lib_def.add_function() = func; + *lib_def.add_function() = mul_func; + FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def); + + GrapplerFunctionItem item; + TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item)); + + // Replace function body with identity function + item.SwapFunctionBody(std::move(id_func_body)); + FunctionDef specialized; + TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized)); + + // Check that graph body was updated. + int count = 0; + for (const NodeDef &node : specialized.node_def()) { + if (node.name() == "output" && count++) { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("x:0", node.input(0)); + } + } + EXPECT_EQ(1, count); + + // And return tensor mapping was updated with a new output name (z->output). + EXPECT_EQ("output:output:0", (*specialized.mutable_ret())["z"]); +} + } // namespace } // namespace grappler } // namespace tensorflow |