From 516cadc0ba71ca6170f63abd3887f097ac4cb39b Mon Sep 17 00:00:00 2001 From: Mustafa Ispir Date: Fri, 27 May 2016 10:37:09 -0800 Subject: Added bias centering. Change: 123438963 --- .../python/learn/estimators/dnn_linear_combined.py | 41 ++++++++++++++++++---- .../learn/estimators/dnn_linear_combined_test.py | 20 ++++++++++- tensorflow/contrib/learn/python/learn/monitors.py | 18 ++++++++++ 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 1c52aba925..d97a90f60c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -28,7 +28,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib import metrics as metrics_lib -from tensorflow.contrib.framework.python.ops import variables as variables +from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -39,6 +39,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import training # TODO(ispir): Increase test coverage @@ -107,6 +109,7 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): self._dnn_activation_fn = nn.relu self._dnn_weight_collection = "DNNLinearCombined_dnn" self._linear_weight_collection = "DNNLinearCombined_linear" + self._centered_bias_weight_collection = "centered_bias" def predict(self, x=None, input_fn=None, batch_size=None): """Returns predictions for given features. @@ -141,10 +144,12 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): def _get_train_ops(self, features, targets): """See base class.""" - global_step = variables.get_global_step() + global_step = contrib_variables.get_global_step() assert global_step - loss = self._loss( - self._logits(features), targets, self._get_weight_tensor(features)) + logits = self._logits(features) + with ops.control_dependencies([self._centered_bias_step( + targets, self._get_weight_tensor(features))]): + loss = self._loss(logits, targets, self._get_weight_tensor(features)) logging_ops.scalar_summary("loss", loss) linear_vars = self._get_linear_vars() @@ -274,6 +279,26 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): return features return {"": features} + def _centered_bias(self): + centered_bias = variables.Variable( + array_ops.zeros([self._num_label_columns()]), + collections=[self._centered_bias_weight_collection, + ops.GraphKeys.VARIABLES], + name="centered_bias_weight") + # TODO(zakaria): Create summaries for centered_bias + return centered_bias + + def _centered_bias_step(self, targets, weight_tensor): + centered_bias = ops.get_collection(self._centered_bias_weight_collection) + batch_size = array_ops.size(targets) + logits = array_ops.reshape( + array_ops.tile(centered_bias[0], [batch_size]), + [-1, self._num_label_columns()]) + loss = self._loss(logits, targets, weight_tensor) + # Learn central bias by an optimizer. 0.1 is a convervative lr for a single + # variable. + return training.AdagradOptimizer(0.1).minimize(loss, var_list=centered_bias) + def _logits(self, features): if not (self._get_linear_feature_columns() or self._get_dnn_feature_columns()): @@ -282,11 +307,13 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator): features = self._get_feature_dict(features) if self._get_linear_feature_columns() and self._get_dnn_feature_columns(): - return self._linear_logits(features) + self._dnn_logits(features) + logits = self._linear_logits(features) + self._dnn_logits(features) elif self._get_dnn_feature_columns(): - return self._dnn_logits(features) + logits = self._dnn_logits(features) else: - return self._linear_logits(features) + logits = self._linear_logits(features) + + return nn.bias_add(logits, self._centered_bias()) def _get_weight_tensor(self, features): if not self._weight_column_name: diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index a9f7c6b99c..10a6c51704 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -235,7 +235,6 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase): def testPredict(self): """Tests weight column in evaluation.""" - def _input_fn_train(): # Create 4 rows, one of them (y = x), three of them (y=Not(x)) target = tf.constant([[1], [0], [0], [0]]) @@ -285,6 +284,25 @@ class DNNLinearCombinedClassifierTest(tf.test.TestCase): self.assertEqual(_sklearn.accuracy_score([1, 0, 0, 0], predictions), scores['my_accuracy']) + def testCenteredBias(self): + """Tests bias is centered or not.""" + def _input_fn_train(): + # Create 4 rows, three (y = x), one (y=Not(x)) + target = tf.constant([[1], [1], [1], [0]]) + features = {'x': tf.ones(shape=[4, 1], dtype=tf.float32),} + return features, target + + classifier = tf.contrib.learn.DNNLinearCombinedClassifier( + linear_feature_columns=[tf.contrib.layers.real_valued_column('x')], + dnn_feature_columns=[tf.contrib.layers.real_valued_column('x')], + dnn_hidden_units=[3, 3]) + + monitor = tf.contrib.learn.monitors.CaptureVariable( + var_name='centered_bias_weight:0', every_n=10) + classifier.train(input_fn=_input_fn_train, steps=500, monitors=[monitor]) + # logodds(0.75) = 1.09861228867 + self.assertAlmostEqual(1.0986, float(monitor.var_values[-1][0]), places=2) + class DNNLinearCombinedRegressorTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 70561c53f2..1c9fca62c6 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -207,6 +207,24 @@ class ValidationMonitor(EveryN): return False +class CaptureVariable(EveryN): + """Capture a variable value into a `list`. + + It's useful for unit testing. + """ + + def __init__(self, var_name, every_n=100, first_n=1): + super(CaptureVariable, self).__init__(every_n, first_n) + self.var_name = var_name + self.var_values = [] + + def every_n_step_begin(self, unused_step, tensors): + return tensors + [self.var_name] + + def every_n_step_end(self, step, outputs): + self.var_values.append(outputs[self.var_name]) + + def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, output_dir=None): monitors = [] -- cgit v1.2.3