diff options
author | 2018-01-02 16:52:50 -0800 | |
---|---|---|
committer | 2018-01-02 16:56:30 -0800 | |
commit | 6a20edf95fcaf45c46385eaf649e814a571737ed (patch) | |
tree | 63f1a2679245c620fdfb721bf69c5b5d049c7350 | |
parent | 89cd0cd81ae829610fcbf4437597451ae5a59fe6 (diff) |
backward compatibility: Disallow changes to an OpDef attribute's default value.
PiperOrigin-RevId: 180611380
-rw-r--r-- | tensorflow/core/framework/op_compatibility_test.cc | 99 | ||||
-rw-r--r-- | tensorflow/core/framework/op_def_util.cc | 15 |
2 files changed, 77 insertions, 37 deletions
diff --git a/tensorflow/core/framework/op_compatibility_test.cc b/tensorflow/core/framework/op_compatibility_test.cc index ae2fdae379..4f4813d9fa 100644 --- a/tensorflow/core/framework/op_compatibility_test.cc +++ b/tensorflow/core/framework/op_compatibility_test.cc @@ -163,6 +163,18 @@ class OpCompatibilityTest : public OpsTestBase { ExpectIncompatible(old_op_def, *new_op_def, compatibility_error); } + + void ExpectDefaultChangeFailure(const OpDef& old_op_def, + const string& compatibility_error) { + // This should be all that is needed to get compatibility. + const OpDef* new_op_def = RegisteredOpDef(); + AddDefaultsToNodeDef(*new_op_def, node_def()); + + // Validate that the NodeDef is valid. + TF_ASSERT_OK(ValidateNodeDef(*node_def(), *new_op_def)); + + ExpectIncompatible(old_op_def, *new_op_def, compatibility_error); + } }; // Should be compatible if the Op hasn't changed (sanity check). @@ -260,40 +272,6 @@ TEST_F(OpCompatibilityTest, AttrOrder) { EXPECT_EQ("attr_order = AttrOrder[a=7, b=true]()", Result()); } -// Should be able to add a default to an attr. -REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234"); -REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel); - -TEST_F(OpCompatibilityTest, AddDefault) { - OpRegistrationData old_op; - TF_ASSERT_OK(OpDefBuilder("AddDefault") - .Output("ndef: string") - .Attr("a: int") - .Finalize(&old_op)); - TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def) - .Attr("a", 765) - .Finalize(node_def())); - ExpectSuccess(old_op.op_def); - EXPECT_EQ("add_default = AddDefault[a=765]()", Result()); -} - -// Should be able to remove a default from an attr, *as long as that -// attr has always existed*. -REGISTER_OP("RemoveDefault").Output("ndef: string").Attr("a: int"); -REGISTER_KERNEL_BUILDER(Name("RemoveDefault").Device(DEVICE_CPU), TestKernel); - -TEST_F(OpCompatibilityTest, RemoveDefault) { - OpRegistrationData old_op; - TF_ASSERT_OK(OpDefBuilder("RemoveDefault") - .Output("ndef: string") - .Attr("a: int = 91") - .Finalize(&old_op)); - TF_ASSERT_OK( - NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def())); - ExpectSuccess(old_op.op_def); - EXPECT_EQ("remove_default = RemoveDefault[a=91]()", Result()); -} - // Should be able to make an input/output polymorphic. // Changing from int32 -> T (where T: type = DT_INT32 by default). REGISTER_OP("TypePolymorphic") @@ -1054,9 +1032,56 @@ TEST_F(OpCompatibilityTest, RenameOutputListFails) { "Output signature mismatch 'old:T' vs. 'new:T'"); } -// Changing an attr's default is not technically illegal, but should -// be forbidden if it the attr ever didn't exist since it likely -// affects semantics. +// Should not be able to add a default to an attr. +REGISTER_OP("AddDefault").Output("ndef: string").Attr("a: int = 1234"); +REGISTER_KERNEL_BUILDER(Name("AddDefault").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, AddDefault) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("AddDefault") + .Output("ndef: string") + .Attr("a: int") + .Finalize(&old_op)); + TF_ASSERT_OK(NodeDefBuilder("add_default", &old_op.op_def) + .Attr("a", 765) + .Finalize(node_def())); + ExpectDefaultChangeFailure( + old_op.op_def, + "Attr 'a' has added/removed it's default; from no default to 1234"); +} + +// Should not be able to remove a default from an attr. +REGISTER_OP("RemoveDefault").Output("ndef: string").Attr("a: int"); +REGISTER_KERNEL_BUILDER(Name("RemoveDefault").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, RemoveDefault) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("RemoveDefault") + .Output("ndef: string") + .Attr("a: int = 91") + .Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("remove_default", &old_op.op_def).Finalize(node_def())); + ExpectDefaultChangeFailure( + old_op.op_def, + "Attr 'a' has added/removed it's default; from 91 to no default"); +} + +// Should not be able to change a default for an attr. +REGISTER_OP("ChangeDefault").Output("ndef: string").Attr("a: int = 1"); +REGISTER_KERNEL_BUILDER(Name("ChangeDefault").Device(DEVICE_CPU), TestKernel); + +TEST_F(OpCompatibilityTest, ChangeDefault) { + OpRegistrationData old_op; + TF_ASSERT_OK(OpDefBuilder("ChangeDefault") + .Output("ndef: string") + .Attr("a: int = 2") + .Finalize(&old_op)); + TF_ASSERT_OK( + NodeDefBuilder("change_default", &old_op.op_def).Finalize(node_def())); + ExpectDefaultChangeFailure( + old_op.op_def, "Attr 'a' has changed it's default value; from 2 to 1"); +} } // namespace } // namespace tensorflow diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc index 29feda499f..f9030e93ab 100644 --- a/tensorflow/core/framework/op_def_util.cc +++ b/tensorflow/core/framework/op_def_util.cc @@ -449,6 +449,11 @@ string AllowedStr(const OpDef::AttrDef& attr) { return SummarizeAttrValue(attr.allowed_values()); } +string DefaultAttrStr(const OpDef::AttrDef& attr) { + if (!attr.has_default_value()) return "no default"; + return SummarizeAttrValue(attr.default_value()); +} + bool HigherMinimum(const OpDef::AttrDef& old_attr, const OpDef::AttrDef& new_attr) { // Anything -> no restriction : not more restrictive. @@ -610,6 +615,16 @@ Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) { VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(), "' has a higher minimum; from ", MinStr(old_attr), " to ", MinStr(*new_attr)); + VALIDATE(old_attr.has_default_value() == new_attr->has_default_value(), + "Attr '", old_attr.name(), "' has added/removed it's default; ", + "from ", DefaultAttrStr(old_attr), " to ", + DefaultAttrStr(*new_attr)); + VALIDATE(!old_attr.has_default_value() || + AreAttrValuesEqual(old_attr.default_value(), + new_attr->default_value()), + "Attr '", old_attr.name(), "' has changed it's default value; ", + "from ", DefaultAttrStr(old_attr), " to ", + DefaultAttrStr(*new_attr)); } for (const auto& new_attr : new_op.attr()) { |