aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_def_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/op_def_util.cc')
-rw-r--r--tensorflow/core/framework/op_def_util.cc162
1 files changed, 160 insertions, 2 deletions
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index e3aef011de..7b0fa668bf 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -1,13 +1,14 @@
#include "tensorflow/core/framework/op_def_util.h"
#include <set>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value_util.h"
-#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
namespace tensorflow {
@@ -341,4 +342,161 @@ string SummarizeOpDef(const OpDef& op_def) {
return ret;
}
+namespace {
+
+typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
+void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
+ for (const auto& attr : op_def.attr()) {
+ (*attr_map)[attr.name()] = &attr;
+ }
+}
+
+// Add a comma to *s every call but the first (*add_comma should be
+// initialized to false).
+void AddComma(string* s, bool* add_comma) {
+ if (*add_comma) {
+ strings::StrAppend(s, ", ");
+ } else {
+ *add_comma = true;
+ }
+}
+
+// Compute a signature for either inputs or outputs that will be the
+// same for both the old and new OpDef if they are compatible. We
+// assume that new_attrs is a superset of old_attrs, and that any attr
+// in the difference has a default. Our strategy is to make a list of
+// types, where the types are things like:
+// * "int32", "float", etc.,
+// * "T" for some attr "T" in old_attrs, or
+// * "N * type" for "N" either some attr in old_attrs.
+//
+// We get the types by either using the attrs in args if they are in
+// old_attrs, or substituting the default value from new_attrs.
+string ComputeArgSignature(
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ const AttrMap& old_attrs, const AttrMap& new_attrs) {
+ string s;
+ bool add_comma = false;
+ for (const OpDef::ArgDef& arg : args) {
+ if (!arg.type_list_attr().empty()) {
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
+ if (old_attr) {
+ // Both old and new have the list(type) attr, so can use it directly.
+ AddComma(&s, &add_comma);
+ strings::StrAppend(&s, arg.type_list_attr());
+ if (arg.is_ref()) strings::StrAppend(&s, " ref");
+ } else {
+ // Missing the list(type) attr in the old, so use the default
+ // value for the attr from new instead.
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
+ const auto& type_list = new_attr->default_value().list().type();
+ if (type_list.empty()) continue;
+ for (int i = 0; i < type_list.size(); ++i) {
+ AddComma(&s, &add_comma);
+ strings::StrAppend(
+ &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
+ if (arg.is_ref()) strings::StrAppend(&s, " ref");
+ }
+ }
+ } else {
+ int num = 1; // How many input/outputs does this represent?
+ if (!arg.number_attr().empty()) {
+ // N * type case.
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, arg.number_attr());
+ if (old_attr) {
+ // Both old and new have the number attr, so can use it directly.
+ AddComma(&s, &add_comma);
+ strings::StrAppend(&s, arg.number_attr(), " * ");
+ add_comma = false; // Don't add another comma before the type.
+ } else {
+ // Missing the number attr in the old, so use the default
+ // value for the attr from new instead.
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, arg.number_attr());
+ num = new_attr->default_value().i();
+ }
+ }
+
+ string type; // What is the type of this arg?
+ if (arg.type() != DT_INVALID) {
+ // int32, float, etc. case
+ type = DataTypeString(arg.type());
+ } else {
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, arg.type_attr());
+ if (old_attr) {
+ // Both old and new have the type attr, so can use it directly.
+ type = arg.type_attr();
+ } else {
+ // Missing the type attr in the old, so use the default
+ // value for the attr from new instead.
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, arg.type_attr());
+ type = DataTypeString(new_attr->default_value().type());
+ }
+ }
+ if (arg.is_ref()) strings::StrAppend(&type, " ref");
+
+ // Record `num` * `type` in the signature.
+ for (int i = 0; i < num; ++i) {
+ AddComma(&s, &add_comma);
+ strings::StrAppend(&s, type);
+ }
+ }
+ }
+
+ return s;
+}
+
+} // namespace
+
+Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
+#define VALIDATE(CONDITION, ...) \
+ if (!(CONDITION)) { \
+ return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
+ "; old: ", SummarizeOpDef(old_op), \
+ "; new: ", SummarizeOpDef(new_op)); \
+ }
+
+ VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
+
+ AttrMap new_attrs, old_attrs;
+ FillAttrMap(old_op, &old_attrs);
+ FillAttrMap(new_op, &new_attrs);
+ for (const auto& old_attr : old_op.attr()) {
+ const OpDef::AttrDef* new_attr =
+ gtl::FindPtrOrNull(new_attrs, old_attr.name());
+ VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
+ VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
+ "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
+ "'");
+ }
+
+ for (const auto& new_attr : new_op.attr()) {
+ const OpDef::AttrDef* old_attr =
+ gtl::FindPtrOrNull(old_attrs, new_attr.name());
+ VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
+ new_attr.name(), "' added without default");
+ }
+
+ const string old_in_sig =
+ ComputeArgSignature(old_op.input_arg(), old_attrs, new_attrs);
+ const string new_in_sig =
+ ComputeArgSignature(new_op.input_arg(), old_attrs, new_attrs);
+ VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
+ "' vs. '", new_in_sig, "'");
+
+ const string old_out_sig =
+ ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs);
+ const string new_out_sig =
+ ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs);
+ VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
+ old_out_sig, "' vs. '", new_out_sig, "'");
+
+ return Status::OK();
+}
+
} // namespace tensorflow