aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-04-04 15:39:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-04-04 16:41:40 -0700
commit3ab39c186d2a7e260023b68809963d0813566473 (patch)
tree1f29e3c52d9995ff183a266e06028c8693e54994
parent3a6565bffa011fe388884f0e0b844e74fa9e065b (diff)
Updated documentation for optimizer argument and added support for instance of sub-class of optimizer if user wants to provide optimizer with custom parameters.
Additionally fixed clipping gradients usage. Change: 118995198
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers.py39
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py58
2 files changed, 72 insertions, 25 deletions
diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py
index 7a1aa8ee1a..8c109c5ab9 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -26,6 +28,7 @@ from tensorflow.python.ops import init_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as vars_
+from tensorflow.python.training import optimizer as optimizer_
from tensorflow.python.training import training as train
OPTIMIZER_CLS_NAMES = {
@@ -52,7 +55,13 @@ def optimize_loss(loss,
loss: Tensor, 0 dimensional.
global_step: Tensor, step counter for each update.
learning_rate: float or Tensor, magnitude of update per each training step.
- optimizer: string or function, used as optimizer for training.
+ optimizer: string, class or optimizer instance, used as trainer.
+ string should be name of optimizer, like 'SGD',
+ 'Adam', 'Adagrad'. Full list in OPTIMIZER_CLS_NAMES constant.
+ class should be sub-class of tf.Optimizer that implements
+ `compute_gradients` and `apply_gradients` functions.
+ optimizer instance should be instantion of tf.Optimizer sub-class
+ and have `compute_gradients` and `apply_gradients` functions.
clip_gradients: float or None, clips gradients by this value.
moving_average_decay: float or None, takes into account previous loss
to make learning smoother due to outliers.
@@ -77,14 +86,6 @@ def optimize_loss(loss,
logging_ops.scalar_summary("loss/mean", loss_averages.average(loss))
loss = control_flow_ops.with_dependencies([loss_averages_op], loss)
- # Convert optimizer into the optimizer class.
- if isinstance(optimizer, str):
- opt_cls = OPTIMIZER_CLS_NAMES[optimizer]
- elif callable(optimizer):
- opt_cls = optimizer
- else:
- raise ValueError("Unrecognized optimizer: should be string or function.")
-
# Learning rate variable, with possible decay.
lr = vs.get_variable("learning_rate",
[],
@@ -93,8 +94,21 @@ def optimize_loss(loss,
if learning_rate_decay_fn is not None:
lr = learning_rate_decay_fn(lr, global_step)
- # Create optimizer.
- opt = opt_cls(learning_rate=lr)
+ # Create optimizer, given specified parameters.
+ if isinstance(optimizer, six.string_types):
+ if optimizer not in OPTIMIZER_CLS_NAMES:
+ raise ValueError("Optimizer name should be one of [%s], you provided %s."
+ % (", ".join(OPTIMIZER_CLS_NAMES), optimizer))
+ opt = OPTIMIZER_CLS_NAMES[optimizer](learning_rate=lr)
+ elif isinstance(optimizer, type) and issubclass(optimizer,
+ optimizer_.Optimizer):
+ opt = optimizer(learning_rate=lr)
+ elif isinstance(optimizer, optimizer_.Optimizer):
+ opt = optimizer
+ else:
+ raise ValueError("Unrecognized optimizer: should be string, "
+ "subclass of Optimizer or instance of "
+ "subclass of Optimizer. Got %s." % str(optimizer))
# All trainable variables, if specific variables are not specified.
if variables is None:
@@ -103,9 +117,10 @@ def optimize_loss(loss,
# Compute gradients and clip them if provided.
gradients = opt.compute_gradients(loss, variables)
if clip_gradients is not None:
+ gradients, variables = zip(*gradients)
clipped_gradients, _ = clip_ops.clip_by_global_norm(gradients,
clip_gradients)
- gradients = zip(clipped_gradients, variables)
+ gradients = list(zip(clipped_gradients, variables))
# Add scalar summary for loss.
logging_ops.scalar_summary("loss", loss)
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 57c83fd7ff..6c07bb838d 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -21,28 +21,60 @@ from __future__ import print_function
import tensorflow as tf
+def _setup_model():
+ x = tf.placeholder(tf.float32, [])
+ var = tf.get_variable("test", [], initializer=tf.constant_initializer(10))
+ loss = tf.abs(var * x)
+ global_step = tf.get_variable("global_step",
+ [],
+ trainable=False,
+ initializer=tf.constant_initializer(0))
+ return x, var, loss, global_step
+
+
class OptimizersTest(tf.test.TestCase):
def testSGDOptimizer(self):
+ optimizers = ["SGD", tf.train.GradientDescentOptimizer,
+ tf.train.GradientDescentOptimizer(learning_rate=0.1)]
+ for optimizer in optimizers:
+ with tf.Graph().as_default() as g:
+ with self.test_session(graph=g) as session:
+ x, var, loss, global_step = _setup_model()
+ train = tf.contrib.layers.optimize_loss(loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer=optimizer)
+ tf.initialize_all_variables().run()
+ session.run(train, feed_dict={x: 5})
+ var_value, global_step_value = session.run([var, global_step])
+ self.assertEqual(var_value, 9.5)
+ self.assertEqual(global_step_value, 1)
+
+ def testWrongOptimizer(self):
+ optimizers = ["blah", tf.Variable, object()]
+ for optimizer in optimizers:
+ with tf.Graph().as_default() as g:
+ with self.test_session(graph=g):
+ _, _, loss, global_step = _setup_model()
+ with self.assertRaises(ValueError):
+ tf.contrib.layers.optimize_loss(loss,
+ global_step,
+ learning_rate=0.1,
+ optimizer=optimizer)
+
+ def testGradientClip(self):
with self.test_session() as session:
- x = tf.placeholder(tf.float32, [])
- var = tf.get_variable("test", [], initializer=tf.constant_initializer(10))
- loss = tf.abs(var * x)
- global_step = tf.get_variable("global_step",
- [],
- trainable=False,
- initializer=tf.constant_initializer(0))
- lr_decay = lambda lr, gs: tf.train.exponential_decay(lr, gs, 1, 0.5)
+ x, var, loss, global_step = _setup_model()
train = tf.contrib.layers.optimize_loss(loss,
global_step,
learning_rate=0.1,
- learning_rate_decay_fn=lr_decay,
- optimizer="SGD")
+ optimizer="SGD",
+ clip_gradients=0.1)
tf.initialize_all_variables().run()
session.run(train, feed_dict={x: 5})
- var_value, global_step_value = session.run([
- var, global_step])
- self.assertEqual(var_value, 9.5)
+ var_value, global_step_value = session.run([var, global_step])
+ self.assertAlmostEqual(var_value, 9.98999, 4)
self.assertEqual(global_step_value, 1)