aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/op_def_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-11 16:13:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-11 17:16:04 -0700
commit67f123ae3d37426f592dee67fb34d1da7865baf0 (patch)
tree8332aca2995d6fdccb618b51000c763011df3fde /tensorflow/core/framework/op_def_util.cc
parent225822ae5ad6db5cda260658bd2c9cc5bc6725d1 (diff)
Allow "ref" to be removed from inputs and added to outputs without
breaking GraphDef compatibility. Change: 130051016
Diffstat (limited to 'tensorflow/core/framework/op_def_util.cc')
-rw-r--r--tensorflow/core/framework/op_def_util.cc42
1 files changed, 29 insertions, 13 deletions
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index 96f2a77f78..d92a1d47ac 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -484,7 +484,8 @@ void AddComma(string* s, bool* add_comma) {
// old_attrs, or substituting the default value from new_attrs.
string ComputeArgSignature(
const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
- const AttrMap& old_attrs, const AttrMap& new_attrs) {
+ const AttrMap& old_attrs, const AttrMap& new_attrs,
+ std::vector<bool>* ref) {
string s;
bool add_comma = false;
for (const OpDef::ArgDef& arg : args) {
@@ -495,7 +496,7 @@ string ComputeArgSignature(
// Both old and new have the list(type) attr, so can use it directly.
AddComma(&s, &add_comma);
strings::StrAppend(&s, arg.type_list_attr());
- if (arg.is_ref()) strings::StrAppend(&s, " ref");
+ ref->push_back(arg.is_ref());
} else {
// Missing the list(type) attr in the old, so use the default
// value for the attr from new instead.
@@ -507,7 +508,7 @@ string ComputeArgSignature(
AddComma(&s, &add_comma);
strings::StrAppend(
&s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
- if (arg.is_ref()) strings::StrAppend(&s, " ref");
+ ref->push_back(arg.is_ref());
}
}
} else {
@@ -548,12 +549,12 @@ string ComputeArgSignature(
type = DataTypeString(new_attr->default_value().type());
}
}
- if (arg.is_ref()) strings::StrAppend(&type, " ref");
// Record `num` * `type` in the signature.
for (int i = 0; i < num; ++i) {
AddComma(&s, &add_comma);
strings::StrAppend(&s, type);
+ ref->push_back(arg.is_ref());
}
}
}
@@ -598,19 +599,34 @@ Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
new_attr.name(), "' added without default");
}
- const string old_in_sig =
- ComputeArgSignature(old_op.input_arg(), old_attrs, new_attrs);
- const string new_in_sig =
- ComputeArgSignature(new_op.input_arg(), old_attrs, new_attrs);
+ std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
+ const string old_in_sig = ComputeArgSignature(old_op.input_arg(), old_attrs,
+ new_attrs, &old_in_ref);
+ const string new_in_sig = ComputeArgSignature(new_op.input_arg(), old_attrs,
+ new_attrs, &new_in_ref);
VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
"' vs. '", new_in_sig, "'");
-
- const string old_out_sig =
- ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs);
- const string new_out_sig =
- ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs);
+ VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen
+ "Unexpected change in input ref lists.");
+ for (int i = 0; i < old_in_ref.size(); ++i) {
+ // Allowed to remove "ref" from an input (or leave it unchanged).
+ VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
+ " changed from non-ref to ref");
+ }
+
+ const string old_out_sig = ComputeArgSignature(old_op.output_arg(), old_attrs,
+ new_attrs, &old_out_ref);
+ const string new_out_sig = ComputeArgSignature(new_op.output_arg(), old_attrs,
+ new_attrs, &new_out_ref);
VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
old_out_sig, "' vs. '", new_out_sig, "'");
+ VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen
+ "Unexpected change in output ref lists");
+ for (int i = 0; i < old_out_ref.size(); ++i) {
+ // Allowed to add "ref" to an output (or leave it unchanged).
+ VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
+ " changed from ref to non-ref");
+ }
return Status::OK();
}