/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/function.h" #include #include #include #include #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/function.pb_text.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { // Extracts the actual type from "attr_values" based on its definition // "arg_def". // // If "arg_def" is a N*T type, *is_type_list is set to false, and // *dtypes is set to be a vector of size N and each element is T. // // If "arg_def" is a list(type), *is_type_list is set to true, and // *dtypes is set to be a vector of types specified in attrs for // arg_def. // // Otherwise (arg_def is a simple type T), *is_type_list is set to // false, and *dtypes is set to a single element vector, whose only // element is T. Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, bool* is_type_list, DataTypeVector* dtypes) { dtypes->clear(); if (!arg_def.type_list_attr().empty()) { const AttrValue* v = attrs.Find(arg_def.type_list_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_list_attr()); } *is_type_list = true; for (int i = 0; i < v->list().type_size(); ++i) { dtypes->push_back(v->list().type(i)); } return Status::OK(); } *is_type_list = false; int num = 1; if (!arg_def.number_attr().empty()) { const AttrValue* v = attrs.Find(arg_def.number_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_attr()); } num = v->i(); } DataType dtype; if (arg_def.type() != DT_INVALID) { dtype = arg_def.type(); } else if (arg_def.type_attr().empty()) { dtype = DT_INVALID; } else { const AttrValue* v = attrs.Find(arg_def.type_attr()); if (v == nullptr) { return errors::NotFound("type attr not found: ", arg_def.type_attr()); } dtype = v->type(); } dtypes->resize(num, dtype); return Status::OK(); } namespace { template void AddAttr(const string& name, const T& val, NodeDef* ndef) { SetAttrValue(val, &((*ndef->mutable_attr())[name])); } Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { // attr_values should specify all attrs defined in fdef. for (const auto& a : sig.attr()) { const AttrValue* v = attr_values.Find(a.name()); if (!v) { return errors::NotFound("Attr ", a.name(), " is not found from ", SummarizeOpDef(sig)); } Status status = AttrValueHasType(*v, a.type()); if (!status.ok()) { errors::AppendToMessage(&status, "for attr '", a.name(), "'"); return status; } } // TODO(josh11b): Enable this code once it works with function gradients. // Right now the C++ function gradient code assumes it can pass // all the attrs of the function to the gradient, and any attrs that // the gradient doesn't care about will be ignored. #if 0 if (attr_values.size() != sig.attr_size()) { for (const auto& a : attr_values) { // TODO(josh11b): Possibly should ignore attrs that start with "_" here? bool found = false; for (const auto& s : sig.attr()) { if (a.first == s.name()) { found = true; break; } } if (!found) { return errors::NotFound("Attr ", a.first, " is not found in ", SummarizeOpDef(sig)); } } } #endif return Status::OK(); } // A helper class for instantiating functions. This contains shared information // like the resulting graph and node name index. class FunctionInstantiationHelper { public: FunctionInstantiationHelper(GetFunctionSignature get_function, InstantiationResult* result) : get_function_(std ::move(get_function)), result_(*result) { result_.nodes.clear(); } // Builds index for nodes that can be used as node's input arguments. Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values) { bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR( ArgNumType(attr_values, arg_def, &is_type_list, &dtypes)); CHECK_GE(dtypes.size(), size_t{1}); int arg_index = result_.nodes.size(); TF_RETURN_IF_ERROR( AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes})); // Creates dtypes.size() nodes in the graph. for (size_t i = 0; i < dtypes.size(); ++i) { TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":", i), {true, arg_index, 0, false, {dtypes[i]}})); DCHECK_EQ(arg_index, result_.nodes.size()); string name = arg_def.name(); if (dtypes.size() > 1) { strings::StrAppend(&name, "_", i); } NodeDef* gnode = AddNode(name); gnode->set_op(FunctionLibraryDefinition::kArgOp); AddAttr("T", dtypes[i], gnode); AddAttr("index", arg_index, gnode); result_.arg_types.push_back(dtypes[i]); ++arg_index; } return Status::OK(); } Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, const int arg_index) { const OpDef* node_sig = nullptr; TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig)); if (node_sig->output_arg_size() == 0) { return AddItem(node.name(), {false, arg_index, 0, false, {}}); } const int num_retval = node_sig->output_arg_size(); int start = 0; bool is_type_list; DataTypeVector dtypes; for (int i = 0; i < num_retval; ++i) { TF_RETURN_IF_ERROR( ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes)); // Note that we rely on the backwards-compatibility test enforcing // that output_arg(*).name() doesn't change here. const string base_name = strings::StrCat(node.name(), ":", node_sig->output_arg(i).name()); TF_RETURN_IF_ERROR( AddItem(base_name, {false, arg_index, start, is_type_list, dtypes})); for (int j = 0; j < static_cast(dtypes.size()); ++j) { TF_RETURN_IF_ERROR( AddItem(strings::StrCat(base_name, ":", j), {false, arg_index, start + j, false, {dtypes[j]}})); } start += dtypes.size(); } return Status::OK(); } Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { const OpDef* fnode_sig = nullptr; TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig)); NodeDef* gnode = AddNode(fnode.name()); gnode->set_op(fnode.op()); gnode->set_device(fnode.device()); int gnode_idx = nodes_.size() - 1; // Input const int num_args = fnode_sig->input_arg_size(); bool is_type_list; // ignored DataTypeVector dtypes; int fnode_arg_index = 0; for (int i = 0; i < num_args; ++i) { TF_RETURN_IF_ERROR( ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes)); // Consume inputs (indexed by fnode_arg_index) until we have // matched each element of dtypes (indexed by j). for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) { if (fnode_arg_index >= fnode.input_size()) { // Should never happen if we computed dtypes correctly. return errors::InvalidArgument( "Attempt to access beyond input size: ", fnode_arg_index, " >= ", fnode.input_size()); } // Look up the next input. const string& input_name = fnode.input(fnode_arg_index); const auto* item = GetItemOrNull(input_name); if (item == nullptr) { return errors::InvalidArgument( "input ", input_name, " is not found: ", SummarizeNodeDef(fnode)); } if (item->dtypes.size() > dtypes.size() - j) { return errors::InvalidArgument("Input ", input_name, " too long for ", fnode_sig->input_arg(i).name()); } // Match up all the elements of this input (indexed by k) with // elements of dtypes (advancing j). for (int k = 0; k < item->dtypes.size(); ++k, ++j) { if (item->dtypes[k] != dtypes[j]) { return errors::InvalidArgument( "input ", fnode_sig->input_arg(i).name(), "[", j, "] expected type ", DataTypeString(dtypes[j]), " != ", DataTypeString(item->dtypes[k]), ", the type of ", input_name, "[", k, "]"); } if (item->is_func_arg) { AddInput(gnode_idx, item->nid + k, 0); } else { AddInput(gnode_idx, item->nid, item->idx + k); } } } } // Control deps. for (int i = fnode_arg_index; i < fnode.input_size(); ++i) { const string& input = fnode.input(i); if (input.empty() || input[0] != '^') { return errors::InvalidArgument("Expected input[", i, "] == '", input, "' to be a control input."); } int nid = -1; const string node_name = input.substr(1); const string node_colon = node_name + ":"; const string node_colon_bound = node_name + ";"; // index_ is a map sorted lexicographically, so the key we are looking for // must lie in the range [node_name, node_colon_bound). auto it = index_.lower_bound(node_name); while (it != index_.end() && it->first <= node_colon_bound) { if (it->first == node_name || tensorflow::str_util::StartsWith(it->first, node_colon)) { nid = it->second.nid; break; } ++it; } if (nid == -1) { return errors::InvalidArgument("input[", i, "] == '", input, "', is not found."); } AddDep(gnode_idx, nid); } // Attrs. for (const auto& p : attrs) { (*gnode->mutable_attr())[p.first] = p.second; } return Status::OK(); } Status AddReturnNode( const OpDef::ArgDef& ret_def, AttrSlice attrs, const ::tensorflow::protobuf::Map& ret_map, int* ret_index) { auto ret_iter = ret_map.find(ret_def.name()); if (ret_iter == ret_map.end()) { return errors::InvalidArgument("Return ", ret_def.name(), " missing."); } bool is_type_list; DataTypeVector dtypes; TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes)); CHECK_GE(dtypes.size(), size_t{1}); const auto* item = GetItemOrNull(ret_iter->second); if (item == nullptr) { return errors::InvalidArgument("Return ", ret_def.name(), " -> ", ret_iter->second, " is not found."); } if (dtypes != item->dtypes) { return errors::InvalidArgument("Invalid ret types ", ret_def.name(), " : ", DataTypeVectorString(dtypes), " vs. ", DataTypeVectorString(item->dtypes)); } for (size_t i = 0; i < dtypes.size(); ++i) { string name = strings::StrCat(ret_def.name(), "_RetVal"); if (dtypes.size() > 1) { strings::StrAppend(&name, "_", i); } NodeDef* gnode = AddNode(name); gnode->set_op(FunctionLibraryDefinition::kRetOp); AddInput(nodes_.size() - 1, item->nid, item->idx + i); AddAttr("T", dtypes[i], gnode); AddAttr("index", (*ret_index)++, gnode); result_.ret_types.push_back(dtypes[i]); } return Status::OK(); } // Adds the actual node inputs to the result graph by converting indexes to // the node names. void AddNodeInputs() { for (int i = 0; i < result_.nodes.size(); i++) { NodeInfo& node_info = nodes_[i]; for (const auto& p : node_info.data_inputs) { result_.nodes[i].add_input(Name(p.first, p.second)); } for (int index : node_info.control_inputs) { result_.nodes[i].add_input(Dep(index)); } } } private: // This is used to build a small index for all names that can be used as a // node's input arguments. // // If is_func_arg is true, the name is a function's argument. In // this case, the produced graph def has node[nid:nid + dtype.size()]. // // Otherwise, the name is a function body's node return value. In // this case, the produced graph def has one node node[nid] and // the node's output index [idx ... idx + num) corresponds to the // named outputs. // // In all cases, "dtype" specifies the data type. struct NameInfoItem { bool is_func_arg; int nid; int idx; bool is_type_list; DataTypeVector dtypes; }; // Adds an item into the input name index. Status AddItem(const string& name, const NameInfoItem& item) { if (!index_.insert({name, item}).second) { return errors::InvalidArgument( strings::StrCat("Duplicated ", item.is_func_arg ? "arg" : "ret", " name: "), name); } return Status::OK(); } const NameInfoItem* GetItemOrNull(const string& name) const { return gtl::FindOrNull(index_, name); } string Dep(int node_index) const { return strings::StrCat("^", Name(node_index)); } string Name(int node_index) const { CHECK_LT(node_index, nodes_.size()); return nodes_[node_index].name; } string Name(int node_index, int output_index) const { if (output_index == 0) { return Name(node_index); } else { return strings::StrCat(Name(node_index), ":", output_index); } } NodeDef* AddNode(const string& name) { result_.nodes.emplace_back(); NodeDef* gnode = &result_.nodes.back(); gnode->set_name(name); nodes_.push_back({name, {}, {}}); CHECK_EQ(result_.nodes.size(), nodes_.size()); return gnode; } void AddInput(int node_index, int output_node, int output_index) { CHECK_LT(node_index, nodes_.size()); nodes_[node_index].data_inputs.push_back( std::make_pair(output_node, output_index)); } void AddDep(int node_index, int dep_index) { CHECK_LT(node_index, nodes_.size()); nodes_[node_index].control_inputs.push_back(dep_index); } GetFunctionSignature get_function_; InstantiationResult& result_; // A small index for all names that can be used as a node's input arguments. std::map index_; // This contains information about a node in the new graph including the node // names and input nodes' indexes. struct NodeInfo { string name; // Data inputs where means arg k of node n. std::vector> data_inputs; // Control inputs (dependencies). std::vector control_inputs; }; // nodes_[i] is the information about result_.nodes[i]. std::vector nodes_; }; // Various helpers Print(proto) to print relevant protos to ascii. string Print(const OpDef::ArgDef& arg) { string out; strings::StrAppend(&out, arg.name(), ":"); if (arg.is_ref()) strings::StrAppend(&out, "Ref("); if (!arg.number_attr().empty()) { strings::StrAppend(&out, arg.number_attr(), "*"); } if (arg.type() != DT_INVALID) { strings::StrAppend(&out, DataTypeString(arg.type())); } else { strings::StrAppend(&out, arg.type_attr()); } if (arg.is_ref()) strings::StrAppend(&out, ")"); return out; } // TODO(josh11b): Merge this with SummarizeAttrValue(). string Print(const AttrValue& attr_value) { if (attr_value.value_case() == AttrValue::kType) { return DataTypeString(attr_value.type()); } else if ((attr_value.value_case() == AttrValue::kList) && (attr_value.list().type_size() > 0)) { string ret = "{"; for (int i = 0; i < attr_value.list().type_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); } strings::StrAppend(&ret, "}"); return ret; } else if (attr_value.value_case() == AttrValue::kFunc) { if (attr_value.func().attr_size() == 0) { return attr_value.func().name(); } std::vector entries; for (auto p : attr_value.func().attr()) { entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); } std::sort(entries.begin(), entries.end()); return strings::StrCat(attr_value.func().name(), "[", str_util::Join(entries, ", "), "]"); } return SummarizeAttrValue(attr_value); } // TODO(josh11b): Merge this with SummarizeNodeDef(). string Print(const NodeDef& n) { string out; strings::StrAppend(&out, n.name(), " = ", n.op()); if (n.attr_size() > 0) { std::vector entries; for (auto& a : n.attr()) { entries.push_back(strings::StrCat(a.first, "=", Print(a.second))); } std::sort(entries.begin(), entries.end()); strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); } strings::StrAppend(&out, "("); std::vector dat; std::vector dep; for (StringPiece s : n.input()) { if (str_util::ConsumePrefix(&s, "^")) { dep.emplace_back(s); } else { dat.push_back(s); } } strings::StrAppend(&out, str_util::Join(dat, ", "), ")"); if (!dep.empty()) { strings::StrAppend(&out, " @ ", str_util::Join(dep, ", ")); } return out; } string Print(const FunctionDef& fdef) { string out; const OpDef& sig = fdef.signature(); strings::StrAppend(&out, "\n", sig.name()); if (sig.attr_size() > 0) { strings::StrAppend(&out, "["); for (int i = 0; i < sig.attr_size(); ++i) { const auto& a = sig.attr(i); if (i > 0) strings::StrAppend(&out, ", "); if (a.type() == "type") { strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values())); } else { strings::StrAppend(&out, a.name(), ":", a.type()); } } strings::StrAppend(&out, "]"); } strings::StrAppend(&out, "("); for (int i = 0; i < sig.input_arg_size(); ++i) { if (i > 0) strings::StrAppend(&out, ", "); strings::StrAppend(&out, Print(sig.input_arg(i))); } strings::StrAppend(&out, ") -> ("); for (int i = 0; i < sig.output_arg_size(); ++i) { if (i > 0) strings::StrAppend(&out, ", "); strings::StrAppend(&out, Print(sig.output_arg(i))); } strings::StrAppend(&out, ") {\n"); for (const auto& n : fdef.node_def()) { strings::StrAppend(&out, " ", Print(n), "\n"); } for (const auto& r : fdef.ret()) { strings::StrAppend(&out, " return ", r.first, " = ", r.second, "\n"); } strings::StrAppend(&out, "}\n"); return out; } string Print(gtl::ArraySlice nodes) { std::vector arg; std::vector ret; std::vector body; for (const NodeDef* n : nodes) { if (n->op() == FunctionLibraryDefinition::kArgOp) { arg.push_back(n); } else if (n->op() == FunctionLibraryDefinition::kRetOp) { ret.push_back(n); } else { body.push_back(n); } } auto comp = [](const NodeDef* x, const NodeDef* y) { int xi; TF_CHECK_OK(GetNodeAttr(*x, "index", &xi)); int yi; TF_CHECK_OK(GetNodeAttr(*y, "index", &yi)); return xi < yi; }; std::sort(arg.begin(), arg.end(), comp); std::sort(ret.begin(), ret.end(), comp); string out; strings::StrAppend(&out, "\n("); auto get_type = [](const NodeDef& n) { DataType dt; if (!GetNodeAttr(n, "T", &dt).ok()) { dt = DT_INVALID; } return DataTypeString(dt); }; for (size_t i = 0; i < arg.size(); ++i) { const NodeDef* n = arg[i]; if (i > 0) strings::StrAppend(&out, ", "); CHECK_GE(n->attr_size(), 2); strings::StrAppend(&out, n->name(), ":", get_type(*n)); } strings::StrAppend(&out, ") -> ("); for (size_t i = 0; i < ret.size(); ++i) { const NodeDef* n = ret[i]; if (i > 0) strings::StrAppend(&out, ", "); CHECK_LE(2, n->attr_size()); CHECK_EQ(1, n->input_size()); strings::StrAppend(&out, n->input(0), ":", get_type(*n)); } strings::StrAppend(&out, ") {\n"); for (size_t i = 0; i < body.size(); ++i) { strings::StrAppend(&out, " ", Print(*body[i]), "\n"); } strings::StrAppend(&out, "}\n"); return out; } Status AddDefaultAttrs(const string& op, const GetFunctionSignature& get_function, AttrValueMap* attrs) { const OpDef* op_def = nullptr; TF_RETURN_IF_ERROR(get_function(op, &op_def)); AttrSlice attr_slice(attrs); for (const auto& attr_def : op_def->attr()) { if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) { if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) { return errors::Internal("Somehow duplicated: ", attr_def.name()); } } } return Status::OK(); } } // end namespace Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, GetFunctionSignature get_function, InstantiationResult* result) { VLOG(3) << "Instantiation Function: " << Print(fdef); const OpDef& sig = fdef.signature(); TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); FunctionInstantiationHelper helper(get_function, result); Status s; for (const OpDef::ArgDef& arg_def : sig.input_arg()) { s = helper.BuildInputArgIndex(arg_def, attr_values); if (!s.ok()) { errors::AppendToMessage(&s, "In ", Print(arg_def)); return s; } } auto substitute = [attr_values](StringPiece name, AttrValue* val) { if (const AttrValue* v = attr_values.Find(name)) { *val = *v; return true; } return false; }; // Makes a copy of all attrs in fdef and substitutes placeholders. // After this step, every attr is bound to a concrete value. std::vector node_attrs; node_attrs.resize(fdef.node_def_size()); for (int i = 0; i < fdef.node_def_size(); ++i) { for (auto attr : fdef.node_def(i).attr()) { if (!SubstitutePlaceholders(substitute, &attr.second)) { return errors::InvalidArgument("Failed to bind all placeholders in ", SummarizeAttrValue(attr.second)); } if (!node_attrs[i].insert(attr).second) { return errors::Internal("Somehow duplicated: ", attr.first); } } TF_RETURN_IF_ERROR( AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i])); } for (int i = 0; i < fdef.node_def_size(); ++i) { s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), result->nodes.size() + i); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); return s; } } // Emits one node for each fdef.node_def. for (int i = 0; i < fdef.node_def_size(); ++i) { s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); if (!s.ok()) { errors::AppendToMessage(&s, "In ", SummarizeNodeDef(fdef.node_def(i))); return s; } } // Emits nodes for the function's return values. int ret_index = 0; for (const OpDef::ArgDef& ret_def : sig.output_arg()) { s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), &ret_index); if (!s.ok()) { errors::AppendToMessage(&s, "In function output ", Print(ret_def)); return s; } } // Adds the actual node inputs using the input indexes. helper.AddNodeInputs(); return Status::OK(); } string DebugString(const FunctionDef& func_def) { return Print(func_def); } string DebugString(const GraphDef& instantiated_func_def) { std::vector ptrs; for (const NodeDef& n : instantiated_func_def.node()) { ptrs.push_back(&n); } return Print(ptrs); } string DebugString(gtl::ArraySlice instantiated_func_nodes) { std::vector ptrs; for (const NodeDef& n : instantiated_func_nodes) { ptrs.push_back(&n); } return Print(ptrs); } string DebugStringWhole(const GraphDef& gdef) { string ret; for (const auto& fdef : gdef.library().function()) { strings::StrAppend(&ret, Print(fdef)); } strings::StrAppend(&ret, "\n"); for (const auto& ndef : gdef.node()) { strings::StrAppend(&ret, Print(ndef), "\n"); } return ret; } namespace { // Returns the name -> attr mapping of fdef's attrs that have a value set. In // Python, it's possible to access unset attrs, which returns a default value // and adds an unset attr to the map. std::map GetSetAttrs(const FunctionDef& fdef) { std::map set_attrs; for (auto pair : fdef.attr()) { if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) { set_attrs[pair.first] = pair.second; } } return set_attrs; } } // end namespace bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { if (!OpDefEqual(f1.signature(), f2.signature())) return false; std::map f1_attrs = GetSetAttrs(f1); std::map f2_attrs = GetSetAttrs(f2); if (f1_attrs.size() != f2_attrs.size()) return false; for (auto iter1 : f1_attrs) { auto iter2 = f2_attrs.find(iter1.first); if (iter2 == f2_attrs.end()) return false; if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false; } if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) { return false; } std::map ret1(f1.ret().begin(), f1.ret().end()); std::map ret2(f2.ret().begin(), f2.ret().end()); if (ret1 != ret2) return false; return true; } uint64 FunctionDefHash(const FunctionDef& fdef) { // signature uint64 h = OpDefHash(fdef.signature()); // attrs std::map attrs = GetSetAttrs(fdef); for (const auto& p : attrs) { h = Hash64(p.first.data(), p.first.size(), h); h = Hash64Combine(AttrValueHash(p.second), h); } // node defs h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h); // output names std::map ret(fdef.ret().begin(), fdef.ret().end()); for (const auto& p : ret) { h = Hash64(p.first.data(), p.first.size(), h); h = Hash64(p.second.data(), p.second.size(), h); } return h; } string Canonicalize(const string& funcname, AttrSlice attrs, const FunctionLibraryRuntime::InstantiateOptions& options) { std::vector entries; entries.reserve(options.target.empty() ? attrs.size() : (attrs.size() + 1)); for (auto p : attrs) { entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); } if (!options.target.empty()) { entries.push_back( strings::StrCat("_target", "=", str_util::CEscape(options.target))); } if (options.overlay_lib) { entries.push_back(strings::StrCat( "_overlay_lib", "=", reinterpret_cast(options.overlay_lib))); } if (!options.state_handle.empty()) { entries.push_back( strings::StrCat("_state_handle", "=", options.state_handle)); } if (!options.executor_type.empty()) { entries.push_back( strings::StrCat("_executor_type", "=", options.executor_type)); } std::sort(entries.begin(), entries.end()); return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); } FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, DataTypeSlice ret_types) : arg_types_(arg_types.begin(), arg_types.end()), ret_types_(ret_types.begin(), ret_types.end()) { args_.resize(arg_types_.size()); rets_.resize(ret_types_.size()); } FunctionCallFrame::~FunctionCallFrame() {} Status FunctionCallFrame::SetArgs(gtl::ArraySlice args) { // Input type checks. if (args.size() != arg_types_.size()) { return errors::InvalidArgument("Expects ", arg_types_.size(), " arguments, but ", args.size(), " is provided"); } for (size_t i = 0; i < args.size(); ++i) { if (arg_types_[i] != args[i].dtype()) { return errors::InvalidArgument( "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ", DataTypeString(args[i].dtype()), " is provided"); } args_[i] = args[i]; } return Status::OK(); } Status FunctionCallFrame::GetRetvals(std::vector* rets) const { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { const auto& item = rets_[i]; if (item.has_val) { rets->push_back(item.val); } else { return errors::Internal("Retval[", i, "] does not have value"); } } return Status::OK(); } Status FunctionCallFrame::ConsumeRetvals(std::vector* rets, bool allow_dead_tensors) { rets->clear(); rets->reserve(rets_.size()); for (size_t i = 0; i < rets_.size(); ++i) { if (rets_[i].has_val) { rets->emplace_back(std::move(rets_[i].val)); } else if (allow_dead_tensors) { rets->emplace_back(); } else { return errors::Internal("Retval[", i, "] does not have value"); } } return Status::OK(); } Status FunctionCallFrame::GetArg(int index, Tensor* val) const { if (index < 0 || static_cast(index) >= args_.size()) { return errors::InvalidArgument("GetArg ", index, " is not within [0, ", args_.size(), ")"); } *val = args_[index]; return Status::OK(); } Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { if (index < 0 || static_cast(index) >= rets_.size()) { return errors::InvalidArgument("SetRetval ", index, " is not within [0, ", rets_.size(), ")"); } if (val.dtype() != ret_types_[index]) { return errors::InvalidArgument( "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]), ", but ", DataTypeString(val.dtype()), " is provided."); } Retval* item = &rets_[index]; if (!item->has_val) { item->has_val = true; item->val = val; } else { return errors::Internal("Retval[", index, "] has already been set."); } return Status::OK(); } FunctionLibraryDefinition::FunctionDefAndOpRegistration:: FunctionDefAndOpRegistration(const FunctionDef& fdef_in) : fdef(fdef_in), // Exact shape inference for functions is handled by ShapeRefiner. // Here we pass a dummy shape inference function for legacy code paths. op_registration_data(fdef.signature(), shape_inference::UnknownShape, true /* is_function */) {} FunctionLibraryDefinition::FunctionLibraryDefinition( const FunctionLibraryDefinition& other) : default_registry_(other.default_registry_) { tf_shared_lock l(other.mu_); for (const auto& it : other.function_defs_) { TF_CHECK_OK(AddFunctionDef(it.second->fdef)); } func_grad_ = other.func_grad_; } FunctionLibraryDefinition::FunctionLibraryDefinition( const OpRegistryInterface* default_registry, const FunctionDefLibrary& def_lib) : default_registry_(default_registry), function_defs_(def_lib.function_size()) { for (const auto& fdef : def_lib.function()) { // The latter function definition wins. auto& ptr = function_defs_[fdef.signature().name()]; ptr.reset(new FunctionDefAndOpRegistration(fdef)); } for (const auto& grad : def_lib.gradient()) { func_grad_[grad.function_name()] = grad.gradient_func(); } } FunctionLibraryDefinition::~FunctionLibraryDefinition() {} bool FunctionLibraryDefinition::Contains(const string& func) const { tf_shared_lock l(mu_); return function_defs_.find(func) != function_defs_.end(); } const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const { tf_shared_lock l(mu_); return FindHelper(func); } const FunctionDef* FunctionLibraryDefinition::FindHelper( const string& func) const { auto iter = function_defs_.find(func); if (iter == function_defs_.end()) { return nullptr; } else { return &iter->second->fdef; } } Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) { mutex_lock l(mu_); bool added; return AddFunctionDefHelper(fdef, &added); } Status FunctionLibraryDefinition::AddFunctionDefHelper(const FunctionDef& fdef, bool* added) { *added = false; std::unique_ptr* entry = &function_defs_[fdef.signature().name()]; if (*entry != nullptr) { if (!FunctionDefsEqual((*entry)->fdef, fdef)) { return errors::InvalidArgument( "Cannot add function '", fdef.signature().name(), "' because a different function with the same name already " "exists."); } // Ignore duplicate FunctionDefs return Status::OK(); } const OpDef* op_def; if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) { return errors::InvalidArgument( "Cannot add function '", fdef.signature().name(), "' because an op with the same name already exists."); } entry->reset(new FunctionDefAndOpRegistration(fdef)); *added = true; return Status::OK(); } Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { mutex_lock l(mu_); bool added; return AddGradientDefHelper(grad, &added); } Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, bool* added) { *added = false; string* entry = &func_grad_[grad.function_name()]; if (!entry->empty()) { if (*entry != grad.gradient_func()) { return errors::InvalidArgument( "Cannot assign gradient function '", grad.gradient_func(), "' to '", grad.function_name(), "' because it already has gradient function ", "'", *entry, "'"); } // Ignore duplicate GradientDefs return Status::OK(); } *entry = grad.gradient_func(); *added = true; return Status::OK(); } Status FunctionLibraryDefinition::AddLibrary( const FunctionLibraryDefinition& other) { // Clone `other` to ensure thread-safety (grabbing `other`'s lock for // the duration of the function could lead to deadlock). FunctionLibraryDefinition clone(other); mutex_lock l(mu_); // Remember the funcs and grads that we added successfully so that // we can roll them back on error. std::vector funcs; std::vector funcs_with_grads; Status s; bool added; for (auto iter : clone.function_defs_) { s = AddFunctionDefHelper(iter.second->fdef, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); return s; } if (added) { funcs.push_back(iter.second->fdef.signature().name()); } } for (auto iter : clone.func_grad_) { GradientDef grad; grad.set_function_name(iter.first); grad.set_gradient_func(iter.second); s = AddGradientDefHelper(grad, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); return s; } if (added) { funcs_with_grads.push_back(grad.function_name()); } } return Status::OK(); } Status FunctionLibraryDefinition::AddLibrary( const FunctionDefLibrary& lib_def) { // Remember the funcs and grads that we added successfully so that // we can roll them back on error. mutex_lock l(mu_); std::vector funcs; std::vector funcs_with_grads; Status s; bool added; for (const FunctionDef& fdef : lib_def.function()) { s = AddFunctionDefHelper(fdef, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); return s; } if (added) { funcs.push_back(fdef.signature().name()); } } for (const GradientDef& grad : lib_def.gradient()) { s = AddGradientDefHelper(grad, &added); if (!s.ok()) { Remove(funcs, funcs_with_grads); return s; } if (added) { funcs_with_grads.push_back(grad.function_name()); } } return Status::OK(); } Status FunctionLibraryDefinition::ReplaceFunction(const string& func, const FunctionDef& fdef) { mutex_lock l(mu_); bool added; TF_RETURN_IF_ERROR(RemoveFunction(func)); TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, &added)); return Status::OK(); } Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { mutex_lock l(mu_); bool added; TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added)); return Status::OK(); } Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); if (i == function_defs_.end()) { return errors::InvalidArgument("Tried to remove non-existent function ", func); } function_defs_.erase(i); return Status::OK(); } Status FunctionLibraryDefinition::RemoveGradient(const string& func) { const auto& i = func_grad_.find(func); if (i == func_grad_.end()) { return errors::InvalidArgument("Tried to remove non-existent gradient ", func); } func_grad_.erase(i); return Status::OK(); } void FunctionLibraryDefinition::Remove( const std::vector& funcs, const std::vector& funcs_with_grads) { for (const string& f : funcs) { Status s = RemoveFunction(f); DCHECK(s.ok()); } for (const string& f : funcs_with_grads) { Status s = RemoveGradient(f); DCHECK(s.ok()); } } string FunctionLibraryDefinition::FindGradient(const string& func) const { tf_shared_lock l(mu_); return gtl::FindWithDefault(func_grad_, func, ""); } string FunctionLibraryDefinition::FindGradientHelper(const string& func) const { return gtl::FindWithDefault(func_grad_, func, ""); } Status FunctionLibraryDefinition::LookUp( const string& op, const OpRegistrationData** op_reg_data) const { tf_shared_lock l(mu_); auto iter = function_defs_.find(op); if (iter != function_defs_.end()) { *op_reg_data = &iter->second->op_registration_data; return Status::OK(); } return default_registry_->LookUp(op, op_reg_data); } string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { tf_shared_lock l(mu_); int index = 0; string name = strings::StrCat(prefix, index); while (function_defs_.find(name) != function_defs_.end()) { ++index; name = strings::StrCat(prefix, index); } return name; } const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { // If 'ndef' calls a function and the function's def has the attr, // returns it. return Find(ndef.op()); } // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or // Foo's attributes. const NameAttrList* forward_func_attrs; if (!GetNodeAttr(ndef, kFuncAttr, &forward_func_attrs).ok()) { return nullptr; } const string& func_name = forward_func_attrs->name(); { tf_shared_lock l(mu_); const string& grad_name = FindGradientHelper(func_name); // If 'func' has a user-defined gradient function, uses the grad // function's attrs to see if noinline is specified. Otherwise, // uses func's attrs. if (!grad_name.empty()) { return FindHelper(grad_name); } return FindHelper(func_name); } } FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { FunctionDefLibrary lib; tf_shared_lock l(mu_); for (const auto& f : function_defs_) { *lib.add_function() = f.second->fdef; } for (const auto& g : func_grad_) { GradientDef* gd = lib.add_gradient(); gd->set_function_name(g.first); gd->set_gradient_func(g.second); } return lib; } template Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, const string& attr, T* value) const { const FunctionDef* fdef = GetAttrImpl(ndef); if (fdef && GetNodeAttr(AttrSlice(&fdef->attr()), attr, value).ok()) { return Status::OK(); } return errors::InvalidArgument("Attr ", attr, " is not defined."); } template Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr, T* value) const { return GetAttr(node.def(), attr, value); } #define GET_ATTR(T) \ template Status FunctionLibraryDefinition::GetAttr(const Node&, \ const string&, T*) const; \ template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \ const string&, T*) const; GET_ATTR(string) GET_ATTR(bool) #undef GET_ATTR void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { if (val.size() >= 2 && val[0] == '$') { proto.set_placeholder(val.data() + 1, val.size() - 1); } else { SetAttrValue(val, &proto); } } FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( const string& name, gtl::ArraySlice> attrs) { AttrValueWrapper ret; ret.proto.mutable_func()->set_name(name); for (const auto& a : attrs) { ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto}); } return ret; } NodeDef FunctionDefHelper::Node::ToNodeDef() const { NodeDef n; n.set_op(this->op); n.set_name(this->ret[0]); for (const auto& a : this->attr) { n.mutable_attr()->insert({a.first, a.second.proto}); } for (const string& a : this->arg) { n.add_input(a); } for (const string& d : this->dep) { n.add_input(strings::StrCat("^", d)); } return n; } /* static */ FunctionDef FunctionDefHelper::Create( const string& function_name, gtl::ArraySlice in_def, gtl::ArraySlice out_def, gtl::ArraySlice attr_def, gtl::ArraySlice node_def, gtl::ArraySlice> ret_def) { FunctionDef fdef; // Signature OpDefBuilder b(function_name); for (const auto& i : in_def) b.Input(i); for (const auto& o : out_def) b.Output(o); for (const auto& a : attr_def) b.Attr(a); OpRegistrationData op_reg_data; TF_CHECK_OK(b.Finalize(&op_reg_data)); fdef.mutable_signature()->Swap(&op_reg_data.op_def); // Function body for (const auto& n : node_def) { *(fdef.add_node_def()) = n.ToNodeDef(); } // Returns for (const auto& r : ret_def) { fdef.mutable_ret()->insert({r.first, r.second}); } auto* op_def_registry = OpRegistry::Global(); // Check if any op is stateful. for (const auto& n : node_def) { const OpDef* op_def = nullptr; auto status = op_def_registry->LookUpOpDef(n.op, &op_def); // Lookup can fail if e.g. we are calling a function that was not yet // defined. If it happens, conservatively assume the op is stateful. if (!status.ok() || op_def->is_stateful()) { fdef.mutable_signature()->set_is_stateful(true); } } return fdef; } /* static */ FunctionDef FunctionDefHelper::Define(const string& name, gtl::ArraySlice arg_def, gtl::ArraySlice ret_def, gtl::ArraySlice attr_def, gtl::ArraySlice node_def) { FunctionDef fdef; OpDefBuilder b(name); for (const auto& a : arg_def) b.Input(a); for (const auto& r : ret_def) b.Output(r); for (const auto& a : attr_def) b.Attr(a); OpRegistrationData op_reg_data; TF_CHECK_OK(b.Finalize(&op_reg_data)); fdef.mutable_signature()->Swap(&op_reg_data.op_def); // Mapping from legacy output names to NodeDef outputs. std::unordered_map ret_index; for (const auto& a : fdef.signature().input_arg()) { ret_index[a.name()] = a.name(); } // For looking up OpDefs auto* op_def_registry = OpRegistry::Global(); // Function body for (const auto& src : node_def) { NodeDef* n = fdef.add_node_def(); n->set_op(src.op); n->set_name(src.ret[0]); for (const auto& a : src.attr) { n->mutable_attr()->insert({a.first, a.second.proto}); } for (const string& a : src.arg) { const auto iter = ret_index.find(a); CHECK(iter != ret_index.end()) << "Node input '" << a << "' in '" << src.ret[0] << "' of " << name; n->add_input(iter->second); } for (const string& d : src.dep) { n->add_input(strings::StrCat("^", d)); } // Add the outputs of this node to ret_index. const OpDef* op_def = nullptr; TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op(); CHECK(op_def != nullptr) << n->op(); NameRangeMap output_names; TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names)); for (const auto& o : output_names) { CHECK_LE(o.second.second, src.ret.size()) << "Missing ret for output '" << o.first << "' in '" << src.ret[0] << "' of " << name; for (int i = o.second.first; i < o.second.second; ++i) { ret_index[src.ret[i]] = strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first); } } if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true); } // Returns for (const auto& r : fdef.signature().output_arg()) { const auto iter = ret_index.find(r.name()); CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name; fdef.mutable_ret()->insert({r.name(), iter->second}); } return fdef; } FunctionDef FunctionDefHelper::Define(gtl::ArraySlice arg_def, gtl::ArraySlice ret_def, gtl::ArraySlice attr_def, gtl::ArraySlice node_def) { return Define("_", arg_def, ret_def, attr_def, node_def); } namespace gradient { typedef std::unordered_map OpGradFactory; OpGradFactory* GetOpGradFactory() { static OpGradFactory* factory = new OpGradFactory; return factory; } bool RegisterOp(const string& op, Creator func) { CHECK(GetOpGradFactory()->insert({op, func}).second) << "Duplicated gradient for " << op; return true; } Status GetOpGradientCreator(const string& op, Creator* creator) { auto fac = GetOpGradFactory(); auto iter = fac->find(op); if (iter == fac->end()) { return errors::NotFound("No gradient defined for op: ", op); } *creator = iter->second; return Status::OK(); } } // end namespace gradient } // end namespace tensorflow