aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Geoffrey Irving <geoffreyi@google.com>2016-04-20 14:58:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-20 16:01:43 -0700
commit30334d28c0ac0c2d0369c4f7f31f5b7ffff06ddd (patch)
treec65bac9a10386302ae769895e5652e0c3a206560
parentc3af083a6411d7facfded6f6087f6524696efc0d (diff)
Add a .Deprecated method to REGISTER_OP
This replaces the OP_DEPRECATED macro with something declarative, which in particular lets us throw exceptions at graph construction time based on deprecation. I've left the OP_DEPRECATED macro around in case uses elsewhere can't be expressed in a purely declarative manner. Change: 120386133
-rw-r--r--tensorflow/core/framework/op.h5
-rw-r--r--tensorflow/core/framework/op_def.proto12
-rw-r--r--tensorflow/core/framework/op_def_builder.cc12
-rw-r--r--tensorflow/core/framework/op_def_builder.h3
-rw-r--r--tensorflow/core/framework/op_def_util.cc9
-rw-r--r--tensorflow/core/framework/op_def_util.h3
-rw-r--r--tensorflow/core/framework/op_kernel.cc4
-rw-r--r--tensorflow/core/framework/op_kernel.h5
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.cc1
-rw-r--r--tensorflow/core/kernels/batch_norm_op.cc2
-rw-r--r--tensorflow/core/kernels/random_crop_op.cc1
-rw-r--r--tensorflow/core/kernels/tile_ops.cc4
-rw-r--r--tensorflow/core/kernels/topk_op.cc1
-rw-r--r--tensorflow/core/ops/array_ops.cc1
-rw-r--r--tensorflow/core/ops/image_ops.cc2
-rw-r--r--tensorflow/core/ops/nn_ops.cc3
-rw-r--r--tensorflow/python/framework/ops_test.py29
-rw-r--r--tensorflow/python/framework/python_op_gen.cc2
-rw-r--r--tensorflow/python/framework/test_ops.cc11
-rw-r--r--tensorflow/python/ops/op_def_library.py11
20 files changed, 111 insertions, 10 deletions
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index bf887982fd..8f26025d42 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -209,6 +209,10 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
+ OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
+ builder_.Deprecated(version, explanation);
+ return *this;
+ }
OpDefBuilderWrapper<true>& Doc(StringPiece text) {
builder_.Doc(text);
return *this;
@@ -231,6 +235,7 @@ class OpDefBuilderWrapper<false> {
OpDefBuilderWrapper<false>& SetIsAggregate() { return *this; }
OpDefBuilderWrapper<false>& SetIsStateful() { return *this; }
OpDefBuilderWrapper<false>& SetAllowsUninitializedInput() { return *this; }
+ OpDefBuilderWrapper<false>& Deprecated(int, StringPiece) { return *this; }
OpDefBuilderWrapper<false>& Doc(StringPiece text) { return *this; }
};
diff --git a/tensorflow/core/framework/op_def.proto b/tensorflow/core/framework/op_def.proto
index da25994727..0ba3b15b1e 100644
--- a/tensorflow/core/framework/op_def.proto
+++ b/tensorflow/core/framework/op_def.proto
@@ -94,6 +94,9 @@ message OpDef {
}
repeated AttrDef attr = 4;
+ // Optional deprecation based on GraphDef versions.
+ OpDeprecation deprecation = 8;
+
// One-line human-readable description of what the Op does.
string summary = 5;
@@ -139,6 +142,15 @@ message OpDef {
bool allows_uninitialized_input = 19; // for Assign, etc.
};
+// Information about version-dependent deprecation of an op
+message OpDeprecation {
+ // First GraphDef version at which the op is disallowed.
+ int32 version = 1;
+
+ // Explanation of why it was deprecated and what to use instead.
+ string explanation = 2;
+};
+
// A collection of OpDefs
message OpList {
repeated OpDef op = 1;
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 5931b1ace1..2d02ed3a42 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -541,6 +541,18 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
+OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
+ if (op_def_.has_deprecation()) {
+ errors_.push_back(
+ strings::StrCat("Deprecated called twice for Op ", op_def_.name()));
+ } else {
+ OpDeprecation* deprecation = op_def_.mutable_deprecation();
+ deprecation->set_version(version);
+ deprecation->set_explanation(explanation.ToString());
+ }
+ return *this;
+}
+
Status OpDefBuilder::Finalize(OpDef* op_def) const {
std::vector<string> errors = errors_;
*op_def = op_def_;
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index 7acf062163..70d45d3333 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -89,6 +89,9 @@ class OpDefBuilder {
OpDefBuilder& SetIsStateful();
OpDefBuilder& SetAllowsUninitializedInput();
+ // Deprecate the op at a certain GraphDef version.
+ OpDefBuilder& Deprecated(int version, StringPiece explanation);
+
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
// <1-line summary>
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index f7e4f1f05a..5d0a633ae7 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -561,7 +561,7 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
return Status::OK();
}
-void RemoveDescriptionsFromOpDef(OpDef* op_def) {
+void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
for (int i = 0; i < op_def->input_arg_size(); ++i) {
op_def->mutable_input_arg(i)->clear_description();
}
@@ -575,6 +575,13 @@ void RemoveDescriptionsFromOpDef(OpDef* op_def) {
op_def->clear_description();
}
+void RemoveDescriptionsFromOpDef(OpDef* op_def) {
+ RemoveNonDeprecationDescriptionsFromOpDef(op_def);
+ if (op_def->has_deprecation()) {
+ op_def->mutable_deprecation()->clear_explanation();
+ }
+}
+
void RemoveDescriptionsFromOpList(OpList* op_list) {
for (int i = 0; i < op_list->op_size(); ++i) {
OpDef* op_def = op_list->mutable_op(i);
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
index de350d9aaa..23da543e00 100644
--- a/tensorflow/core/framework/op_def_util.h
+++ b/tensorflow/core/framework/op_def_util.h
@@ -58,6 +58,9 @@ Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
void RemoveDescriptionsFromOpDef(OpDef* op_def);
void RemoveDescriptionsFromOpList(OpList* op_list);
+// Remove docs from *op_def but leave explanations of deprecations.
+void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def);
+
} // namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index 1ed7267664..8be297b72a 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -89,6 +89,10 @@ OpKernel::OpKernel(OpKernelConstruction* context)
OP_REQUIRES_OK(context,
NameRangesForNode(def_, context->op_def(), &input_name_map_,
&output_name_map_));
+ if (context->op_def().has_deprecation()) {
+ const OpDeprecation& deprecation = context->op_def().deprecation();
+ OP_DEPRECATED(context, deprecation.version(), deprecation.explanation());
+ }
}
OpKernel::~OpKernel() {}
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index ca017aa189..85b9b063b6 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -1252,6 +1252,11 @@ inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
// }
// Declares an op deprecated, and illegal starting at GraphDef version VERSION
+// Cleverly, OP_DEPRECATED is itself deprecated for most users; instead, use
+// REGISTER_OP(...)
+// ...
+// .Deprecated(version, note)
+// ...
#define OP_DEPRECATED(CTX, VERSION, NOTE) \
if ((CTX)->graph_def_version() >= (VERSION)) { \
::tensorflow::Status _s(::tensorflow::errors::Unimplemented( \
diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc
index 878eceae2c..a85bcc99f1 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.cc
+++ b/tensorflow/core/kernels/adjust_contrast_op.cc
@@ -38,7 +38,6 @@ template <typename Device, typename T>
class AdjustContrastOp : public OpKernel {
public:
explicit AdjustContrastOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_DEPRECATED(context, 2, "Use AdjustContrastv2 instead");
}
void Compute(OpKernelContext* context) override {
diff --git a/tensorflow/core/kernels/batch_norm_op.cc b/tensorflow/core/kernels/batch_norm_op.cc
index fb663c9e83..2389af050a 100644
--- a/tensorflow/core/kernels/batch_norm_op.cc
+++ b/tensorflow/core/kernels/batch_norm_op.cc
@@ -33,7 +33,6 @@ template <typename Device, typename T>
class BatchNormOp : public OpKernel {
public:
explicit BatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
OP_REQUIRES_OK(context,
context->GetAttr("variance_epsilon", &variance_epsilon_));
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
@@ -82,7 +81,6 @@ template <typename Device, typename T>
class BatchNormGradOp : public OpKernel {
public:
explicit BatchNormGradOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_DEPRECATED(context, 9, "Use tf.nn.batch_normalization()");
OP_REQUIRES_OK(context,
context->GetAttr("variance_epsilon", &variance_epsilon_));
OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc
index 80b4041e05..e27de3b4e0 100644
--- a/tensorflow/core/kernels/random_crop_op.cc
+++ b/tensorflow/core/kernels/random_crop_op.cc
@@ -28,7 +28,6 @@ template <typename T>
class RandomCropOp : public OpKernel {
public:
explicit RandomCropOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_DEPRECATED(context, 8, "Random crop is now pure Python");
OP_REQUIRES_OK(context, generator_.Init(context));
}
diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc
index 5fc791506a..1d0d341c8f 100644
--- a/tensorflow/core/kernels/tile_ops.cc
+++ b/tensorflow/core/kernels/tile_ops.cc
@@ -186,9 +186,7 @@ HANDLE_CASE_DIM(GPUDevice, DT_INT64);
template <typename Device>
class TileGradientOp : public OpKernel {
public:
- explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {
- OP_DEPRECATED(context, 3, "TileGrad has been replaced with reduce_sum");
- }
+ explicit TileGradientOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
diff --git a/tensorflow/core/kernels/topk_op.cc b/tensorflow/core/kernels/topk_op.cc
index d5ea5e15a8..27c00e912c 100644
--- a/tensorflow/core/kernels/topk_op.cc
+++ b/tensorflow/core/kernels/topk_op.cc
@@ -33,7 +33,6 @@ class TopK : public OpKernel {
explicit TopK(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("sorted", &sorted_));
if (num_inputs() < 2) { // k is an attr (TopK).
- OP_DEPRECATED(context, 7, "Use TopKV2 instead");
OP_REQUIRES_OK(context, context->GetAttr("k", &k_));
} else { // k is an input (TopKV2), so we won't know it until Compute.
k_ = -1;
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 899e05c35e..970759db7a 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -993,6 +993,7 @@ REGISTER_OP("TileGrad")
.Input("multiples: int32")
.Output("output: T")
.Attr("T: type")
+ .Deprecated(3, "TileGrad has been replaced with reduce_sum")
.Doc(R"doc(
Returns the gradient of `Tile`.
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 64553c6bb5..ed0a5faa51 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -153,6 +153,7 @@ REGISTER_OP("RandomCrop")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.SetIsStateful()
+ .Deprecated(8, "Random crop is now pure Python")
.Doc(R"doc(
Randomly crop `image`.
@@ -267,6 +268,7 @@ REGISTER_OP("AdjustContrast")
.Input("max_value: float")
.Output("output: float")
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
+ .Deprecated(2, "Use AdjustContrastv2 instead")
.Doc(R"Doc(
Deprecated. Disallowed in GraphDef version >= 2.
)Doc");
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 6db0289f1f..0c31bb5be7 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -85,6 +85,7 @@ REGISTER_OP("BatchNormWithGlobalNormalization")
.Attr("T: numbertype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
+ .Deprecated(9, "Use tf.nn.batch_normalization()")
.Doc(R"doc(
Batch normalization.
@@ -121,6 +122,7 @@ REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
.Attr("T: numbertype")
.Attr("variance_epsilon: float")
.Attr("scale_after_normalization: bool")
+ .Deprecated(9, "Use tf.nn.batch_normalization()")
.Doc(R"doc(
Gradients for batch normalization.
@@ -815,6 +817,7 @@ REGISTER_OP("TopK")
.Attr("k: int >= 0")
.Attr("sorted: bool = true")
.Attr("T: realnumbertype")
+ .Deprecated(7, "Use TopKV2 instead")
.Doc(R"doc(
Finds values and indices of the `k` largest elements for the last dimension.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 1d3cb4e54e..6dafe4c1a1 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -1353,5 +1353,34 @@ class ColocationGroupTest(test_util.TensorFlowTestCase):
self.assertEqual("/device:CPU:0", b.device)
+class DeprecatedTest(test_util.TensorFlowTestCase):
+
+ def testSuccess(self):
+ with ops.Graph().as_default() as g:
+ g.graph_def_versions.producer = 7
+ old = test_ops.old()
+ with self.test_session(graph=g):
+ old.run()
+
+ def _error(self):
+ return ((r"Op Old is not available in GraphDef version %d\. "
+ r"It has been removed in version 8\. For reasons\.") %
+ versions.GRAPH_DEF_VERSION)
+
+ def testGraphConstructionFail(self):
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(NotImplementedError, self._error()):
+ test_ops.old()
+
+ def testGraphExecutionFail(self):
+ with ops.Graph().as_default() as g:
+ g.graph_def_versions.producer = 7
+ old = test_ops.old()
+ g.graph_def_versions.producer = versions.GRAPH_DEF_VERSION
+ with self.test_session(graph=g):
+ with self.assertRaisesOpError(self._error()):
+ old.run()
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 28c0da1db0..c06c40d670 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -661,7 +661,7 @@ from tensorflow.python.ops import op_def_library
auto added = out->Add();
*added = op_def;
- RemoveDescriptionsFromOpDef(added);
+ RemoveNonDeprecationDescriptionsFromOpDef(added);
}
strings::Appendf(&result, R"(def _InitOpDefLibrary():
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index 925e938c6d..1c86826159 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -23,6 +23,8 @@ REGISTER_OP("KernelLabel").Output("result: string");
REGISTER_OP("GraphDefVersion").Output("version: int32").SetIsStateful();
+REGISTER_OP("Old").Deprecated(8, "For reasons");
+
namespace {
enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL };
} // namespace
@@ -79,4 +81,13 @@ class GraphDefVersionOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("GraphDefVersion").Device(DEVICE_CPU),
GraphDefVersionOp);
+class OldOp : public OpKernel {
+ public:
+ OldOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {}
+};
+
+REGISTER_KERNEL_BUILDER(Name("Old").Device(DEVICE_CPU), OldOp);
+
} // end namespace tensorflow
diff --git a/tensorflow/python/ops/op_def_library.py b/tensorflow/python/ops/op_def_library.py
index d657cc1333..f7150d4f39 100644
--- a/tensorflow/python/ops/op_def_library.py
+++ b/tensorflow/python/ops/op_def_library.py
@@ -340,6 +340,17 @@ class OpDefLibrary(object):
if name is None:
name = op_type_name
+ # Check for deprecation
+ deprecation_version = op_def.deprecation.version
+ if deprecation_version:
+ producer = g.graph_def_versions.producer
+ if producer >= deprecation_version:
+ raise NotImplementedError(
+ ("Op %s is not available in GraphDef version %d. "
+ "It has been removed in version %d. %s.") %
+ (op_type_name, producer, deprecation_version,
+ op_def.deprecation.explanation))
+
# Requires that op_def has passed validation (using the C++
# ValidateOpDef() from ../framework/op_def_util.h).
attrs = {}