aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 18:29:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 18:31:28 -0700
commit93afca507ec09ff3b5cdf05cbd5eb265e83fc8cb (patch)
tree4e25cefafcc41c3e94bf7abed519c72a45cf0c18
parent7d89bfcd72bef4c5c9328a88ee520d81642b5284 (diff)
Convert GrapplerFunctionItem to (Specialized)FunctionDef.
PiperOrigin-RevId: 192704808
-rw-r--r--tensorflow/core/grappler/utils/BUILD3
-rw-r--r--tensorflow/core/grappler/utils/functions.cc328
-rw-r--r--tensorflow/core/grappler/utils/functions.h92
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc179
4 files changed, 504 insertions, 98 deletions
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index 05d9cbaa2b..b473f32c45 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -165,6 +165,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
],
)
@@ -177,6 +178,8 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index dd0d918e72..e8d423a759 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -23,27 +23,82 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
+namespace {
+
+Status OutputNameRange(const FunctionLibraryDefinition& flib,
+ const NodeDef& node,
+ tensorflow::NameRangeMap* outputs_range_map) {
+ const OpRegistrationData* registration;
+ TF_RETURN_IF_ERROR(flib.LookUp(node.op(), &registration));
+ TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(node, registration->op_def,
+ nullptr, outputs_range_map));
+ return Status::OK();
+}
+
+Status RegisterFunctionBodyOutputs(const FunctionLibraryDefinition& flib,
+ const NodeDef& node,
+ GrapplerFunctionConnectivity* connectivity) {
+ tensorflow::NameRangeMap outputs_range_map;
+ TF_RETURN_IF_ERROR(OutputNameRange(flib, node, &outputs_range_map));
+ connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ return Status::OK();
+}
+
+// Replace the placeholder attribute values with the values specified in
+// instantiation attributes.
+Status ResolveFunctionBodyNodeAttrPlaceholders(
+ const AttrValueMap& func_instantiation_attr, NodeDef* node) {
+ for (auto& attr : *node->mutable_attr()) {
+ const string& placeholder = attr.second.placeholder();
+ if (placeholder.empty()) continue;
+
+ auto it = func_instantiation_attr.find(placeholder);
+ if (it != func_instantiation_attr.end()) {
+ attr.second = it->second;
+ } else {
+ return errors::InvalidArgument("Can't resolve placeholder: ",
+ placeholder);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
const InputArgExpansion& input_arg_expansion) {
- input_arg_expansions_.insert(
- {input_arg_expansion.input_name, input_arg_expansion});
+ const auto& input_name = input_arg_expansion.input_name;
+ const auto& placeholders = input_arg_expansion.placeholders;
+ input_arg_expansions_.emplace(input_name, input_arg_expansion);
+ for (int i = 0; i < placeholders.size(); ++i) {
+ const string& placeholder = input_arg_expansion.placeholders[i];
+ input_arg_placeholders_.emplace(
+ placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ }
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_.insert({node_name, outputs});
+ function_body_outputs_[node_name] = outputs;
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const string& func_def_input, std::vector<string>* graph_def_inputs) const {
using ::tensorflow::strings::Scanner;
+ if (IsControlInput(func_def_input)) {
+ graph_def_inputs->push_back(func_def_input);
+ return Status::OK();
+ }
+
// Parse input format: "node_name[:node_output][:position]"
string node_name;
string node_output;
@@ -150,11 +205,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
std::vector<string> expanded_inputs;
for (const string& function_def_input : function_body_node->input()) {
- if (!IsControlInput(function_def_input))
- TF_RETURN_IF_ERROR(
- ExpandFunctionDefInput(function_def_input, &expanded_inputs));
- else
- expanded_inputs.push_back(function_def_input);
+ TF_RETURN_IF_ERROR(
+ ExpandFunctionDefInput(function_def_input, &expanded_inputs));
}
function_body_node->clear_input();
@@ -163,10 +215,66 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
return Status::OK();
}
-Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name,
- DataType* data_type) const {
- auto it = func_attr_->find(type_attr_name);
- if (it == func_attr_->end()) {
+Status GrapplerFunctionConnectivity::AsFunctionDefInput(
+ const string& graph_def_input, string* func_def_input) const {
+ using gtl::FindOrNull;
+
+ if (IsControlInput(graph_def_input)) {
+ *func_def_input = graph_def_input;
+ return Status::OK();
+ }
+
+ int position;
+ string node_name = ParseNodeName(graph_def_input, &position);
+ CHECK_GE(position, 0);
+
+ // Check if it's an input arg placeholder
+ if (position == 0) {
+ const InputArgPlaceholder* placeholder =
+ FindOrNull(input_arg_placeholders_, node_name);
+ if (placeholder != nullptr) {
+ *func_def_input =
+ strings::StrCat(placeholder->input_name, ":", placeholder->position);
+ return Status::OK();
+ }
+ }
+
+ // It must be output from one of the function body nodes
+ const tensorflow::NameRangeMap* outputs_range_map =
+ FindOrNull(function_body_outputs_, node_name);
+ if (outputs_range_map != nullptr) {
+ for (const auto& el : *outputs_range_map) {
+ const auto& output_name = el.first;
+ const auto& output_range = el.second;
+ if (position >= output_range.first && position < output_range.second) {
+ int pos = position - output_range.first;
+ *func_def_input =
+ strings::StrCat(node_name, ":", output_name, ":", pos);
+ return Status::OK();
+ }
+ }
+ }
+
+ return errors::InvalidArgument("Unknown graph def input: ", graph_def_input);
+}
+
+Status GrapplerFunctionConnectivity::AsFunctionDefNode(
+ NodeDef* function_body_node) const {
+ string func_def_input;
+
+ for (int i = 0; i < function_body_node->input_size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ AsFunctionDefInput(function_body_node->input(i), &func_def_input));
+ function_body_node->set_input(i, func_def_input);
+ }
+
+ return Status::OK();
+}
+
+Status GrapplerFunctionItemInstantiation::GetTypeAttr(
+ const string& type_attr_name, DataType* data_type) const {
+ auto it = func_instantiation_attr_->find(type_attr_name);
+ if (it == func_instantiation_attr_->end()) {
return errors::InvalidArgument("Type attribute ", type_attr_name,
" is not defined");
} else if (it->second.type() == DT_INVALID) {
@@ -178,31 +286,48 @@ Status GrapplerFunctionItemBuilder::GetTypeAttr(const string& type_attr_name,
return Status::OK();
}
-Status GrapplerFunctionItemBuilder::GetArgType(const OpDef::ArgDef& arg,
- DataType* data_type) const {
+Status GrapplerFunctionItemInstantiation::GetArgType(
+ const OpDef::ArgDef& arg, DataType* data_type) const {
if (arg.type() != DT_INVALID) {
*data_type = arg.type();
} else {
+ if (!arg.type_list_attr().empty() || !arg.number_attr().empty()) {
+ return errors::InvalidArgument(
+ "Arguments with sequence of tensors are not supported. Unsupported "
+ "argument name: ",
+ arg.name());
+ }
TF_RETURN_IF_ERROR(GetTypeAttr(arg.type_attr(), data_type));
}
return Status::OK();
}
GrapplerFunctionItem::GrapplerFunctionItem(
- const string& function_name,
+ const string& func_name, const AttrValueMap& func_attr,
const std::vector<InputArgExpansion>& input_arg_expansions,
const std::vector<OutputArgExpansion>& output_arg_expansions,
GraphDef&& function_body)
- : function_name_(function_name),
+ : func_attr_(func_attr),
input_arg_expansions_(input_arg_expansions),
output_arg_expansions_(output_arg_expansions) {
+ id = func_name;
+ // Fill the feed nodes with input placeholders
+ for (const InputArgExpansion& input_arg : input_arg_expansions_) {
+ for (const string& placeholder : input_arg.placeholders) {
+ feed.emplace_back(placeholder, Tensor());
+ input_arg_placeholders_.insert(placeholder);
+ }
+ }
+ // Fill the fetch nodes with outputs
+ for (const OutputArgExpansion& output_arg : output_arg_expansions_) {
+ for (const string& output_tensor : output_arg.output_tensors) {
+ fetch.push_back(output_tensor);
+ }
+ }
+ // Swap the graph body
graph.Swap(&function_body);
}
-const string& GrapplerFunctionItem::function_name() const {
- return function_name_;
-}
-
const std::vector<InputArgExpansion>& GrapplerFunctionItem::inputs() const {
return input_arg_expansions_;
}
@@ -215,6 +340,11 @@ const std::size_t GrapplerFunctionItem::input_size() const {
return input_arg_expansions_.size();
}
+bool GrapplerFunctionItem::IsInputPlaceholder(const string& node_name) const {
+ return input_arg_placeholders_.find(node_name) !=
+ input_arg_placeholders_.end();
+}
+
const std::vector<OutputArgExpansion>& GrapplerFunctionItem::outputs() const {
return output_arg_expansions_;
}
@@ -227,10 +357,19 @@ const std::size_t GrapplerFunctionItem::output_size() const {
return output_arg_expansions_.size();
}
+const AttrValueMap& GrapplerFunctionItem::func_attr() const {
+ return func_attr_;
+}
+
const GraphDef& GrapplerFunctionItem::function_body() const { return graph; }
GraphDef& GrapplerFunctionItem::mutable_function_body() { return graph; }
+GrapplerFunctionItem& GrapplerFunctionItem::SwapFunctionBody(GraphDef&& other) {
+ graph.Swap(&other);
+ return *this;
+}
+
std::vector<string> OutputTensors(const GrapplerFunctionItem& item) {
std::vector<string> output_tensors;
for (const OutputArgExpansion& output : item.outputs()) {
@@ -241,18 +380,27 @@ std::vector<string> OutputTensors(const GrapplerFunctionItem& item) {
return output_tensors;
}
-Status MakeGrapplerFunctionItem(
- const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_attr,
- const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item) {
+Status MakeGrapplerFunctionItem(const FunctionDef& func,
+ const AttrValueMap& func_instantiation_attr,
+ const FunctionLibraryDefinition& flib,
+ GrapplerFunctionItem* item) {
const OpDef& signature = func.signature();
if (signature.name().empty()) {
return errors::InvalidArgument("Function name must be specified");
}
- // Helper methods to lookup function attributes
- GrapplerFunctionItemBuilder builder(&func_attr);
+ // Function types will be resolved from function instantiation attributes. All
+ // other attributes will be lost during conversion to FunctionDef.
+ for (const OpDef::AttrDef& attr : signature.attr()) {
+ if (attr.type() != "type") {
+ return errors::InvalidArgument(
+ "Function signature must have only type attributes");
+ }
+ }
+
+ // Helper methods to lookup function instantiation attributes
+ GrapplerFunctionItemInstantiation instantiation(&func_instantiation_attr);
// Mapping from FunctionDef input format (name[:output][:position]) to
// GraphDef input format (name[:position])
@@ -260,7 +408,10 @@ Status MakeGrapplerFunctionItem(
std::vector<InputArgExpansion> inputs;
std::vector<OutputArgExpansion> outputs;
+
+ // Function body shares the library with the graph that instantiated it.
GraphDef function_body;
+ *function_body.mutable_library() = flib.ToProto();
// TODO(ezhulenev): support functions with tensor sequence inputs/outputs
@@ -284,7 +435,7 @@ Status MakeGrapplerFunctionItem(
}
DataType input_data_type;
- TF_RETURN_IF_ERROR(builder.GetArgType(input, &input_data_type));
+ TF_RETURN_IF_ERROR(instantiation.GetArgType(input, &input_data_type));
NodeDef* placeholder = function_body.add_node();
placeholder->set_name(input.name());
@@ -292,6 +443,7 @@ Status MakeGrapplerFunctionItem(
(*placeholder->mutable_attr())["T"].set_type(input_data_type);
InputArgExpansion input_expansion{/*input_name=*/input.name(),
+ /*data_type=*/input_data_type,
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
inputs.push_back(input_expansion);
@@ -302,24 +454,12 @@ Status MakeGrapplerFunctionItem(
NodeDef* new_node = function_body.add_node();
*new_node = func_def_node;
- // Replace the placeholder attribute values with the specified value
- for (auto& attr : *new_node->mutable_attr()) {
- const string& ph_name = attr.second.placeholder();
- auto it = func_attr.find(ph_name);
- if (it != func_attr.end()) {
- attr.second = it->second;
- }
- }
-
- // Functions use a custom format to encode connectivity. Map these custom
- // strings to regular ones.
- tensorflow::NameRangeMap outputs_range_map;
- const OpRegistrationData* registration;
- TF_RETURN_IF_ERROR(func_library.LookUp(func_def_node.op(), &registration));
- TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
- func_def_node, registration->op_def, nullptr, &outputs_range_map));
- connectivity.RegisterFunctionBodyOutputs(func_def_node.name(),
- outputs_range_map);
+ // Resolve all placeholder values using function instantiation attributes.
+ TF_RETURN_IF_ERROR(ResolveFunctionBodyNodeAttrPlaceholders(
+ func_instantiation_attr, new_node));
+ // Register node output range in a function connectivity.
+ TF_RETURN_IF_ERROR(
+ RegisterFunctionBodyOutputs(flib, func_def_node, &connectivity));
}
// Rewrite inputs to use GraphDef format
@@ -331,20 +471,96 @@ Status MakeGrapplerFunctionItem(
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
auto ret = func.ret().find(out.name());
- if (ret != func.ret().end()) {
- // Expand outputs using provided output mapping
- TF_RETURN_IF_ERROR(
- connectivity.ExpandFunctionDefInput(ret->second, &output_tensors));
- } else {
- // Otherwise output must be one of the function inputs
- TF_RETURN_IF_ERROR(
- connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
+ TF_RETURN_IF_ERROR(
+ ret != func.ret().end()
+ // Expand outputs using provided output mapping
+ ? connectivity.ExpandFunctionDefInput(ret->second, &output_tensors)
+ // Otherwise output must be one of the function inputs
+ : connectivity.ExpandFunctionDefInput(out.name(), &output_tensors));
+
+ DataType output_data_type;
+ TF_RETURN_IF_ERROR(instantiation.GetArgType(out, &output_data_type));
+
+ OutputArgExpansion output{/*output_name=*/out.name(),
+ /*data_type=*/output_data_type,
+ /*output_tensors=*/output_tensors};
+ outputs.push_back(output);
+ }
+
+ *item = GrapplerFunctionItem(
+ /*func_name=*/signature.name(),
+ /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
+ inputs, outputs, std::move(function_body));
+ return Status::OK();
+}
+
+// Register GrapplerFunctionItem input arg expansion and function body outputs
+// in the GrapplerFunctionConnectivity
+Status RegisterGrapplerFunctionConnectivity(
+ const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
+ GrapplerFunctionConnectivity* connectivity) {
+ for (const InputArgExpansion& input : item.inputs()) {
+ connectivity->RegisterInputArgExpansion(input);
+ }
+ for (const NodeDef& func_body_node : item.function_body().node()) {
+ TF_RETURN_IF_ERROR(
+ RegisterFunctionBodyOutputs(flib, func_body_node, connectivity));
+ }
+ return Status::OK();
+}
+
+Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item,
+ const FunctionLibraryDefinition& flib,
+ FunctionDef* func) {
+ func->mutable_signature()->set_name(item.id);
+
+ // Build a GrapplerFunctionConnectivity from inputs and new function body.
+ GrapplerFunctionConnectivity connectivity;
+ TF_RETURN_IF_ERROR(
+ RegisterGrapplerFunctionConnectivity(item, flib, &connectivity));
+
+ // Add function input arguments.
+ for (const InputArgExpansion& input_arg : item.inputs()) {
+ OpDef::ArgDef arg_def;
+ arg_def.set_name(input_arg.input_name);
+ arg_def.set_type(input_arg.data_type);
+ *func->mutable_signature()->add_input_arg() = arg_def;
+ }
+
+ // Add function output arguments.
+ for (const OutputArgExpansion& output_arg : item.outputs()) {
+ OpDef::ArgDef arg_def;
+ arg_def.set_name(output_arg.output_name);
+ arg_def.set_type(output_arg.data_type);
+ *func->mutable_signature()->add_output_arg() = arg_def;
+
+ CHECK(output_arg.output_tensors.size() == 1) // do some sanity checking
+ << "Outputs of tensor sequences are not supported";
+
+ string ret;
+ for (const string& output_tensor : output_arg.output_tensors) {
+ TF_RETURN_IF_ERROR(connectivity.AsFunctionDefInput(output_tensor, &ret));
+ (*func->mutable_ret())[output_arg.output_name] = ret;
}
- outputs.push_back({out.name(), output_tensors});
}
- *item = GrapplerFunctionItem(signature.name(), inputs, outputs,
- std::move(function_body));
+ // Copy function definition specific attributes.
+ for (const auto& attr : item.func_attr()) {
+ const auto& attr_name = attr.first;
+ const auto& attr_value = attr.second;
+ (*func->mutable_attr())[attr_name] = attr_value;
+ }
+
+ // Copy function body nodes to the FunctionDef and update input format
+ for (const NodeDef& func_body_node : item.function_body().node()) {
+ // Do not copy input placeholders
+ if (item.IsInputPlaceholder(func_body_node.name())) continue;
+
+ NodeDef* func_def_node = func->add_node_def();
+ *func_def_node = func_body_node;
+ TF_RETURN_IF_ERROR(connectivity.AsFunctionDefNode(func_def_node));
+ }
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 60ea8857c0..2ac3917a66 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -28,14 +28,19 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
+using AttrValueMap = std::unordered_map<string, AttrValue>;
+
// Depending on the function instantiation attributes, input argument to the
// function might be a single tensor, list of tensors of the same type, or a
// list of tensors of different types.
//
// InputArgExpansion keeps track of the placeholders that were added to the
-// function body in place of function inputs.
+// function body in place of function inputs and a resolved input data type.
struct InputArgExpansion {
+ // TODO(ezhulenev): Add support for functions with tensor sequence inputs of
+ // different data types
string input_name; // name of the function input argument
+ DataType data_type; // input data type
std::vector<string> placeholders; // names of placeholder nodes in the
// function body
};
@@ -44,11 +49,14 @@ struct InputArgExpansion {
// to one or more outputs of one of the function body nodes.
//
// OutputArgExpansion keeps mapping from a function output arg to the output
-// tensors of a function body nodes, that compute function outputs.
+// tensors of a function body nodes and a resolved output data type
struct OutputArgExpansion {
+ // TODO(ezhulenev): Add support for functions with tensor sequence outputs of
+ // different data types
string output_name; // name of the function output argument
- std::vector<string> output_tensors; // names of output tensors from the
- // function body graph nodes
+ DataType data_type; // output data type
+ std::vector<string> output_tensors; // names of output tensor from the
+ // function body nodes
};
// FunctionDef uses different connectivity encoding for the function body nodes,
@@ -67,26 +75,46 @@ class GrapplerFunctionConnectivity {
Status ExpandFunctionDefInput(const string& func_def_input,
std::vector<string>* graph_def_inputs) const;
- // Update Node inputs from FunctionDef to GraphDef format
+ // Update Node inputs from FunctionDef to GraphDef format.
Status ExpandNodeInputs(NodeDef* function_body_node) const;
- // TODO(ezhulenev): fold GraphDef inputs back to FunctionDef format
- // Status FoldGraphDefInputs(const std::vector<sting> graph_def_inputs,
- // std::vector<string>* function_def_inputs) const;
+ // When expanding inputs in function def format, single input might be
+ // expanded into multiple tensors. When converting back to the function def
+ // format from graph def format, it's always a 1-to-1 relationship.
+ // FunctionDef built from GrapplerFunctionItem is always specialized to it's
+ // instantiation attributes and length of input args (and node def outputs) is
+ // known.
+
+ // Map from GraphDef input format to FunctionDef input format using registered
+ // input arg expansion and function body outputs.
+ Status AsFunctionDefInput(const string& graph_def_input,
+ string* func_def_input) const;
+
+ // Update Node inputs from GraphDef to FunctionDef format.
+ Status AsFunctionDefNode(NodeDef* function_body_node) const;
private:
+ // Mapping from input name to input arg expansion.
std::unordered_map<string, InputArgExpansion> input_arg_expansions_;
+ // Mapping from function body node name to output names range map.
std::unordered_map<string, tensorflow::NameRangeMap> function_body_outputs_;
+
+ struct InputArgPlaceholder {
+ string input_name;
+ int position;
+ };
+
+ // Mapping from input arg placeholder to the function input tensor.
+ std::unordered_map<string, InputArgPlaceholder> input_arg_placeholders_;
};
-// Helper methods to build GrapplerFunctionItem from a function def and function
-// attributes.
-class GrapplerFunctionItemBuilder {
+// Get Function type attributes using attributes of a node that instantiated
+// a function.
+class GrapplerFunctionItemInstantiation {
public:
- using FunctionAttr = std::unordered_map<string, AttrValue>;
-
- explicit GrapplerFunctionItemBuilder(const FunctionAttr* func_attr)
- : func_attr_(func_attr) {}
+ explicit GrapplerFunctionItemInstantiation(
+ const AttrValueMap* func_instantiation_attr)
+ : func_instantiation_attr_(func_instantiation_attr) {}
// Get DataType from attributes by name. Return error if attribute is missing,
// or it doesn't define a valid data type.
@@ -97,20 +125,20 @@ class GrapplerFunctionItemBuilder {
Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const;
private:
- const FunctionAttr* func_attr_; // do not own
+ const AttrValueMap* func_instantiation_attr_; // do not own
};
// A special case of GrapplerItem, constructed from a TensorFlow Function.
class GrapplerFunctionItem : public GrapplerItem {
public:
- GrapplerFunctionItem() {}
+ GrapplerFunctionItem() = default;
GrapplerFunctionItem(
- const string& function_name,
+ const string& func_name, const AttrValueMap& func_attr,
const std::vector<InputArgExpansion>& input_arg_expansions,
const std::vector<OutputArgExpansion>& output_arg_expansions,
GraphDef&& function_body);
- const string& function_name() const;
+ bool IsInputPlaceholder(const string& node_name) const;
const std::vector<InputArgExpansion>& inputs() const;
const InputArgExpansion& input(int i) const;
@@ -120,13 +148,20 @@ class GrapplerFunctionItem : public GrapplerItem {
const OutputArgExpansion& output(int i) const;
const std::size_t output_size() const;
+ const AttrValueMap& func_attr() const;
const GraphDef& function_body() const;
GraphDef& mutable_function_body();
+ GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
+
private:
- string function_name_;
+ AttrValueMap func_attr_; // Attributes specific to function definition that
+ // produced this item (FuncDef.attr field).
+
std::vector<InputArgExpansion> input_arg_expansions_;
std::vector<OutputArgExpansion> output_arg_expansions_;
+
+ std::set<string> input_arg_placeholders_;
};
// Return all output tensors referenced by item output args.
@@ -136,8 +171,21 @@ std::vector<string> OutputTensors(const GrapplerFunctionItem& item);
// Return error if the given function def cannot be converted.
Status MakeGrapplerFunctionItem(
const FunctionDef& func,
- const std::unordered_map<string, AttrValue>& func_attr,
- const FunctionLibraryDefinition& func_library, GrapplerFunctionItem* item);
+ const std::unordered_map<string, AttrValue>& func_instantiation_attr,
+ const FunctionLibraryDefinition& flib, GrapplerFunctionItem* item);
+
+// Register GrapplerFunctionItem input arg expansion and function body outputs
+// in the GrapplerFunctionConnectivity. Use function library definition to
+// lookup function body nodes output names and ranges.
+Status RegisterGrapplerFunctionConnectivity(
+ const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
+ GrapplerFunctionConnectivity* connectivity);
+
+// Make a specialized FunctionDef from the GrapplerFunctionItem. Use function
+// library definition to lookup function body nodes output names and ranges.
+Status MakeSpecializedFunctionDef(const GrapplerFunctionItem& item,
+ const FunctionLibraryDefinition& flib,
+ FunctionDef* func);
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index 1eb3298e89..a9a708bf67 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
@@ -32,8 +33,9 @@ class FunctionsTest : public ::testing::Test {};
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
GrapplerFunctionConnectivity connectivity;
- connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}});
- connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}});
+ connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}});
+ connectivity.RegisterInputArgExpansion(
+ {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}});
connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
connectivity.RegisterFunctionBodyOutputs("Func",
@@ -93,11 +95,50 @@ TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandFunctionDefInput) {
EXPECT_EQ("Func:3", inputs[0]);
}
+TEST_F(FunctionsTest, GrapplerFunctionConnectivity_AsFunctionDefInput) {
+ GrapplerFunctionConnectivity connectivity;
+
+ connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}});
+ connectivity.RegisterInputArgExpansion(
+ {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}});
+
+ connectivity.RegisterFunctionBodyOutputs("Add", {{"z", {0, 1}}});
+ connectivity.RegisterFunctionBodyOutputs("Func",
+ {{"o1", {0, 2}}, {"o2", {2, 4}}});
+
+ string input;
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputA", &input));
+ EXPECT_EQ("inputA:0", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_0", &input));
+ EXPECT_EQ("inputB:0", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("inputB_1", &input));
+ EXPECT_EQ("inputB:1", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("Add", &input));
+ EXPECT_EQ("Add:z:0", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func", &input));
+ EXPECT_EQ("Func:o1:0", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:1", &input));
+ EXPECT_EQ("Func:o1:1", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:2", &input));
+ EXPECT_EQ("Func:o2:0", input);
+
+ TF_EXPECT_OK(connectivity.AsFunctionDefInput("Func:3", &input));
+ EXPECT_EQ("Func:o2:1", input);
+}
+
TEST_F(FunctionsTest, GrapplerFunctionConnectivity_ExpandNodeInputs) {
GrapplerFunctionConnectivity connectivity;
- connectivity.RegisterInputArgExpansion({"inputA", {"inputA"}});
- connectivity.RegisterInputArgExpansion({"inputB", {"inputB_0", "inputB_1"}});
+ connectivity.RegisterInputArgExpansion({"inputA", DT_FLOAT, {"inputA"}});
+ connectivity.RegisterInputArgExpansion(
+ {"inputB", DT_FLOAT, {"inputB_0", "inputB_1"}});
NodeDef node;
node.add_input("inputA:0");
@@ -131,12 +172,12 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
std::unordered_map<string, AttrValue> func_attr;
func_attr["T"].set_type(DT_FLOAT);
- FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
- EXPECT_EQ("XTimesTwo", item.function_name());
+ EXPECT_EQ("XTimesTwo", item.id);
EXPECT_EQ(4, item.function_body().node_size());
EXPECT_EQ(1, item.input_size());
@@ -206,12 +247,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
std::unordered_map<string, AttrValue> func_attr;
func_attr["T"].set_type(DT_FLOAT);
- FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
- EXPECT_EQ("SubGrad", item.function_name());
+ EXPECT_EQ("SubGrad", item.id);
EXPECT_EQ(12, item.function_body().node_size());
ASSERT_EQ(3, item.input_size());
@@ -251,8 +292,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
}
TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
- FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
- TF_ASSERT_OK(library.AddFunctionDef(FunctionDefHelper::Define(
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
+ TF_ASSERT_OK(flib.AddFunctionDef(FunctionDefHelper::Define(
// Name
"Swap",
// Args
@@ -290,7 +331,7 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
func_attr["T"].set_type(DT_FLOAT);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
@@ -348,10 +389,10 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
{{"out", "Exp:y:0"}});
std::unordered_map<string, AttrValue> func_attr;
- FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("Exp", item.output(0).output_tensors[0]);
@@ -391,12 +432,12 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
{{"out0", "in0"}});
std::unordered_map<string, AttrValue> func_attr;
- FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
- EXPECT_EQ("ForwardInputs", item.function_name());
+ EXPECT_EQ("ForwardInputs", item.id);
EXPECT_EQ(5, item.function_body().node_size());
EXPECT_EQ(3, item.output_size());
@@ -437,10 +478,10 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
std::unordered_map<string, AttrValue> func_attr;
func_attr["T"].set_type(DT_FLOAT);
- FunctionLibraryDefinition library(OpRegistry::Global(), FunctionDefLibrary());
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, library, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
EXPECT_EQ(0, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -456,6 +497,104 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
EXPECT_EQ("two", cast.input(0));
}
+TEST_F(FunctionsTest, MakeSpecializedFunctionDef) {
+ const Tensor kTwo = test::AsScalar<int64>(2);
+ FunctionDef func = FunctionDefHelper::Define(
+ // Name
+ "XTimesTwo",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"two"}, "Const", {}, {{"value", kTwo}, {"dtype", DT_INT64}}},
+ {{"scale"}, "Cast", {"two"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"y"}, "Mul", {"x", "scale"}, {{"T", "$T"}}},
+ });
+
+ std::unordered_map<string, AttrValue> func_attr;
+ func_attr["T"].set_type(DT_FLOAT);
+ FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
+
+ GrapplerFunctionItem item;
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+
+ FunctionDef specialized;
+ TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized));
+
+ // Input and output types are resolved based on instantiation attributes.
+ EXPECT_EQ("x", specialized.signature().input_arg(0).name());
+ EXPECT_EQ(DT_FLOAT, specialized.signature().input_arg(0).type());
+ EXPECT_EQ("y", specialized.signature().output_arg(0).name());
+ EXPECT_EQ(DT_FLOAT, specialized.signature().output_arg(0).type());
+
+ // Function body specialized for instantiation types
+ int count = 0;
+ for (const NodeDef &node : specialized.node_def()) {
+ if (node.name() == "scale" && count++) {
+ EXPECT_EQ(DT_FLOAT, node.attr().at("DstT").type());
+ } else if (node.name() == "y" && count++) {
+ EXPECT_EQ("Mul", node.op());
+ EXPECT_EQ("x:0", node.input(0));
+ EXPECT_EQ("scale:y:0", node.input(1));
+ EXPECT_EQ(DT_FLOAT, node.attr().at("T").type());
+ }
+ }
+ EXPECT_EQ(2, count);
+}
+
+TEST_F(FunctionsTest, SwapFunctionBodyAndMakeSpecializedFunctionDef) {
+ using test::function::NDef;
+
+ FunctionDef mul_func = FunctionDefHelper::Create(
+ "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ FunctionDef func = FunctionDefHelper::Create(
+ "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
+ {{{"output"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "output:z:0"}});
+
+ GraphDef id_func_body = test::function::GDef(
+ {/* pass input to output through identity */
+ NDef("output", "Identity", {"x"}, {{"T", "float"}})});
+
+ std::unordered_map<string, AttrValue> func_attr;
+ func_attr["T"].set_type(DT_FLOAT);
+
+ FunctionDefLibrary lib_def;
+ *lib_def.add_function() = func;
+ *lib_def.add_function() = mul_func;
+ FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def);
+
+ GrapplerFunctionItem item;
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+
+ // Replace function body with identity function
+ item.SwapFunctionBody(std::move(id_func_body));
+ FunctionDef specialized;
+ TF_EXPECT_OK(MakeSpecializedFunctionDef(item, flib, &specialized));
+
+ // Check that graph body was updated.
+ int count = 0;
+ for (const NodeDef &node : specialized.node_def()) {
+ if (node.name() == "output" && count++) {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("x:0", node.input(0));
+ }
+ }
+ EXPECT_EQ(1, count);
+
+ // And return tensor mapping was updated with a new output name (z->output).
+ EXPECT_EQ("output:output:0", (*specialized.mutable_ret())["z"]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow