aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-06-19 14:13:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 14:16:04 -0700
commit445f16740007f209f426149fcf9b3c6ef4344532 (patch)
treecf58436a70ec5271efead3fa62b066ef514fd6a1 /tensorflow/contrib/optimizer_v2
parent92a55c7abd5a99771315724f162fea711ee3d9fb (diff)
Create hyper parameter tensors in optimizer v2 outside any control flow contexts.
Also, use lambdas for creating the non slot variables in adam v2. These changes are needed to allow optimizer.minimize to run inside a while loop, which will be done in distribution strategies shortly. PiperOrigin-RevId: 201238566
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py4
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py5
2 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index d538ad0fb0..631d4f44df 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
# Non-slot variables end up on the same device(s).
- state.create_non_slot(initial_value=state.get_hyper("beta1"),
+ state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"),
name="beta1_power")
- state.create_non_slot(initial_value=state.get_hyper("beta2"),
+ state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"),
name="beta2_power")
# Create slots for the first and second moments.
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index f537318b32..a44f29fa37 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -211,8 +211,9 @@ class _OptimizerV2State(object):
# This dict starts with a single item with key "None" with the hyper
# parameter value converted to a Tensor. Other items have dtype keys
# with that Tensor cast to that dtype.
- self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in hyper.items() if not dynamic}
+ with ops.init_scope():
+ self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
+ for name, (dynamic, value) in hyper.items() if not dynamic}
self._slots = {}
self._non_slot_dict = {}
# Extra state to help Optimizers implement Checkpointable. Holds information