diff options
Diffstat (limited to 'tensorflow/core/framework/op_def_util.cc')
-rw-r--r-- | tensorflow/core/framework/op_def_util.cc | 344 |
1 files changed, 344 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc new file mode 100644 index 0000000000..e3aef011de --- /dev/null +++ b/tensorflow/core/framework/op_def_util.cc @@ -0,0 +1,344 @@ +#include "tensorflow/core/framework/op_def_util.h" + +#include <set> +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/regexp.h" + +namespace tensorflow { +namespace { // ------ Helper functions ------ + +bool HasAttrStyleType(const OpDef::ArgDef& arg) { + return arg.type() != DT_INVALID || !arg.type_attr().empty() || + !arg.type_list_attr().empty(); +} + +Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().type()) { + if (dt == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (int i = 0; i < allowed_values.list().type_size(); ++i) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, + DataTypeString(allowed_values.list().type(i))); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", DataTypeString(dt), + " is not in the list of allowed values: ", allowed_str); +} + +Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) { + const AttrValue& allowed_values(attr.allowed_values()); + for (auto allowed : allowed_values.list().s()) { + if (str == allowed) { + return Status::OK(); + } + } + string allowed_str; + for (const string& allowed : allowed_values.list().s()) { + if (!allowed_str.empty()) { + strings::StrAppend(&allowed_str, ", "); + } + strings::StrAppend(&allowed_str, "\"", allowed, "\""); + } + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of \"", str, + "\" is not in the list of allowed values: ", allowed_str); +} + +} // namespace + +// Requires: attr has already been validated. +Status ValidateAttrValue(const AttrValue& attr_value, + const OpDef::AttrDef& attr) { + // Is it a valid value? + TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()), + " for attr '", attr.name(), "'"); + + // Does the value satisfy the minimum constraint in the AttrDef? + if (attr.has_minimum()) { + if (attr.type() == "int") { + if (attr_value.i() < attr.minimum()) { + return errors::InvalidArgument( + "Value for attr '", attr.name(), "' of ", attr_value.i(), + " must be at least minimum ", attr.minimum()); + } + } else { + int length = -1; + if (attr.type() == "list(string)") { + length = attr_value.list().s_size(); + } else if (attr.type() == "list(int)") { + length = attr_value.list().i_size(); + } else if (attr.type() == "list(float)") { + length = attr_value.list().f_size(); + } else if (attr.type() == "list(bool)") { + length = attr_value.list().b_size(); + } else if (attr.type() == "list(type)") { + length = attr_value.list().type_size(); + } else if (attr.type() == "list(shape)") { + length = attr_value.list().shape_size(); + } else if (attr.type() == "list(tensor)") { + length = attr_value.list().tensor_size(); + } + if (length < attr.minimum()) { + return errors::InvalidArgument( + "Length for attr '", attr.name(), "' of ", length, + " must be at least minimum ", attr.minimum()); + } + } + } + + // Does the value satisfy the allowed_value constraint in the AttrDef? + if (attr.has_allowed_values()) { + if (attr.type() == "type") { + TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr)); + } else if (attr.type() == "list(type)") { + for (int dt : attr_value.list().type()) { + TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr)); + } + } else if (attr.type() == "string") { + TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr)); + } else if (attr.type() == "list(string)") { + for (const string& str : attr_value.list().s()) { + TF_RETURN_IF_ERROR(AllowedStringValue(str, attr)); + } + } else { + return errors::Unimplemented( + "Support for allowed_values not implemented for type ", attr.type()); + } + } + return Status::OK(); +} + +const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) { + for (int i = 0; i < op_def.attr_size(); ++i) { + if (op_def.attr(i).name() == name) { + return &op_def.attr(i); + } + } + return nullptr; +} + +OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) { + for (int i = 0; i < op_def->attr_size(); ++i) { + if (op_def->attr(i).name() == name) { + return op_def->mutable_attr(i); + } + } + return nullptr; +} + +#define VALIDATE(EXPR, ...) \ + do { \ + if (!(EXPR)) { \ + return errors::InvalidArgument(__VA_ARGS__, "; in OpDef: ", \ + op_def.ShortDebugString()); \ + } \ + } while (false) + +static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def, + bool output, std::set<string>* names) { + const string suffix = strings::StrCat( + output ? " for output '" : " for input '", arg.name(), "'"); + VALIDATE(gtl::InsertIfNotPresent(names, arg.name()), "Duplicate name: ", + arg.name()); + VALIDATE(HasAttrStyleType(arg), "Missing type", suffix); + + if (!arg.number_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'", + suffix); + VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length", + suffix, " has type ", attr->type(), " != int"); + VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length", + suffix, " must have minimum"); + VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length", + suffix, " must have minimum >= 0"); + VALIDATE(arg.type_list_attr().empty(), + "Can't have both number_attr and type_list_attr", suffix); + VALIDATE((arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) == + 1, + "Exactly one of type, type_attr must be set", suffix); + } else { + const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) + + (!arg.type_attr().empty() ? 1 : 0) + + (!arg.type_list_attr().empty() ? 1 : 0); + VALIDATE(num_type_fields == 1, + "Exactly one of type, type_attr, type_list_attr must be set", + suffix); + } + + if (!arg.type_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'", + suffix); + VALIDATE(attr->type() == "type", "Attr '", attr->name(), + "' used as type_attr", suffix, " has type ", attr->type(), + " != type"); + } else if (!arg.type_list_attr().empty()) { + const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def); + VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'", + suffix); + VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(), + "' used as type_list_attr", suffix, " has type ", attr->type(), + " != list(type)"); + } else { + // All argument types should be non-reference types at this point. + // ArgDef.is_ref is set to true for reference arguments. + VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '", + DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix); + } + + return Status::OK(); +} + +Status ValidateOpDef(const OpDef& op_def) { + VALIDATE(RE2::FullMatch(op_def.name(), "(?:_.*|[A-Z][a-zA-Z0-9]*)"), + "Invalid name: ", op_def.name(), " (Did you use CamelCase?)"); + + std::set<string> names; // for detecting duplicate names + for (const auto& attr : op_def.attr()) { + // Validate name + VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()), "Duplicate name: ", + attr.name()); + DataType dt; + VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ", + attr.name(), " that matches a data type"); + + // Validate type + StringPiece type(attr.type()); + bool is_list = type.Consume("list("); + bool found = false; + for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape", + "tensor", "func"}) { + if (type.Consume(valid)) { + found = true; + break; + } + } + VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(), + "'"); + if (is_list) { + VALIDATE(type.Consume(")"), "'list(' is missing ')' in attr ", + attr.name(), "'s type ", attr.type()); + } + VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ", + attr.name(), "'s type ", attr.type()); + + // Validate minimum + if (attr.has_minimum()) { + VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(), + "' has minimum for unsupported type ", attr.type()); + if (is_list) { + VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(), + "' with list type must have a non-negative minimum, not ", + attr.minimum()); + } + } else { + VALIDATE(attr.minimum() == 0, "Attr '", attr.name(), + "' with has_minimum = false but minimum ", attr.minimum(), + " not equal to default of 0"); + } + + // Validate allowed_values + if (attr.has_allowed_values()) { + const string list_type = + is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")"); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + AttrValueHasType(attr.allowed_values(), list_type), " for attr '", + attr.name(), "' in Op '", op_def.name(), "'"); + } + + // Validate default_value (after we have validated the rest of the attr, + // so we can use ValidateAttrValue()). + if (attr.has_default_value()) { + TF_RETURN_WITH_CONTEXT_IF_ERROR( + ValidateAttrValue(attr.default_value(), attr), " in Op '", + op_def.name(), "'"); + } + } + + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names)); + } + + for (const auto& arg : op_def.output_arg()) { + TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names)); + } + + return Status::OK(); +} + +#undef VALIDATE + +namespace { + +string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { + string ret; + for (const OpDef::ArgDef& arg : args) { + if (!ret.empty()) strings::StrAppend(&ret, ", "); + strings::StrAppend(&ret, arg.name(), ":"); + if (arg.is_ref()) strings::StrAppend(&ret, "Ref("); + if (!arg.number_attr().empty()) { + strings::StrAppend(&ret, arg.number_attr(), "*"); + } + if (arg.type() != DT_INVALID) { + strings::StrAppend(&ret, DataTypeString(arg.type())); + } else { + strings::StrAppend(&ret, arg.type_attr()); + } + if (arg.is_ref()) strings::StrAppend(&ret, ")"); + } + return ret; +} + +} // namespace + +string SummarizeOpDef(const OpDef& op_def) { + string ret = strings::StrCat("Op<name=", op_def.name()); + strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()), + " -> ", SummarizeArgs(op_def.output_arg())); + for (int i = 0; i < op_def.attr_size(); ++i) { + strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":", + op_def.attr(i).type()); + if (op_def.attr(i).has_default_value()) { + strings::StrAppend(&ret, ",default=", + SummarizeAttrValue(op_def.attr(i).default_value())); + } + if (op_def.attr(i).has_minimum()) { + strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum()); + } + if (op_def.attr(i).has_allowed_values()) { + strings::StrAppend(&ret, ",allowed=", + SummarizeAttrValue(op_def.attr(i).allowed_values())); + } + } + if (op_def.is_commutative()) { + strings::StrAppend(&ret, "; is_commutative=true"); + } + if (op_def.is_aggregate()) { + strings::StrAppend(&ret, "; is_aggregate=true"); + } + if (op_def.is_stateful()) { + strings::StrAppend(&ret, "; is_stateful=true"); + } + if (op_def.allows_uninitialized_input()) { + strings::StrAppend(&ret, "; allows_uninitialized_input=true"); + } + strings::StrAppend(&ret, ">"); + return ret; +} + +} // namespace tensorflow |