diff options
author | 2018-06-04 15:19:39 -0700 | |
---|---|---|
committer | 2018-06-04 15:22:59 -0700 | |
commit | 18995ecf1a0c4a161b296fbafe63289e90437807 (patch) | |
tree | 436d0d26bdd6e4d466e96f6867026949cc72d653 /tensorflow/contrib/estimator | |
parent | d947e2c172b2eee4338e598a51d80d519907f991 (diff) |
Adds update_ops to train_op for all heads.
PiperOrigin-RevId: 199203634
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r-- | tensorflow/contrib/estimator/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/head_test.py | 29 |
3 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 47c7b7fc19..1937ffb583 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -312,6 +312,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", + "//tensorflow/python:variables", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 8b97f86db1..b798769d2c 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -845,6 +845,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access train_op = train_op_fn(regularized_training_loss) else: raise ValueError('train_op_fn and optimizer cannot both be None.') + train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access # Only summarize mean_loss for SUM reduction to preserve backwards # compatibility. Otherwise skip it to avoid unnecessary computation. if self._loss_reduction == losses.Reduction.SUM: diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index d6c158608b..b2b57fa06b 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -989,6 +990,34 @@ class MultiLabelHead(test.TestCase): six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), train_result) + def test_train_with_update_ops(self): + head = head_lib.multi_label_head(n_classes=2) + + with ops.Graph().as_default(): + w = variables.Variable(1) + update_op = w.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op) + + t = variables.Variable('') + expected_train_result = b'my_train_op' + def _train_op_fn(loss): + del loss + return t.assign(expected_train_result) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=np.array([[1, 0], [1, 1]], dtype=np.int64), + train_op_fn=_train_op_fn) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + sess.run(spec.train_op) + w_value, t_value = sess.run([w, t]) + self.assertEqual(2, w_value) + self.assertEqual(expected_train_result, t_value) + def test_train_with_regularization_losses(self): head = head_lib.multi_label_head( n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) |