aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/node_def_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/node_def_util.cc')
-rw-r--r--tensorflow/core/framework/node_def_util.cc414
1 files changed, 414 insertions, 0 deletions
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
new file mode 100644
index 0000000000..aefd416187
--- /dev/null
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -0,0 +1,414 @@
+#include "tensorflow/core/framework/node_def_util.h"
+
+#include <algorithm>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+string SummarizeNodeDef(const NodeDef& node_def) {
+ string ret = strings::StrCat(node_def.name(), " = ", node_def.op(), "[");
+
+ // We sort the attrs so the output is deterministic.
+ std::vector<string> attr_names;
+ attr_names.reserve(node_def.attr().size());
+ for (const auto& attr : node_def.attr()) {
+ attr_names.push_back(attr.first);
+ }
+ std::sort(attr_names.begin(), attr_names.end());
+ bool first = true;
+ for (const string& attr_name : attr_names) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ auto iter = node_def.attr().find(attr_name);
+ strings::StrAppend(&ret, attr_name, "=", SummarizeAttrValue(iter->second));
+ }
+
+ // Consider the device to be a final attr with name "_device".
+ if (!node_def.device().empty()) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ strings::StrAppend(&ret, "_device=\"", node_def.device(), "\"");
+ }
+ strings::StrAppend(&ret, "](");
+
+ // Output inputs, including control inputs, verbatim.
+ first = true;
+ for (const string& input : node_def.input()) {
+ if (!first) strings::StrAppend(&ret, ", ");
+ first = false;
+ strings::StrAppend(&ret, input);
+ }
+ strings::StrAppend(&ret, ")");
+ return ret;
+}
+
+const AttrValue* AttrSlice::Find(const string& attr_name) const {
+ auto iter = attrs_->find(attr_name);
+ if (iter == attrs_->end()) return nullptr;
+ return &iter->second;
+}
+
+Status AttrSlice::Find(const string& attr_name,
+ const AttrValue** attr_value) const {
+ *attr_value = Find(attr_name);
+ if (*attr_value != nullptr) {
+ return Status::OK();
+ }
+ Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
+ if (ndef_) {
+ s = AttachDef(s, *ndef_);
+ }
+ return s;
+}
+
+// The ... is to allow the caller to inject some value validation code. Use
+// just ; if no additional validation code is needed.
+#define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
+ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ TYPE* value) { \
+ const AttrValue* attr_value; \
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
+ const auto& v = attr_value->FIELD(); \
+ __VA_ARGS__; \
+ *value = CAST; \
+ return Status::OK(); \
+ } \
+ Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name, \
+ std::vector<TYPE>* value) { \
+ const AttrValue* attr_value; \
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
+ for (const auto& v : attr_value->list().FIELD()) { \
+ __VA_ARGS__; \
+ value->APPEND_OP(CAST); \
+ } \
+ return Status::OK(); \
+ }
+
+DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
+DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
+DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
+ if (static_cast<int64>(static_cast<int32>(v)) != v) {
+ return errors::InvalidArgument("Attr ", attr_name,
+ " has value ", v,
+ " out of range for an int32");
+ })
+DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
+// std::vector<bool> specialization does not have emplace_back until
+// c++14, so we have to use push_back (see
+// http://en.cppreference.com/w/cpp/container/vector/emplace_back)
+DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
+DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
+ ;)
+DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
+DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v), ;)
+DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
+ if (!t.FromProto(v)) {
+ return errors::InvalidArgument(
+ "Attr ", attr_name, " has value ", v.ShortDebugString(),
+ " that can't be converted to a Tensor");
+ })
+
+#undef DEFINE_GET_ATTR
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ DataTypeVector* value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
+ for (const auto& v : attr_value->list().type()) {
+ value->push_back(static_cast<DataType>(v));
+ }
+ return Status::OK();
+}
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const TensorProto** value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
+ *value = &attr_value->tensor();
+ return Status::OK();
+}
+
+Status GetNodeAttr(const AttrSlice& attrs, const string& attr_name,
+ const NameAttrList** value) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
+ TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
+ *value = &attr_value->func();
+ return Status::OK();
+}
+
+namespace { // Helper for InOutTypesForNode().
+
+Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
+ DataTypeVector* sig) {
+ const int original_size = sig->size();
+ if (!arg_def.number_attr().empty()) {
+ // Same type repeated "repeats" times.
+ int32 repeats = -1;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
+ if (repeats < 0) {
+ return errors::InvalidArgument("Value for number_attr() ", repeats,
+ " < 0");
+ }
+
+ if (!arg_def.type_attr().empty()) {
+ DataType dtype;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
+ for (int i = 0; i < repeats; ++i) {
+ sig->push_back(dtype);
+ }
+ } else if (arg_def.type() != DT_INVALID) {
+ for (int i = 0; i < repeats; ++i) {
+ sig->push_back(arg_def.type());
+ }
+ } else {
+ return errors::InvalidArgument("Missing type or type_attr field in ",
+ arg_def.ShortDebugString());
+ }
+ } else if (!arg_def.type_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
+ sig->push_back(attr_value->type());
+ } else if (!arg_def.type_list_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
+ for (int dtype : attr_value->list().type()) {
+ sig->push_back(static_cast<DataType>(dtype));
+ }
+ } else if (arg_def.type() != DT_INVALID) {
+ sig->push_back(arg_def.type());
+ } else {
+ return errors::InvalidArgument("No type fields in ",
+ arg_def.ShortDebugString());
+ }
+ if (arg_def.is_ref()) {
+ // For all types that were added by this function call, make them refs.
+ for (size_t i = original_size; i < sig->size(); ++i) {
+ (*sig)[i] = MakeRefType((*sig)[i]);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs, DataTypeVector* outputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ for (const auto& arg : op_def.output_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
+ }
+ return Status::OK();
+}
+
+Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
+ if (node_def.op() != op_def.name()) {
+ return errors::InvalidArgument("NodeDef op '", node_def.op(),
+ "' does not match ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ bool seen_control = false;
+ size_t num_inputs = 0;
+ // TODO(josh11b): Unify the input field validation.
+ for (const string& input : node_def.input()) {
+ if (StringPiece(input).starts_with("^")) {
+ seen_control = true;
+ if (input.find(':') != string::npos) {
+ return errors::InvalidArgument("Control input '", input,
+ "' must not have ':' in NodeDef: ",
+ SummarizeNodeDef(node_def));
+ }
+ } else if (seen_control) {
+ return errors::InvalidArgument("Non-control input '", input,
+ "' after control input in NodeDef: ",
+ SummarizeNodeDef(node_def));
+ } else {
+ ++num_inputs;
+ }
+ }
+
+ std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
+ for (const auto& attr : op_def.attr()) {
+ if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
+ return errors::InvalidArgument("OpDef has duplicate attr name '",
+ attr.name(), "': ",
+ SummarizeOpDef(op_def));
+ }
+ }
+ for (const auto& attr : node_def.attr()) {
+ // Allow internal optional attributes with names starting with "_".
+ if (StringPiece(attr.first).starts_with("_")) {
+ continue;
+ }
+ auto iter = op_attrs.find(attr.first);
+ if (iter == op_attrs.end()) {
+ return errors::InvalidArgument("NodeDef mentions attr '", attr.first,
+ "' not in ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.second, *iter->second), "; NodeDef: ",
+ SummarizeNodeDef(node_def), "; ", SummarizeOpDef(op_def));
+ // Keep track of which attr names have (not) been found in the NodeDef.
+ op_attrs.erase(iter);
+ }
+
+ // Were all attrs in the OpDef found in the NodeDef?
+ if (!op_attrs.empty()) {
+ string attrs;
+ for (const auto& attr_pair : op_attrs) {
+ if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
+ strings::StrAppend(&attrs, attr_pair.first);
+ }
+ return errors::InvalidArgument("NodeDef missing attr",
+ op_attrs.size() == 1 ? " '" : "s '", attrs,
+ "' from ", SummarizeOpDef(op_def),
+ "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ // Validate the number of inputs.
+ DataTypeVector inputs, outputs;
+ TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
+
+ if (num_inputs != inputs.size()) {
+ return errors::InvalidArgument(
+ "NodeDef expected inputs '", DataTypeVectorString(inputs),
+ "' do not match ", num_inputs, " inputs specified; ",
+ SummarizeOpDef(op_def), "; NodeDef: ", SummarizeNodeDef(node_def));
+ }
+
+ return Status::OK();
+}
+
+namespace { // Helpers for NameRangesForNode()
+
+Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
+ const OpDef& op_def, int* num) {
+ if (!arg_def.number_attr().empty()) {
+ // Same type repeated "num" times.
+ return GetNodeAttr(node_def, arg_def.number_attr(), num);
+ } else if (!arg_def.type_list_attr().empty()) {
+ const AttrValue* attr_value;
+ TF_RETURN_IF_ERROR(
+ AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
+ *num = attr_value->list().type_size();
+ } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
+ *num = 1;
+ } else {
+ return errors::InvalidArgument("Argument '", arg_def.name(),
+ "' incorrectly specified in op definition: ",
+ SummarizeOpDef(op_def));
+ }
+ return Status::OK();
+}
+
+Status NameRangesHelper(const NodeDef& node_def,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ const OpDef& op_def, NameRangeMap* result) {
+ int start = 0;
+ int num;
+ for (const auto& arg : args) {
+ TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
+ (*result)[arg.name()] = std::make_pair(start, start + num);
+ start += num;
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
+ NameRangeMap* inputs, NameRangeMap* outputs) {
+ TF_RETURN_IF_ERROR(
+ NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
+ return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
+}
+
+void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
+ for (const auto& attr_def : op_def.attr()) {
+ AttrSlice attrs(*node_def);
+ if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
+ AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
+ }
+ }
+}
+
+namespace {
+
+static RE2* valid_op_name_pattern = new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+static RE2* valid_data_input_pattern =
+ new RE2("[A-Za-z0-9.][A-Za-z0-9_.\\-/]*(\\:(0|([1-9][0-9]*)))?");
+static RE2* valid_control_input_pattern =
+ new RE2("\\^[A-Za-z0-9.][A-Za-z0-9_.\\-/]*");
+
+} // namespace
+
+Status ValidateOpInput(const string& input_name, bool* is_control_input) {
+ *is_control_input = false;
+ if (RE2::FullMatch(input_name, *valid_data_input_pattern)) {
+ return Status::OK();
+ } else if (RE2::FullMatch(input_name, *valid_control_input_pattern)) {
+ *is_control_input = true;
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Illegal op input name '", input_name, "'");
+ }
+}
+
+Status ValidateOpName(const string& op_name) {
+ if (RE2::FullMatch(op_name, *valid_op_name_pattern)) {
+ return Status::OK();
+ } else {
+ return errors::InvalidArgument("Illegal op name '", op_name, "'");
+ }
+}
+
+Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
+ Status s = ValidateOpName(node_def.name());
+ if (!s.ok()) {
+ return AttachDef(s, node_def);
+ }
+ bool in_control_inputs = false;
+ for (const string& input_name : node_def.input()) {
+ bool is_control_input;
+ s = ValidateOpInput(input_name, &is_control_input);
+ if (!s.ok()) {
+ return AttachDef(s, node_def);
+ }
+
+ if (in_control_inputs && !is_control_input) {
+ return AttachDef(errors::InvalidArgument(
+ "All control inputs must follow all data inputs"),
+ node_def);
+ }
+ in_control_inputs = is_control_input;
+ }
+ return Status::OK();
+}
+
+Status AttachDef(const Status& status, const NodeDef& node_def) {
+ Status ret = status;
+ errors::AppendToMessage(
+ &ret, strings::StrCat(" [[Node: ", SummarizeNodeDef(node_def), "]]"));
+ return ret;
+}
+
+} // namespace tensorflow