diff options
author | 2016-12-01 16:38:25 -0800 | |
---|---|---|
committer | 2016-12-01 16:45:29 -0800 | |
commit | 6e09c0e435d5c56763f4a84c7d5b8d54ac59ef15 (patch) | |
tree | b6236446548ca119b801f9bb4b96068349c2c79d | |
parent | 7fedd59196530c7a50b0fd3c94e372e29ad9c850 (diff) |
Revert centered_bias_weight/Adagrad var name to backward-compatible name.
Change: 140794371
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head.py | 17 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/head_test.py | 8 |
2 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index b0783dacef..60515a3fc6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -880,10 +880,10 @@ def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn): with ops.name_scope(None, "centered_bias", (labels, logits)): centered_bias_loss = math_ops.reduce_mean( loss_fn(logits, labels), name="training_loss") - # Learn central bias by an optimizer. 0.1 is a convervative lr for a - # single variable. - return training.AdagradOptimizer(0.1).minimize( - centered_bias_loss, var_list=(centered_bias,), name=name) + # Learn central bias by an optimizer. 0.1 is a convervative lr for a + # single variable. + return training.AdagradOptimizer(0.1).minimize( + centered_bias_loss, var_list=(centered_bias,), name=name) def _head_prefixed(head_name, val): @@ -930,11 +930,14 @@ def _train_op( loss, labels, train_op_fn, centered_bias=None, logits_dimension=None, loss_fn=None): """Returns op for the training step.""" + if centered_bias is not None: + centered_bias_step = _centered_bias_step( + centered_bias, logits_dimension, labels, loss_fn) + else: + centered_bias_step = None with ops.name_scope(None, "train_op", (loss, labels)): train_op = train_op_fn(loss) - if centered_bias is not None: - centered_bias_step = _centered_bias_step( - centered_bias, logits_dimension, labels, loss_fn) + if centered_bias_step is not None: train_op = control_flow_ops.group(train_op, centered_bias_step) return train_op diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 5fa7e745bd..40eb7d17de 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -113,7 +113,7 @@ class RegressionModelHeadTest(tf.test.TestCase): self._assert_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", - "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + "centered_bias_weight/Adagrad:0", ), expected_trainable=( "centered_bias_weight:0", )) @@ -202,7 +202,7 @@ class MultiLabelModelHeadTest(tf.test.TestCase): self._assert_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", - "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + "centered_bias_weight/Adagrad:0", ), expected_trainable=( "centered_bias_weight:0", )) @@ -307,7 +307,7 @@ class MultiClassModelHeadTest(tf.test.TestCase): self._assert_binary_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", - "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + "centered_bias_weight/Adagrad:0", ), expected_trainable=( "centered_bias_weight:0", )) @@ -444,7 +444,7 @@ class BinarySvmModelHeadTest(tf.test.TestCase): self._assert_metrics(model_fn_ops) _assert_variables(self, expected_global=( "centered_bias_weight:0", - "train_op/centered_bias_step/centered_bias_weight/Adagrad:0", + "centered_bias_weight/Adagrad:0", ), expected_trainable=( "centered_bias_weight:0", )) |