aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-04 15:19:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 15:22:59 -0700
commit18995ecf1a0c4a161b296fbafe63289e90437807 (patch)
tree436d0d26bdd6e4d466e96f6867026949cc72d653 /tensorflow/contrib/estimator
parentd947e2c172b2eee4338e598a51d80d519907f991 (diff)
Adds update_ops to train_op for all heads.
PiperOrigin-RevId: 199203634
Diffstat (limited to 'tensorflow/contrib/estimator')
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py29
3 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 47c7b7fc19..1937ffb583 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -312,6 +312,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:variables",
"//tensorflow/python/estimator:metric_keys",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:prediction_keys",
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 8b97f86db1..b798769d2c 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -845,6 +845,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
train_op = train_op_fn(regularized_training_loss)
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
+ train_op = head_lib._append_update_ops(train_op) # pylint:disable=protected-access
# Only summarize mean_loss for SUM reduction to preserve backwards
# compatibility. Otherwise skip it to avoid unnecessary computation.
if self._loss_reduction == losses.Reduction.SUM:
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index d6c158608b..b2b57fa06b 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
+from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
@@ -989,6 +990,34 @@ class MultiLabelHead(test.TestCase):
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
+ def test_train_with_update_ops(self):
+ head = head_lib.multi_label_head(n_classes=2)
+
+ with ops.Graph().as_default():
+ w = variables.Variable(1)
+ update_op = w.assign_add(1)
+ ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, update_op)
+
+ t = variables.Variable('')
+ expected_train_result = b'my_train_op'
+ def _train_op_fn(loss):
+ del loss
+ return t.assign(expected_train_result)
+
+ spec = head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
+ labels=np.array([[1, 0], [1, 1]], dtype=np.int64),
+ train_op_fn=_train_op_fn)
+
+ with self.test_session() as sess:
+ _initialize_variables(self, spec.scaffold)
+ sess.run(spec.train_op)
+ w_value, t_value = sess.run([w, t])
+ self.assertEqual(2, w_value)
+ self.assertEqual(expected_train_result, t_value)
+
def test_train_with_regularization_losses(self):
head = head_lib.multi_label_head(
n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)