diff options
author | Priya Gupta <priyag@google.com> | 2018-06-19 14:13:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-19 14:16:04 -0700 |
commit | 445f16740007f209f426149fcf9b3c6ef4344532 (patch) | |
tree | cf58436a70ec5271efead3fa62b066ef514fd6a1 /tensorflow/contrib/optimizer_v2 | |
parent | 92a55c7abd5a99771315724f162fea711ee3d9fb (diff) |
Create hyper parameter tensors in optimizer v2 outside any control flow contexts.
Also, use lambdas for creating the non slot variables in adam v2.
These changes are needed to allow optimizer.minimize to run inside a while loop, which will be done in distribution strategies shortly.
PiperOrigin-RevId: 201238566
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/adam.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/optimizer_v2/optimizer_v2.py | 5 |
2 files changed, 5 insertions, 4 deletions
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py index d538ad0fb0..631d4f44df 100644 --- a/tensorflow/contrib/optimizer_v2/adam.py +++ b/tensorflow/contrib/optimizer_v2/adam.py @@ -103,9 +103,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): # Non-slot variables end up on the same device(s). - state.create_non_slot(initial_value=state.get_hyper("beta1"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"), name="beta1_power") - state.create_non_slot(initial_value=state.get_hyper("beta2"), + state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"), name="beta2_power") # Create slots for the first and second moments. diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index f537318b32..a44f29fa37 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -211,8 +211,9 @@ class _OptimizerV2State(object): # This dict starts with a single item with key "None" with the hyper # parameter value converted to a Tensor. Other items have dtype keys # with that Tensor cast to that dtype. - self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} - for name, (dynamic, value) in hyper.items() if not dynamic} + with ops.init_scope(): + self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)} + for name, (dynamic, value) in hyper.items() if not dynamic} self._slots = {} self._non_slot_dict = {} # Extra state to help Optimizers implement Checkpointable. Holds information |