aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-03-29 19:29:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-29 19:31:34 -0700
commit8712a7fa24da7803a79c10501f8fd65d358a62b5 (patch)
treed9e5a0dfbd964589c9cefe15ad29f4ab37968826 /tensorflow/contrib/optimizer_v2
parent28dec7f4669e8ed5af4a3abebf9888d3fdffe5fd (diff)
Internal change.
PiperOrigin-RevId: 191023160
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py24
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index e54f990cca..c333d1e089 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
@@ -65,17 +64,18 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
for v in var_list:
- with ops.colocate_with(v):
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- # Use a Tensor instead of initializer if variable does not have static
- # shape.
- init_constant = gen_array_ops.fill(
- array_ops.shape(v), self._initial_accumulator_value)
- init = math_ops.cast(init_constant, dtype)
+ # TODO(isaprykin): Delete colocate_with(v) from other optimizers and
+ # confirm that colocation will happen anyway.
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ # Use a Tensor instead of initializer if variable does not have static
+ # shape.
+ init_constant = gen_array_ops.fill(
+ array_ops.shape(v), self._initial_accumulator_value)
+ init = math_ops.cast(init_constant, dtype)
state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator")