aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/compatibility
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-05 17:13:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-05 17:17:21 -0700
commit017599d0a1fa7a7227a43649db67e96311033a4e (patch)
treedc4b00270acc53b2e098022fb03e4b54cb40ab9e /tensorflow/tools/compatibility
parente7b37766f53d5d9d976f2ba3046d3df3333c8ebb (diff)
This CL changes the graph-mode API of the learning_rate_decay functions in TF 2.0 to return a no-arg callable to output a learning rate, instead of directly outputting a learning rate tensor.
This brings the graph mode API in line with the eager execution API, where this change was made to allow changing the learning rate value across different invocations of optimizer functions. PiperOrigin-RevId: 211726295
Diffstat (limited to 'tensorflow/tools/compatibility')
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2.py24
-rw-r--r--tensorflow/tools/compatibility/tf_upgrade_v2_test.py13
2 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 9702430a12..38216ce9b1 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import argparse
+import functools
from tensorflow.tools.compatibility import ast_edits
from tensorflow.tools.compatibility import renames_v2
@@ -45,6 +46,29 @@ class TFAPIChangeSpec(ast_edits.APIChangeSpec):
# Specially handled functions.
self.function_handle = {}
+ for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+ "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+ "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+ "tf.train.cosine_decay_restarts",
+ "tf.train.linear_cosine_decay",
+ "tf.train.noisy_linear_cosine_decay"]:
+ self.function_handle[decay] = functools.partial(
+ self._learning_rate_decay_handler, decay_name=decay)
+
+ @staticmethod
+ def _learning_rate_decay_handler(file_edit_recorder, node, decay_name):
+ comment = ("ERROR: %s has been changed to return a callable instead of a "
+ "tensor when graph building, but its functionality remains "
+ "unchanged during eager execution (returns a callable like "
+ "before). The converter cannot detect and fix this reliably, so "
+ "you need to inspect this usage manually.\n") % decay_name
+ file_edit_recorder.add(
+ comment,
+ node.lineno,
+ node.col_offset,
+ decay_name,
+ decay_name,
+ error="%s requires manual check." % decay_name)
if __name__ == "__main__":
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 57ac04de06..3886c1e8b9 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -63,6 +63,19 @@ class TestUpgrade(test_util.TensorFlowTestCase):
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, "tf.math.rsqrt(tf.math.log(3.8))\n")
+ def testLearningRateDecay(self):
+ for decay in ["tf.train.exponential_decay", "tf.train.piecewise_constant",
+ "tf.train.polynomial_decay", "tf.train.natural_exp_decay",
+ "tf.train.inverse_time_decay", "tf.train.cosine_decay",
+ "tf.train.cosine_decay_restarts",
+ "tf.train.linear_cosine_decay",
+ "tf.train.noisy_linear_cosine_decay"]:
+
+ text = "%s(a, b)\n" % decay
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(text, new_text)
+ self.assertEqual(errors, ["test.py:1: %s requires manual check." % decay])
+
class TestUpgradeFiles(test_util.TensorFlowTestCase):