diff options
author | Igor Saprykin <isaprykin@google.com> | 2018-03-29 19:29:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-29 19:31:34 -0700 |
commit | 8712a7fa24da7803a79c10501f8fd65d358a62b5 (patch) | |
tree | d9e5a0dfbd964589c9cefe15ad29f4ab37968826 /tensorflow/contrib/optimizer_v2 | |
parent | 28dec7f4669e8ed5af4a3abebf9888d3fdffe5fd (diff) |
Internal change.
PiperOrigin-RevId: 191023160
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/adagrad.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py index e54f990cca..c333d1e089 100644 --- a/tensorflow/contrib/optimizer_v2/adagrad.py +++ b/tensorflow/contrib/optimizer_v2/adagrad.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.optimizer_v2 import optimizer_v2 -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops @@ -65,17 +64,18 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2): def _create_vars(self, var_list, state): for v in var_list: - with ops.colocate_with(v): - dtype = v.dtype.base_dtype - if v.get_shape().is_fully_defined(): - init = init_ops.constant_initializer(self._initial_accumulator_value, - dtype=dtype) - else: - # Use a Tensor instead of initializer if variable does not have static - # shape. - init_constant = gen_array_ops.fill( - array_ops.shape(v), self._initial_accumulator_value) - init = math_ops.cast(init_constant, dtype) + # TODO(isaprykin): Delete colocate_with(v) from other optimizers and + # confirm that colocation will happen anyway. + dtype = v.dtype.base_dtype + if v.get_shape().is_fully_defined(): + init = init_ops.constant_initializer(self._initial_accumulator_value, + dtype=dtype) + else: + # Use a Tensor instead of initializer if variable does not have static + # shape. + init_constant = gen_array_ops.fill( + array_ops.shape(v), self._initial_accumulator_value) + init = math_ops.cast(init_constant, dtype) state.create_slot_with_initializer(v, init, v.get_shape(), dtype, "accumulator") |