diff options
Diffstat (limited to 'tensorflow/core/framework/function.cc')
-rw-r--r-- | tensorflow/core/framework/function.cc | 878 |
1 files changed, 878 insertions, 0 deletions
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc new file mode 100644 index 0000000000..b73e1ab8a9 --- /dev/null +++ b/tensorflow/core/framework/function.cc @@ -0,0 +1,878 @@ +#include "tensorflow/core/framework/function.h" + +#include <unordered_set> + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +REGISTER_OP("_Arg") + .Output("output: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents an argument to a function. + +output: The argument. +index: This argument is the index-th argument of the function. +)doc"); + +REGISTER_OP("_Retval") + .Input("input: T") + .Attr("T: type") + .Attr("index: int >= 0") + .Doc(R"doc( +A graph node which represents a return value of a function. + +input: The return value. +index: This return value is the index-th return value of the function. +)doc"); + +REGISTER_OP("_ListToArray") + .Input("input: Tin") + .Output("output: N * T") + .Attr("Tin: list(type)") + .Attr("T: type") + .Attr("N: int >= 1") + .Doc(R"doc( +Converts a list of tensors to an array of tensors. +)doc"); + +REGISTER_OP("_ArrayToList") + .Input("input: N * T") + .Output("output: out_types") + .Attr("T: type") + .Attr("N: int >= 1") + .Attr("out_types: list(type)") + .Doc(R"doc( +Converts an array of tensors to a list of tensors. +)doc"); + +namespace { + +// Extracts the actual type from "attr_values" based on its definition +// "arg_def". +Status ArgNumType(const InstantiateAttrValueMap& attrs, + const OpDef::ArgDef& arg_def, int* num, DataType* dtype) { + if (!arg_def.type_list_attr().empty()) { + return errors::Unimplemented("type_list is not supported."); + } + + if (arg_def.number_attr().empty()) { + *num = 1; + } else { + const AttrValue* v = gtl::FindOrNull(attrs, arg_def.number_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + *num = v->i(); + } + + if (arg_def.type() != DT_INVALID) { + *dtype = arg_def.type(); + } else if (arg_def.type_attr().empty()) { + *dtype = DT_INVALID; + } else { + const AttrValue* v = gtl::FindOrNull(attrs, arg_def.type_attr()); + if (v == nullptr) { + return errors::NotFound("type attr not found: ", arg_def.type_attr()); + } + *dtype = v->type(); + } + return Status::OK(); +} + +string Name(int node_index) { return strings::StrCat("n", node_index); } + +string Name(int node_index, int output_index) { + if (output_index == 0) { + return Name(node_index); + } else { + return strings::StrCat("n", node_index, ":", output_index); + } +} + +string Dep(int node_index) { return strings::StrCat("^", Name(node_index)); } + +template <typename T> +void AddAttr(const string& name, const T& val, NodeDef* ndef) { + SetAttrValue(val, &((*ndef->mutable_attr())[name])); +} + +Status ValidateSignatureWithAttrs(const OpDef& sig, + const InstantiateAttrValueMap& attr_values) { + // attr_values should specify all attrs defined in fdef. + for (const auto& a : sig.attr()) { + if (attr_values.find(a.name()) == attr_values.end()) { + return errors::NotFound("Attr ", a.name(), " is not found."); + } + } + + for (const auto& p : attr_values) { + if (HasPlaceHolder(p.second)) { + return errors::InvalidArgument(p.first, + " in attr_values is still a placeholder."); + } + } + + return Status::OK(); +} + +// We build a small index for all names that can be used as a node's +// input arguments. +// +// If is_func_arg is true, the name is a function's argument. In +// this case, the produced graph def has gdef.node[nid ... nid + +// num). +// +// Otherwise, the name is a function body's node return value. In +// this case, the produced graph def has one node gdef.node[nid] and +// the node's output index [idx ... idx + num) corresponds to the +// named outputs. +// +// In all cases, "dtype" specifies the data type. +struct NameInfoItem { + bool is_func_arg; + int nid; + int idx; + int num; + DataType dtype; +}; +typedef std::unordered_map<string, NameInfoItem> NameInfoIndex; + +Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, + const InstantiateAttrValueMap& attr_values, + NameInfoIndex* name_info, + InstantiationResult* result) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR(ArgNumType(attr_values, arg_def, &num, &dtype)); + CHECK_GE(num, 1); + GraphDef* gdef = &result->gdef; + int arg_index = gdef->node_size(); + if (!name_info->insert({arg_def.name(), {true, arg_index, 0, num, dtype}}) + .second) { + return errors::InvalidArgument("Duplicated arg name."); + } + // Creates "num" nodes in the gdef. + for (int i = 0; i < num; ++i) { + DCHECK_EQ(arg_index, gdef->node_size()); + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(arg_index)); + gnode->set_op("_Arg"); + AddAttr("T", dtype, gnode); + AddAttr("index", arg_index, gnode); + result->arg_types.push_back(dtype); + ++arg_index; + } + return Status::OK(); +} + +Status BuildNodeOutputIndex(const FunctionDef::Node& node, + const InstantiateAttrValueMap& attrs, + GetFunctionSignature get_function, + const int arg_index, NameInfoIndex* name_info) { + const OpDef* node_sig = nullptr; + TF_RETURN_IF_ERROR(get_function(node.op(), &node_sig)); + if (node_sig->output_arg_size() == 0) { + // This node produces no output. + if (node.ret_size() != 1) { + return errors::InvalidArgument("Expect one ret name."); + } + if (!name_info->insert({node.ret(0), {false, arg_index, 0, 0, DT_INVALID}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + return Status::OK(); + } + + // When the signature says the last return value is of list(type), + // i.e., it's variadic, we need to consult + // attrs[last_retval.type_list_attr] to determine for the last arg + // * the actual number of outputs; + // * the actual data type of outputs. + const int num_retval = node_sig->output_arg_size(); + const OpDef::ArgDef& last_retval = node_sig->output_arg(num_retval - 1); + const bool last_retval_is_typelist = !last_retval.type_list_attr().empty(); + if (!last_retval_is_typelist && (node.ret_size() != num_retval)) { + return errors::InvalidArgument("Malformed function node (#ret)."); + } + int start = 0; + const int num_fixed_size_retval = + last_retval_is_typelist ? num_retval - 1 : num_retval; + for (int i = 0; i < num_fixed_size_retval; ++i) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR( + ArgNumType(attrs, node_sig->output_arg(i), &num, &dtype)); + if (!name_info->insert({node.ret(i), {false, arg_index, start, num, dtype}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + start += num; + } + if (last_retval_is_typelist) { + const AttrValue* typelist = + gtl::FindOrNull(attrs, last_retval.type_list_attr()); + if (typelist == nullptr) { + return errors::InvalidArgument("Missing attr ", + last_retval.type_list_attr(), "."); + } + if (num_fixed_size_retval + typelist->list().type_size() != + node.ret_size()) { + return errors::InvalidArgument("Wrong #ret: ", num_fixed_size_retval, " ", + typelist->list().type_size(), " ", + node.ret_size(), "."); + } + for (int i = 0; i < typelist->list().type_size(); ++i) { + if (!name_info->insert({node.ret(i), + {false, arg_index, start, 1, + typelist->list().type(i)}}) + .second) { + return errors::InvalidArgument("Duplicated ret name."); + } + ++start; + } + } + return Status::OK(); +} + +Status InstantiateNode(const FunctionDef::Node& fnode, + const InstantiateAttrValueMap& attrs, + GetFunctionSignature get_function, + const NameInfoIndex& name_info, GraphDef* gdef) { + const OpDef* fnode_sig = nullptr; + TF_CHECK_OK(get_function(fnode.op(), &fnode_sig)); + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(gdef->node_size() - 1)); + gnode->set_op(fnode.op()); + + // Input + // + // When the signature says the last argument is of list(type), + // i.e., it's variadic, we need to consult + // attrs[last_arg.type_list_attr] to determine for the last arg + // * the number of arguments; + // * the data types of arguments. + const int num_arg = fnode_sig->input_arg_size(); + bool last_arg_is_typelist = false; + if (num_arg > 0 && + !fnode_sig->input_arg(num_arg - 1).type_list_attr().empty()) { + last_arg_is_typelist = true; + } + if (!last_arg_is_typelist && (fnode.arg_size() != num_arg)) { + return errors::InvalidArgument("arg.size != sig.arg.size."); + } + const int num_fixed_size_args = last_arg_is_typelist ? num_arg - 1 : num_arg; + for (int i = 0; i < num_fixed_size_args; ++i) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR( + ArgNumType(attrs, fnode_sig->input_arg(i), &num, &dtype)); + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i)); + if (item == nullptr) { + return errors::InvalidArgument("arg[", i, "] is not found: ", + fnode.ShortDebugString()); + } + if (num != item->num || dtype != item->dtype) { + return errors::InvalidArgument("Invalid arg(", i, ") for function arg: ", + " ", num, "/", dtype, " vs. ", item->num, + "/", item->dtype, "."); + } + for (int j = 0; j < num; ++j) { + if (item->is_func_arg) { + gnode->add_input(Name(item->nid + j)); + } else { + gnode->add_input(Name(item->nid, item->idx + j)); + } + } + } + if (last_arg_is_typelist) { + AttrValue typelist; + for (int i = num_fixed_size_args; i < fnode.arg_size(); ++i) { + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.arg(i)); + if (item == nullptr) { + return errors::InvalidArgument("arg[", i, "] is not found."); + } + for (int j = 0; j < item->num; ++j) { + if (item->is_func_arg) { + gnode->add_input(Name(item->nid + j)); + } else { + gnode->add_input(Name(item->nid, item->idx + j)); + } + typelist.mutable_list()->add_type(item->dtype); + } + } + + // 'typelist' is inferred from the inputs' data types. + const auto& last_arg = fnode_sig->input_arg(num_arg - 1); + gnode->mutable_attr()->insert({last_arg.type_list_attr(), typelist}); + } + + // Control deps. + for (int i = 0; i < fnode.dep_size(); ++i) { + const NameInfoItem* item = gtl::FindOrNull(name_info, fnode.dep(i)); + if (item == nullptr) { + return errors::InvalidArgument("dep[", i, "] is not found."); + } + gnode->add_input(Dep(item->nid)); + } + + // Attrs. + for (const auto& p : attrs) { + (*gnode->mutable_attr())[p.first] = p.second; + } + + return Status::OK(); +} + +Status AddReturnNode(const OpDef::ArgDef& ret_def, + const InstantiateAttrValueMap& attrs, + const NameInfoIndex& name_info, int* ret_index, + InstantiationResult* result) { + int num; + DataType dtype; + TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &num, &dtype)); + CHECK_GE(num, 1); + const NameInfoItem* item = gtl::FindOrNull(name_info, ret_def.name()); + if (item == nullptr) { + return errors::InvalidArgument("ret is not found."); + } + if (num != item->num || dtype != item->dtype) { + return errors::InvalidArgument("Invalid ret name."); + } + GraphDef* gdef = &result->gdef; + for (int i = 0; i < num; ++i) { + NodeDef* gnode = gdef->add_node(); + gnode->set_name(Name(gdef->node_size() - 1)); + gnode->set_op("_Retval"); + gnode->add_input(Name(item->nid, item->idx + i)); + AddAttr("T", dtype, gnode); + AddAttr("index", (*ret_index)++, gnode); + result->ret_types.push_back(dtype); + } + return Status::OK(); +} + +// Various helpers Print(proto) to print relevant protos to ascii. +string Print(const OpDef::ArgDef& arg) { + string out; + strings::StrAppend(&out, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&out, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&out, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&out, DataTypeString(arg.type())); + } else { + strings::StrAppend(&out, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&out, ")"); + return out; +} + +string Print(const AttrValue& attr_value) { + if (attr_value.value_case() == AttrValue::kType) { + return DataTypeString(attr_value.type()); + } else if ((attr_value.value_case() == AttrValue::kList) && + (attr_value.list().type_size() > 0)) { + string ret = "{"; + for (int i = 0; i < attr_value.list().type_size(); ++i) { + if (i > 0) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); + } + strings::StrAppend(&ret, "}"); + return ret; + } else if (attr_value.value_case() == AttrValue::kFunc) { + if (attr_value.func().attr_size() == 0) { + return attr_value.func().name(); + } + std::vector<string> entries; + for (auto p : attr_value.func().attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(attr_value.func().name(), "[", + str_util::Join(entries, ", "), "]"); + } + return SummarizeAttrValue(attr_value); +} + +string Print(const FunctionDef::Node& node) { + string out; + for (int i = 0; i < node.ret_size(); ++i) { + const auto& name = node.ret(i); + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, name); + } + strings::StrAppend(&out, " = ", node.op()); + if (node.attr_size() > 0) { + std::vector<string> entries; + for (auto p : node.attr()) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < node.arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, node.arg(i)); + } + strings::StrAppend(&out, ")"); + if (node.dep_size() > 0) { + strings::StrAppend(&out, " @ "); + for (int i = 0; i < node.dep_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, node.dep(i)); + } + } + return out; +} + +string Print(const FunctionDef& fdef) { + string out; + const OpDef& sig = fdef.signature(); + strings::StrAppend(&out, "\n", sig.name()); + if (sig.attr_size() > 0) { + strings::StrAppend(&out, "["); + for (int i = 0; i < sig.attr_size(); ++i) { + const auto& a = sig.attr(i); + if (i > 0) strings::StrAppend(&out, ", "); + if (a.type() == "type") { + strings::StrAppend(&out, a.name(), ":", Print(a.allowed_values())); + } else { + strings::StrAppend(&out, a.name(), ":", a.type()); + } + } + strings::StrAppend(&out, "]"); + } + strings::StrAppend(&out, "("); + for (int i = 0; i < sig.input_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.input_arg(i))); + } + strings::StrAppend(&out, ") -> ("); + for (int i = 0; i < sig.output_arg_size(); ++i) { + if (i > 0) strings::StrAppend(&out, ", "); + strings::StrAppend(&out, Print(sig.output_arg(i))); + } + strings::StrAppend(&out, ") {\n"); + for (const auto& n : fdef.node()) { + strings::StrAppend(&out, " ", Print(n), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +string Print(const NodeDef& n) { + string out; + strings::StrAppend(&out, n.name(), " = ", n.op()); + if (n.attr_size() > 0) { + std::vector<string> entries; + for (auto& a : n.attr()) { + entries.push_back(strings::StrCat(a.first, "=", Print(a.second))); + } + sort(entries.begin(), entries.end()); + strings::StrAppend(&out, "[", str_util::Join(entries, ", "), "]"); + } + strings::StrAppend(&out, "("); + std::vector<StringPiece> dat; + std::vector<string> dep; + for (StringPiece s : n.input()) { + if (s.Consume("^")) { + dep.push_back(s.ToString()); + } else { + dat.push_back(s); + } + } + strings::StrAppend(&out, str_util::Join(dat, ", "), ")"); + if (!dep.empty()) { + strings::StrAppend(&out, " @ ", str_util::Join(dep, ", ")); + } + return out; +} + +string Print(const GraphDef& gdef) { + std::vector<const NodeDef*> arg; + std::vector<const NodeDef*> ret; + std::vector<const NodeDef*> body; + for (const NodeDef& n : gdef.node()) { + if (n.op() == "_Arg") { + arg.push_back(&n); + } else if (n.op() == "_Retval") { + ret.push_back(&n); + } else { + body.push_back(&n); + } + } + auto comp = [](const NodeDef* x, const NodeDef* y) { + int xi; + TF_CHECK_OK(GetNodeAttr(*x, "index", &xi)); + int yi; + TF_CHECK_OK(GetNodeAttr(*y, "index", &yi)); + return xi < yi; + }; + sort(arg.begin(), arg.end(), comp); + sort(ret.begin(), ret.end(), comp); + string out; + strings::StrAppend(&out, "\n("); + auto get_type = [](const NodeDef& n) { + for (auto a : n.attr()) { + if (a.first == "T") { + return DataTypeString(a.second.type()); + } + } + return DataTypeString(DT_INVALID); + }; + for (size_t i = 0; i < arg.size(); ++i) { + const NodeDef* n = arg[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_EQ(2, n->attr_size()); + strings::StrAppend(&out, n->name(), ":", get_type(*n)); + } + strings::StrAppend(&out, ") -> ("); + for (size_t i = 0; i < ret.size(); ++i) { + const NodeDef* n = ret[i]; + if (i > 0) strings::StrAppend(&out, ", "); + CHECK_EQ(2, n->attr_size()); + CHECK_EQ(1, n->input_size()); + strings::StrAppend(&out, n->input(0), ":", get_type(*n)); + } + strings::StrAppend(&out, ") {\n"); + for (size_t i = 0; i < body.size(); ++i) { + strings::StrAppend(&out, " ", Print(*body[i]), "\n"); + } + strings::StrAppend(&out, "}\n"); + return out; +} + +} // end namespace + +Status InstantiateFunction(const FunctionDef& fdef, + const InstantiateAttrValueMap& attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + const OpDef& sig = fdef.signature(); + GraphDef* gdef = &result->gdef; + gdef->Clear(); + + TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); + + auto substitute = [&attr_values](const string& name, AttrValue* val) { + auto iter = attr_values.find(name); + if (iter == attr_values.end()) { + return false; + } else { + *val = iter->second; + return true; + } + }; + + // Makes a copy of all attrs in fdef and substitutes placeholders. + // After this step, every attr is bound to a concrete value. + std::vector<InstantiateAttrValueMap> node_attrs; + node_attrs.resize(fdef.node_size()); + for (int i = 0; i < fdef.node_size(); ++i) { + for (auto attr : fdef.node(i).attr()) { + if (!SubstitutePlaceholders(substitute, &attr.second)) { + return errors::InvalidArgument("Failed to bind all placeholders in ", + SummarizeAttrValue(attr.second)); + } + CHECK(node_attrs[i].insert(attr).second); + } + } + + NameInfoIndex name_info; + Status s; + for (const OpDef::ArgDef& arg_def : sig.input_arg()) { + s = BuildInputArgIndex(arg_def, attr_values, &name_info, result); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(arg_def)); + return s; + } + } + for (int i = 0; i < fdef.node_size(); ++i) { + s = BuildNodeOutputIndex(fdef.node(i), node_attrs[i], get_function, + gdef->node_size() + i, &name_info); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(fdef.node(i))); + return s; + } + } + + // Emits one gdef.node for each fdef.node. + for (int i = 0; i < fdef.node_size(); ++i) { + s = InstantiateNode(fdef.node(i), node_attrs[i], get_function, name_info, + gdef); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(fdef.node(i))); + return s; + } + } + + // Emits nodes for the function's return values. + int ret_index = 0; + for (const OpDef::ArgDef& ret_def : sig.output_arg()) { + s = AddReturnNode(ret_def, attr_values, name_info, &ret_index, result); + if (!s.ok()) { + errors::AppendToMessage(&s, " In ", Print(ret_def)); + return s; + } + } + + return Status::OK(); +} + +string DebugString(const FunctionDef& func_def) { return Print(func_def); } + +string DebugString(const GraphDef& instantiated_func_def) { + return Print(instantiated_func_def); +} + +string DebugStringWhole(const GraphDef& gdef) { + string ret; + for (auto fdef : gdef.library().function()) { + strings::StrAppend(&ret, Print(fdef)); + } + strings::StrAppend(&ret, "\n"); + for (auto ndef : gdef.node()) { + strings::StrAppend(&ret, Print(ndef), "\n"); + } + return ret; +} + +string Canonicalize(const string& funcname, + const InstantiateAttrValueMap& attrs) { + std::vector<string> entries; + entries.reserve(attrs.size()); + for (auto p : attrs) { + entries.push_back(strings::StrCat(p.first, "=", Print(p.second))); + } + sort(entries.begin(), entries.end()); + return strings::StrCat(funcname, "[", str_util::Join(entries, ","), "]"); +} + +FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, + DataTypeSlice ret_types) + : arg_types_(arg_types.begin(), arg_types.end()), + ret_types_(ret_types.begin(), ret_types.end()) { + args_.resize(arg_types_.size()); + rets_.resize(ret_types_.size()); +} + +FunctionCallFrame::~FunctionCallFrame() {} + +Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) { + // Input type checks. + if (args.size() != arg_types_.size()) { + return errors::InvalidArgument("Expects ", arg_types_.size(), + " arguments, but ", args.size(), + " is provided"); + } + for (size_t i = 0; i < args.size(); ++i) { + if (arg_types_[i] != args[i].dtype()) { + return errors::InvalidArgument( + "Expects arg[", i, "] to be ", DataTypeString(arg_types_[i]), " but ", + DataTypeString(args[i].dtype()), " is provided"); + } + args_[i] = args[i]; + } + return Status::OK(); +} + +Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const { + rets->clear(); + rets->reserve(rets_.size()); + for (size_t i = 0; i < rets_.size(); ++i) { + auto item = rets_[i]; + if (item.has_val) { + rets->push_back(item.val); + } else { + return errors::Internal("Retval[", i, "] does not have value"); + } + } + return Status::OK(); +} + +Status FunctionCallFrame::GetArg(int index, Tensor* val) const { + if (index < 0 || static_cast<size_t>(index) >= args_.size()) { + return errors::OutOfRange("GetArg ", index, " is not within [0, ", + args_.size(), ")"); + } + *val = args_[index]; + return Status::OK(); +} + +Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { + if (index < 0 || static_cast<size_t>(index) >= rets_.size()) { + return errors::OutOfRange("SetRetval ", index, " is not within [0, ", + rets_.size(), ")"); + } + if (val.dtype() != ret_types_[index]) { + return errors::InvalidArgument( + "Expects ret[", index, "] to be ", DataTypeString(ret_types_[index]), + ", but ", DataTypeString(val.dtype()), " is provided."); + } + Retval* item = &rets_[index]; + if (!item->has_val) { + item->has_val = true; + item->val = val; + } else { + return errors::Internal("Retval[", index, "] has already been set."); + } + return Status::OK(); +} + +FunctionLibraryDefinition::FunctionLibraryDefinition( + const FunctionDefLibrary& def_lib) + : function_defs_(def_lib.function_size()) { + for (auto fdef : def_lib.function()) { + // The latter function definition wins. + function_defs_[fdef.signature().name()] = fdef; + } +} + +FunctionLibraryDefinition::~FunctionLibraryDefinition() {} + +const FunctionDef* FunctionLibraryDefinition::Find(const string& name) const { + auto iter = function_defs_.find(name); + if (iter == function_defs_.end()) { + return nullptr; + } else { + return &iter->second; + } +} + +const OpDef* FunctionLibraryDefinition::LookUp(const string& op, + Status* status) const { + auto fdef = Find(op); + if (fdef != nullptr) { + return &(fdef->signature()); + } + return OpRegistry::Global()->LookUp(op, status); +} + +Status InstantiateFunction(const FunctionDef& fdef, + InstantiateAttrValueSlice attr_values, + GetFunctionSignature get_function, + InstantiationResult* result) { + InstantiateAttrValueMap m; + for (const auto& aval : attr_values) { + m.insert({aval.first, aval.second.proto}); + } + return InstantiateFunction(fdef, m, get_function, result); +} + +string Canonicalize(const string& funcname, InstantiateAttrValueSlice attrs) { + InstantiateAttrValueMap m; + for (const auto& aval : attrs) { + m.insert({aval.first, aval.second.proto}); + } + return Canonicalize(funcname, m); +} + +Status FunctionLibraryRuntime::Instantiate(const string& function_name, + InstantiateAttrValueSlice attrs, + Handle* handle) { + InstantiateAttrValueMap m; + for (const auto& aval : attrs) { + m.insert({aval.first, aval.second.proto}); + } + return Instantiate(function_name, m, handle); +} + +void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { + if (val.size() >= 2 && val[0] == '$') { + proto.set_placeholder(val.data() + 1, val.size() - 1); + } else { + SetAttrValue(val, &proto); + } +} + +FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( + const string& name, + gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) { + AttrValueWrapper ret; + ret.proto.mutable_func()->set_name(name); + for (const auto& a : attrs) { + ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto}); + } + return ret; +} + +FunctionDef::Node FunctionDefHelper::Node::ToProto() const { + FunctionDef::Node n; + for (const string& r : this->ret) { + n.add_ret(r); + } + n.set_op(this->op); + for (const string& a : arg) { + n.add_arg(a); + } + for (const auto& a : this->attr) { + n.mutable_attr()->insert({a.first, a.second.proto}); + } + for (const string& d : dep) { + n.add_dep(d); + } + return n; +} + +/* static */ +FunctionDef FunctionDefHelper::Define(const string& name, + gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def) { + FunctionDef fdef; + OpDefBuilder b(name); + for (const auto& a : arg_def) b.Input(a); + for (const auto& r : ret_def) b.Output(r); + for (const auto& a : attr_def) b.Attr(a); + TF_CHECK_OK(b.Finalize(fdef.mutable_signature())); + for (const auto& n : node_def) { + *(fdef.add_node()) = n.ToProto(); + } + return fdef; +} + +FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def, + gtl::ArraySlice<string> ret_def, + gtl::ArraySlice<string> attr_def, + gtl::ArraySlice<Node> node_def) { + return Define("_", arg_def, ret_def, attr_def, node_def); +} + +namespace gradient { + +typedef std::unordered_map<string, Creator> OpGradFactory; + +OpGradFactory* GetOpGradFactory() { + static OpGradFactory* factory = new OpGradFactory; + return factory; +} + +bool RegisterOp(const string& op, Creator func) { + CHECK(GetOpGradFactory()->insert({op, func}).second) + << "Duplicated gradient for " << op; + return true; +} + +Status GetOpGradientCreator(const string& op, Creator* creator) { + auto fac = GetOpGradFactory(); + auto iter = fac->find(op); + if (iter == fac->end()) { + return errors::NotFound("No gradient defined for op: ", op); + } + *creator = iter->second; + return Status::OK(); +} + +} // end namespace gradient + +} // end namespace tensorflow |