aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-01 16:38:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 16:45:29 -0800
commit6e09c0e435d5c56763f4a84c7d5b8d54ac59ef15 (patch)
treeb6236446548ca119b801f9bb4b96068349c2c79d
parent7fedd59196530c7a50b0fd3c94e372e29ad9c850 (diff)
Revert centered_bias_weight/Adagrad var name to backward-compatible name.
Change: 140794371
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head_test.py8
2 files changed, 14 insertions, 11 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index b0783dacef..60515a3fc6 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -880,10 +880,10 @@ def _centered_bias_step(centered_bias, logits_dimension, labels, loss_fn):
with ops.name_scope(None, "centered_bias", (labels, logits)):
centered_bias_loss = math_ops.reduce_mean(
loss_fn(logits, labels), name="training_loss")
- # Learn central bias by an optimizer. 0.1 is a convervative lr for a
- # single variable.
- return training.AdagradOptimizer(0.1).minimize(
- centered_bias_loss, var_list=(centered_bias,), name=name)
+ # Learn central bias by an optimizer. 0.1 is a convervative lr for a
+ # single variable.
+ return training.AdagradOptimizer(0.1).minimize(
+ centered_bias_loss, var_list=(centered_bias,), name=name)
def _head_prefixed(head_name, val):
@@ -930,11 +930,14 @@ def _train_op(
loss, labels, train_op_fn, centered_bias=None, logits_dimension=None,
loss_fn=None):
"""Returns op for the training step."""
+ if centered_bias is not None:
+ centered_bias_step = _centered_bias_step(
+ centered_bias, logits_dimension, labels, loss_fn)
+ else:
+ centered_bias_step = None
with ops.name_scope(None, "train_op", (loss, labels)):
train_op = train_op_fn(loss)
- if centered_bias is not None:
- centered_bias_step = _centered_bias_step(
- centered_bias, logits_dimension, labels, loss_fn)
+ if centered_bias_step is not None:
train_op = control_flow_ops.group(train_op, centered_bias_step)
return train_op
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
index 5fa7e745bd..40eb7d17de 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py
@@ -113,7 +113,7 @@ class RegressionModelHeadTest(tf.test.TestCase):
self._assert_metrics(model_fn_ops)
_assert_variables(self, expected_global=(
"centered_bias_weight:0",
- "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ "centered_bias_weight/Adagrad:0",
), expected_trainable=(
"centered_bias_weight:0",
))
@@ -202,7 +202,7 @@ class MultiLabelModelHeadTest(tf.test.TestCase):
self._assert_metrics(model_fn_ops)
_assert_variables(self, expected_global=(
"centered_bias_weight:0",
- "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ "centered_bias_weight/Adagrad:0",
), expected_trainable=(
"centered_bias_weight:0",
))
@@ -307,7 +307,7 @@ class MultiClassModelHeadTest(tf.test.TestCase):
self._assert_binary_metrics(model_fn_ops)
_assert_variables(self, expected_global=(
"centered_bias_weight:0",
- "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ "centered_bias_weight/Adagrad:0",
), expected_trainable=(
"centered_bias_weight:0",
))
@@ -444,7 +444,7 @@ class BinarySvmModelHeadTest(tf.test.TestCase):
self._assert_metrics(model_fn_ops)
_assert_variables(self, expected_global=(
"centered_bias_weight:0",
- "train_op/centered_bias_step/centered_bias_weight/Adagrad:0",
+ "centered_bias_weight/Adagrad:0",
), expected_trainable=(
"centered_bias_weight:0",
))