diff options
Diffstat (limited to 'tensorflow/core/framework/op_def_util.cc')
-rw-r--r-- | tensorflow/core/framework/op_def_util.cc | 162 |
1 files changed, 160 insertions, 2 deletions
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index e3aef011de..7b0fa668bf 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -1,13 +1,14 @@ #include "tensorflow/core/framework/op_def_util.h" #include <set> +#include <unordered_map> #include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/port.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/regexp.h" namespace tensorflow { @@ -341,4 +342,161 @@ string SummarizeOpDef(const OpDef& op_def) { return ret; } +namespace { + +typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap; +void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) { + for (const auto& attr : op_def.attr()) { + (*attr_map)[attr.name()] = &attr; + } +} + +// Add a comma to *s every call but the first (*add_comma should be +// initialized to false). +void AddComma(string* s, bool* add_comma) { + if (*add_comma) { + strings::StrAppend(s, ", "); + } else { + *add_comma = true; + } +} + +// Compute a signature for either inputs or outputs that will be the +// same for both the old and new OpDef if they are compatible. We +// assume that new_attrs is a superset of old_attrs, and that any attr +// in the difference has a default. Our strategy is to make a list of +// types, where the types are things like: +// * "int32", "float", etc., +// * "T" for some attr "T" in old_attrs, or +// * "N * type" for "N" either some attr in old_attrs. +// +// We get the types by either using the attrs in args if they are in +// old_attrs, or substituting the default value from new_attrs. +string ComputeArgSignature( + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, + const AttrMap& old_attrs, const AttrMap& new_attrs) { + string s; + bool add_comma = false; + for (const OpDef::ArgDef& arg : args) { + if (!arg.type_list_attr().empty()) { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, arg.type_list_attr()); + if (old_attr) { + // Both old and new have the list(type) attr, so can use it directly. + AddComma(&s, &add_comma); + strings::StrAppend(&s, arg.type_list_attr()); + if (arg.is_ref()) strings::StrAppend(&s, " ref"); + } else { + // Missing the list(type) attr in the old, so use the default + // value for the attr from new instead. + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, arg.type_list_attr()); + const auto& type_list = new_attr->default_value().list().type(); + if (type_list.empty()) continue; + for (int i = 0; i < type_list.size(); ++i) { + AddComma(&s, &add_comma); + strings::StrAppend( + &s, DataTypeString(static_cast<DataType>(type_list.Get(i)))); + if (arg.is_ref()) strings::StrAppend(&s, " ref"); + } + } + } else { + int num = 1; // How many input/outputs does this represent? + if (!arg.number_attr().empty()) { + // N * type case. + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, arg.number_attr()); + if (old_attr) { + // Both old and new have the number attr, so can use it directly. + AddComma(&s, &add_comma); + strings::StrAppend(&s, arg.number_attr(), " * "); + add_comma = false; // Don't add another comma before the type. + } else { + // Missing the number attr in the old, so use the default + // value for the attr from new instead. + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, arg.number_attr()); + num = new_attr->default_value().i(); + } + } + + string type; // What is the type of this arg? + if (arg.type() != DT_INVALID) { + // int32, float, etc. case + type = DataTypeString(arg.type()); + } else { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, arg.type_attr()); + if (old_attr) { + // Both old and new have the type attr, so can use it directly. + type = arg.type_attr(); + } else { + // Missing the type attr in the old, so use the default + // value for the attr from new instead. + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, arg.type_attr()); + type = DataTypeString(new_attr->default_value().type()); + } + } + if (arg.is_ref()) strings::StrAppend(&type, " ref"); + + // Record `num` * `type` in the signature. + for (int i = 0; i < num; ++i) { + AddComma(&s, &add_comma); + strings::StrAppend(&s, type); + } + } + } + + return s; +} + +} // namespace + +Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { +#define VALIDATE(CONDITION, ...) \ + if (!(CONDITION)) { \ + return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \ + "; old: ", SummarizeOpDef(old_op), \ + "; new: ", SummarizeOpDef(new_op)); \ + } + + VALIDATE(old_op.name() == new_op.name(), "Name mismatch"); + + AttrMap new_attrs, old_attrs; + FillAttrMap(old_op, &old_attrs); + FillAttrMap(new_op, &new_attrs); + for (const auto& old_attr : old_op.attr()) { + const OpDef::AttrDef* new_attr = + gtl::FindPtrOrNull(new_attrs, old_attr.name()); + VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed"); + VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(), + "' changed type '", old_attr.type(), "' -> '", new_attr->type(), + "'"); + } + + for (const auto& new_attr : new_op.attr()) { + const OpDef::AttrDef* old_attr = + gtl::FindPtrOrNull(old_attrs, new_attr.name()); + VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '", + new_attr.name(), "' added without default"); + } + + const string old_in_sig = + ComputeArgSignature(old_op.input_arg(), old_attrs, new_attrs); + const string new_in_sig = + ComputeArgSignature(new_op.input_arg(), old_attrs, new_attrs); + VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig, + "' vs. '", new_in_sig, "'"); + + const string old_out_sig = + ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs); + const string new_out_sig = + ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs); + VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '", + old_out_sig, "' vs. '", new_out_sig, "'"); + + return Status::OK(); +} + } // namespace tensorflow |