aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2016-07-01 08:09:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-01 09:18:57 -0700
commit242fe922e43440bab0098a208e3ca73e8e042014 (patch)
tree5bb2b4ae3ee5bb868a065617eaa89ca44222a6ac
parenta00fa7b7013b86afafb773b954e5b81ea1c8c85a (diff)
Removed caching of optimizer. Since optimizer may depend on a graph element (global_step).
Change: 126415442
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index a5b1dc722f..5ab5eec26a 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -110,8 +110,7 @@ class _ComposableModel(object):
grads = gradients.gradients(loss, my_vars)
if self._gradient_clip_norm:
grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm)
- self._optimizer = self._get_optimizer()
- return [self._optimizer.apply_gradients(zip(grads, my_vars))]
+ return [self._get_optimizer().apply_gradients(zip(grads, my_vars))]
def _get_feature_columns(self):
if not self._feature_columns:
@@ -132,10 +131,12 @@ class _ComposableModel(object):
def _get_optimizer(self):
if (self._optimizer is None or isinstance(self._optimizer,
six.string_types)):
- self._optimizer = self._get_default_optimizer(self._optimizer)
+ optimizer = self._get_default_optimizer(self._optimizer)
elif callable(self._optimizer):
- self._optimizer = self._optimizer()
- return self._optimizer
+ optimizer = self._optimizer()
+ else:
+ optimizer = self._optimizer
+ return optimizer
def _get_default_optimizer(self, optimizer_name=None):
raise NotImplementedError