diff options
author | 2016-07-01 08:09:22 -0800 | |
---|---|---|
committer | 2016-07-01 09:18:57 -0700 | |
commit | 242fe922e43440bab0098a208e3ca73e8e042014 (patch) | |
tree | 5bb2b4ae3ee5bb868a065617eaa89ca44222a6ac | |
parent | a00fa7b7013b86afafb773b954e5b81ea1c8c85a (diff) |
Removed caching of optimizer. Since optimizer may depend on a graph element (global_step).
Change: 126415442
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index a5b1dc722f..5ab5eec26a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -110,8 +110,7 @@ class _ComposableModel(object): grads = gradients.gradients(loss, my_vars) if self._gradient_clip_norm: grads, _ = clip_ops.clip_by_global_norm(grads, self._gradient_clip_norm) - self._optimizer = self._get_optimizer() - return [self._optimizer.apply_gradients(zip(grads, my_vars))] + return [self._get_optimizer().apply_gradients(zip(grads, my_vars))] def _get_feature_columns(self): if not self._feature_columns: @@ -132,10 +131,12 @@ class _ComposableModel(object): def _get_optimizer(self): if (self._optimizer is None or isinstance(self._optimizer, six.string_types)): - self._optimizer = self._get_default_optimizer(self._optimizer) + optimizer = self._get_default_optimizer(self._optimizer) elif callable(self._optimizer): - self._optimizer = self._optimizer() - return self._optimizer + optimizer = self._optimizer() + else: + optimizer = self._optimizer + return optimizer def _get_default_optimizer(self, optimizer_name=None): raise NotImplementedError |