aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <mustafa.ispir@gmail.com>2016-05-27 10:37:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-27 11:51:21 -0700
commit516cadc0ba71ca6170f63abd3887f097ac4cb39b (patch)
treec0420955fa9c0bb6b6d0b08be658abf70a23472d
parente753383886b84900e4dcfb5850ff8e1c14af1372 (diff)
Added bias centering.
Change: 123438963
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py41
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py20
-rw-r--r--tensorflow/contrib/learn/python/learn/monitors.py18
3 files changed, 71 insertions, 8 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 1c52aba925..d97a90f60c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -28,7 +28,7 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib import metrics as metrics_lib
-from tensorflow.contrib.framework.python.ops import variables as variables
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -39,6 +39,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import training
# TODO(ispir): Increase test coverage
@@ -107,6 +109,7 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
self._dnn_activation_fn = nn.relu
self._dnn_weight_collection = "DNNLinearCombined_dnn"
self._linear_weight_collection = "DNNLinearCombined_linear"
+ self._centered_bias_weight_collection = "centered_bias"
def predict(self, x=None, input_fn=None, batch_size=None):
"""Returns predictions for given features.
@@ -141,10 +144,12 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
def _get_train_ops(self, features, targets):
"""See base class."""
- global_step = variables.get_global_step()
+ global_step = contrib_variables.get_global_step()
assert global_step
- loss = self._loss(
- self._logits(features), targets, self._get_weight_tensor(features))
+ logits = self._logits(features)
+ with ops.control_dependencies([self._centered_bias_step(
+ targets, self._get_weight_tensor(features))]):
+ loss = self._loss(logits, targets, self._get_weight_tensor(features))
logging_ops.scalar_summary("loss", loss)
linear_vars = self._get_linear_vars()
@@ -274,6 +279,26 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
return features
return {"": features}
+ def _centered_bias(self):
+ centered_bias = variables.Variable(
+ array_ops.zeros([self._num_label_columns()]),
+ collections=[self._centered_bias_weight_collection,
+ ops.GraphKeys.VARIABLES],
+ name="centered_bias_weight")
+ # TODO(zakaria): Create summaries for centered_bias
+ return centered_bias
+
+ def _centered_bias_step(self, targets, weight_tensor):
+ centered_bias = ops.get_collection(self._centered_bias_weight_collection)
+ batch_size = array_ops.size(targets)
+ logits = array_ops.reshape(
+ array_ops.tile(centered_bias[0], [batch_size]),
+ [-1, self._num_label_columns()])
+ loss = self._loss(logits, targets, weight_tensor)
+ # Learn central bias by an optimizer. 0.1 is a convervative lr for a single
+ # variable.
+ return training.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias)
+
def _logits(self, features):
if not (self._get_linear_feature_columns() or
self._get_dnn_feature_columns()):
@@ -282,11 +307,13 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
features = self._get_feature_dict(features)
if self._get_linear_feature_columns() and self._get_dnn_feature_columns():
- return self._linear_logits(features) + self._dnn_logits(features)
+ logits = self._linear_logits(features) + self._dnn_logits(features)
elif self._get_dnn_feature_columns():
- return self._dnn_logits(features)
+ logits = self._dnn_logits(features)
else:
- return self._linear_logits(features)
+ logits = self._linear_logits(features)
+
+ return nn.bias_add(logits, self._centered_bias())
def _get_weight_tensor(self, features):
if not self._weight_column_name:
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index a9f7c6b99c..10a6c51704 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -235,7 +235,6 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
def testPredict(self):
"""Tests weight column in evaluation."""
-
def _input_fn_train():
# Create 4 rows, one of them (y = x), three of them (y=Not(x))
target = tf.constant([[1], [0], [0], [0]])
@@ -285,6 +284,25 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase):
self.assertEqual(_sklearn.accuracy_score([1, 0, 0, 0], predictions),
scores['my_accuracy'])
+ def testCenteredBias(self):
+ """Tests bias is centered or not."""
+ def _input_fn_train():
+ # Create 4 rows, three (y = x), one (y=Not(x))
+ target = tf.constant([[1], [1], [1], [0]])
+ features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),}
+ return features, target
+
+ classifier = tf.contrib.learn.DNNLinearCombinedClassifier(
+ linear_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+ dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')],
+ dnn_hidden_units=[3, 3])
+
+ monitor = tf.contrib.learn.monitors.CaptureVariable(
+ var_name='centered_bias_weight:0', every_n=10)
+ classifier.train(input_fn=_input_fn_train, steps=500, monitors=[monitor])
+ # logodds(0.75) = 1.09861228867
+ self.assertAlmostEqual(1.0986, float(monitor.var_values[-1][0]), places=2)
+
class DNNLinearCombinedRegressorTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py
index 70561c53f2..1c9fca62c6 100644
--- a/tensorflow/contrib/learn/python/learn/monitors.py
+++ b/tensorflow/contrib/learn/python/learn/monitors.py
@@ -207,6 +207,24 @@ class ValidationMonitor(EveryN):
return False
+class CaptureVariable(EveryN):
+ """Capture a variable value into a `list`.
+
+ It's useful for unit testing.
+ """
+
+ def __init__(self, var_name, every_n=100, first_n=1):
+ super(CaptureVariable, self).__init__(every_n, first_n)
+ self.var_name = var_name
+ self.var_values = []
+
+ def every_n_step_begin(self, unused_step, tensors):
+ return tensors + [self.var_name]
+
+ def every_n_step_end(self, step, outputs):
+ self.var_values.append(outputs[self.var_name])
+
+
def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100,
output_dir=None):
monitors = []