aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_def_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_def_util.cc')
-rw-r--r--tensorflow/core/framework/op_def_util.cc344
1 files changed, 344 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
new file mode 100644
index 0000000000..e3aef011de
--- /dev/null
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -0,0 +1,344 @@
+#include "tensorflow/core/framework/op_def_util.h"
+
+#include <set>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+namespace { // ------ Helper functions ------
+
+bool HasAttrStyleType(const OpDef::ArgDef& arg) {
+ return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
+ !arg.type_list_attr().empty();
+}
+
+Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
+ const AttrValue& allowed_values(attr.allowed_values());
+ for (auto allowed : allowed_values.list().type()) {
+ if (dt == allowed) {
+ return Status::OK();
+ }
+ }
+ string allowed_str;
+ for (int i = 0; i < allowed_values.list().type_size(); ++i) {
+ if (!allowed_str.empty()) {
+ strings::StrAppend(&allowed_str, ", ");
+ }
+ strings::StrAppend(&allowed_str,
+ DataTypeString(allowed_values.list().type(i)));
+ }
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
+ " is not in the list of allowed values: ", allowed_str);
+}
+
+Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
+ const AttrValue& allowed_values(attr.allowed_values());
+ for (auto allowed : allowed_values.list().s()) {
+ if (str == allowed) {
+ return Status::OK();
+ }
+ }
+ string allowed_str;
+ for (const string& allowed : allowed_values.list().s()) {
+ if (!allowed_str.empty()) {
+ strings::StrAppend(&allowed_str, ", ");
+ }
+ strings::StrAppend(&allowed_str, "\"", allowed, "\"");
+ }
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of \"", str,
+ "\" is not in the list of allowed values: ", allowed_str);
+}
+
+} // namespace
+
+// Requires: attr has already been validated.
+Status ValidateAttrValue(const AttrValue& attr_value,
+ const OpDef::AttrDef& attr) {
+ // Is it a valid value?
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
+ " for attr '", attr.name(), "'");
+
+ // Does the value satisfy the minimum constraint in the AttrDef?
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ if (attr_value.i() < attr.minimum()) {
+ return errors::InvalidArgument(
+ "Value for attr '", attr.name(), "' of ", attr_value.i(),
+ " must be at least minimum ", attr.minimum());
+ }
+ } else {
+ int length = -1;
+ if (attr.type() == "list(string)") {
+ length = attr_value.list().s_size();
+ } else if (attr.type() == "list(int)") {
+ length = attr_value.list().i_size();
+ } else if (attr.type() == "list(float)") {
+ length = attr_value.list().f_size();
+ } else if (attr.type() == "list(bool)") {
+ length = attr_value.list().b_size();
+ } else if (attr.type() == "list(type)") {
+ length = attr_value.list().type_size();
+ } else if (attr.type() == "list(shape)") {
+ length = attr_value.list().shape_size();
+ } else if (attr.type() == "list(tensor)") {
+ length = attr_value.list().tensor_size();
+ }
+ if (length < attr.minimum()) {
+ return errors::InvalidArgument(
+ "Length for attr '", attr.name(), "' of ", length,
+ " must be at least minimum ", attr.minimum());
+ }
+ }
+ }
+
+ // Does the value satisfy the allowed_value constraint in the AttrDef?
+ if (attr.has_allowed_values()) {
+ if (attr.type() == "type") {
+ TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
+ } else if (attr.type() == "list(type)") {
+ for (int dt : attr_value.list().type()) {
+ TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
+ }
+ } else if (attr.type() == "string") {
+ TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
+ } else if (attr.type() == "list(string)") {
+ for (const string& str : attr_value.list().s()) {
+ TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
+ }
+ } else {
+ return errors::Unimplemented(
+ "Support for allowed_values not implemented for type ", attr.type());
+ }
+ }
+ return Status::OK();
+}
+
+const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ if (op_def.attr(i).name() == name) {
+ return &op_def.attr(i);
+ }
+ }
+ return nullptr;
+}
+
+OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ if (op_def->attr(i).name() == name) {
+ return op_def->mutable_attr(i);
+ }
+ }
+ return nullptr;
+}
+
+#define VALIDATE(EXPR, ...) \
+ do { \
+ if (!(EXPR)) { \
+ return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \
+ op_def.ShortDebugString()); \
+ } \
+ } while (false)
+
+static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
+ bool output, std::set<string>* names) {
+ const string suffix = strings::StrCat(
+ output ? " for output '" : " for input '", arg.name(), "'");
+ VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ",
+ arg.name());
+ VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
+
+ if (!arg.number_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
+ suffix, " has type ", attr->type(), " != int");
+ VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
+ suffix, " must have minimum");
+ VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
+ suffix, " must have minimum >= 0");
+ VALIDATE(arg.type_list_attr().empty(),
+ "Can't have both number_attr and type_list_attr", suffix);
+ VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
+ (!arg.type_attr().empty() ? 1 : 0) ==
+ 1,
+ "Exactly one of type, type_attr must be set", suffix);
+ } else {
+ const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
+ (!arg.type_attr().empty() ? 1 : 0) +
+ (!arg.type_list_attr().empty() ? 1 : 0);
+ VALIDATE(num_type_fields == 1,
+ "Exactly one of type, type_attr, type_list_attr must be set",
+ suffix);
+ }
+
+ if (!arg.type_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "type", "Attr '", attr->name(),
+ "' used as type_attr", suffix, " has type ", attr->type(),
+ " != type");
+ } else if (!arg.type_list_attr().empty()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
+ VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
+ suffix);
+ VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
+ "' used as type_list_attr", suffix, " has type ", attr->type(),
+ " != list(type)");
+ } else {
+ // All argument types should be non-reference types at this point.
+ // ArgDef.is_ref is set to true for reference arguments.
+ VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
+ DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
+ }
+
+ return Status::OK();
+}
+
+Status ValidateOpDef(const OpDef& op_def) {
+ VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"),
+ "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
+
+ std::set<string> names; // for detecting duplicate names
+ for (const auto& attr : op_def.attr()) {
+ // Validate name
+ VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ",
+ attr.name());
+ DataType dt;
+ VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
+ attr.name(), " that matches a data type");
+
+ // Validate type
+ StringPiece type(attr.type());
+ bool is_list = type.Consume("list(");
+ bool found = false;
+ for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
+ "tensor", "func"}) {
+ if (type.Consume(valid)) {
+ found = true;
+ break;
+ }
+ }
+ VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
+ "'");
+ if (is_list) {
+ VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ",
+ attr.name(), "'s type ", attr.type());
+ }
+ VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
+ attr.name(), "'s type ", attr.type());
+
+ // Validate minimum
+ if (attr.has_minimum()) {
+ VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
+ "' has minimum for unsupported type ", attr.type());
+ if (is_list) {
+ VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
+ "' with list type must have a non-negative minimum, not ",
+ attr.minimum());
+ }
+ } else {
+ VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
+ "' with has_minimum = false but minimum ", attr.minimum(),
+ " not equal to default of 0");
+ }
+
+ // Validate allowed_values
+ if (attr.has_allowed_values()) {
+ const string list_type =
+ is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
+ attr.name(), "' in Op '", op_def.name(), "'");
+ }
+
+ // Validate default_value (after we have validated the rest of the attr,
+ // so we can use ValidateAttrValue()).
+ if (attr.has_default_value()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.default_value(), attr), " in Op '",
+ op_def.name(), "'");
+ }
+ }
+
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
+ }
+
+ for (const auto& arg : op_def.output_arg()) {
+ TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
+ }
+
+ return Status::OK();
+}
+
+#undef VALIDATE
+
+namespace {
+
+string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
+ string ret;
+ for (const OpDef::ArgDef& arg : args) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, arg.name(), ":");
+ if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
+ if (!arg.number_attr().empty()) {
+ strings::StrAppend(&ret, arg.number_attr(), "*");
+ }
+ if (arg.type() != DT_INVALID) {
+ strings::StrAppend(&ret, DataTypeString(arg.type()));
+ } else {
+ strings::StrAppend(&ret, arg.type_attr());
+ }
+ if (arg.is_ref()) strings::StrAppend(&ret, ")");
+ }
+ return ret;
+}
+
+} // namespace
+
+string SummarizeOpDef(const OpDef& op_def) {
+ string ret = strings::StrCat("Op<name=", op_def.name());
+ strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
+ " -> ", SummarizeArgs(op_def.output_arg()));
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
+ op_def.attr(i).type());
+ if (op_def.attr(i).has_default_value()) {
+ strings::StrAppend(&ret, ",default=",
+ SummarizeAttrValue(op_def.attr(i).default_value()));
+ }
+ if (op_def.attr(i).has_minimum()) {
+ strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
+ }
+ if (op_def.attr(i).has_allowed_values()) {
+ strings::StrAppend(&ret, ",allowed=",
+ SummarizeAttrValue(op_def.attr(i).allowed_values()));
+ }
+ }
+ if (op_def.is_commutative()) {
+ strings::StrAppend(&ret, "; is_commutative=true");
+ }
+ if (op_def.is_aggregate()) {
+ strings::StrAppend(&ret, "; is_aggregate=true");
+ }
+ if (op_def.is_stateful()) {
+ strings::StrAppend(&ret, "; is_stateful=true");
+ }
+ if (op_def.allows_uninitialized_input()) {
+ strings::StrAppend(&ret, "; allows_uninitialized_input=true");
+ }
+ strings::StrAppend(&ret, ">");
+ return ret;
+}
+
+} // namespace tensorflow