aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-08 17:07:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-08 17:21:30 -0800
commite48c70ba22dce43fa474c0daeee15019a98a1e3c (patch)
treec233a099c355a27dc473e980ebc29d7cb33335da
parentd5944fe3725a0e1194bf76c1a29422cbe6c66f29 (diff)
Fixes computation of SdcaModel.regularized_loss()
Change: 116714987
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py120
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py10
2 files changed, 78 insertions, 52 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 2bfbc6c329..13968457f7 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -153,8 +153,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='logistic_loss',
- prior=0.0)
+ loss_type='logistic_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
@@ -165,10 +164,15 @@ class SdcaOptimizerTest(TensorFlowTestCase):
self.assertAllClose(0.693147, loss.eval())
for _ in xrange(5):
lr.minimize().run()
- self.assertAllClose(0.395226, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.657446, loss.eval(),
- rtol=3e-2, atol=3e-2)
+ # The high tolerance in unregularized_loss comparisons is due to the
+ # fact that it's possible to trade off unregularized_loss vs.
+ # regularization and still have a sum that is quite close to the
+ # optimal regularized_loss value. SDCA's duality gap only ensures that
+ # the regularized_loss is within 0.01 of optimal.
+ # 0.525457 is the optimal regularized_loss.
+ # 0.411608 is the unregularized_loss at that optimum.
+ self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.11)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 1], predicted_labels.eval())
self.assertAllClose(0.01,
@@ -213,10 +217,8 @@ class SdcaOptimizerTest(TensorFlowTestCase):
predictions = lr.predictions(examples)
for _ in xrange(5):
lr.minimize().run()
- self.assertAllClose(0.395226, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.657446, loss.eval(),
- rtol=3e-2, atol=3e-2)
+ self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.12)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllClose([0, 1, 1, 1], predicted_labels.eval())
self.assertAllClose(0.01,
@@ -324,8 +326,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(3, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='logistic_loss',
- prior=-1.09861)
+ loss_type='logistic_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
@@ -334,9 +335,10 @@ class SdcaOptimizerTest(TensorFlowTestCase):
predictions = lr.predictions(examples)
for _ in xrange(5):
lr.minimize().run()
- self.assertAllClose(0.331710, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.591295, loss.eval(), rtol=3e-2, atol=3e-2)
+ self.assertAllClose(0.226487 + 0.102902,
+ unregularized_loss.eval(),
+ rtol=0.08)
+ self.assertAllClose(0.328394 + 0.131364, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 0, 0, 1], predicted_labels.eval())
self.assertAllClose(0.01,
@@ -369,9 +371,8 @@ class SdcaOptimizerTest(TensorFlowTestCase):
predictions = lr.predictions(examples)
for _ in xrange(5):
lr.minimize().run()
- self.assertAllClose(0.266189, unregularized_loss.eval(),
- rtol=3e-2, atol=3e-2)
- self.assertAllClose(0.571912, loss.eval(), rtol=3e-2, atol=3e-2)
+ self.assertAllClose(0.284860, unregularized_loss.eval(), rtol=0.08)
+ self.assertAllClose(0.408044, loss.eval(), atol=0.012)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 1], predicted_labels.eval())
self.assertAllClose(0.01,
@@ -393,7 +394,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights)
variables = make_variable_dict(1, 1)
- options = dict(symmetric_l2_regularization=0.25,
+ options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
loss_type='logistic_loss')
@@ -404,11 +405,8 @@ class SdcaOptimizerTest(TensorFlowTestCase):
predictions = lr.predictions(examples)
for _ in xrange(5):
lr.minimize().run()
- self.assertAllClose(0.395226,
- unregularized_loss.eval(),
- rtol=3e-2,
- atol=3e-2)
- self.assertAllClose(0.460781, loss.eval(), rtol=3e-2, atol=3e-2)
+ self.assertAllClose(0.411608, unregularized_loss.eval(), rtol=0.12)
+ self.assertAllClose(0.525457, loss.eval(), atol=0.01)
predicted_labels = get_binary_predictions_for_logistic(predictions)
self.assertAllEqual([0, 0], predicted_labels.eval())
self.assertAllClose(0.01,
@@ -432,8 +430,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
+ loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
@@ -452,7 +449,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
rtol=1e-2,
atol=1e-2)
- def testLinearRegularization(self):
+ def testLinearL2Regularization(self):
# Setup test data
example_protos = [
# 2 identical examples
@@ -476,8 +473,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=16,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
+ loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
@@ -495,6 +491,39 @@ class SdcaOptimizerTest(TensorFlowTestCase):
predictions.eval(),
rtol=0.01)
+ def testLinearL1Regularization(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto(
+ {'age': [0],
+ 'gender': [0]}, -10.0),
+ make_example_proto(
+ {'age': [1],
+ 'gender': [1]}, 14.0),
+ ]
+ example_weights = [1.0, 1.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(symmetric_l2_regularization=1.0,
+ symmetric_l1_regularization=4.0,
+ loss_type='squared_loss')
+ lr = SdcaModel(CONTAINER, examples, variables, options)
+ tf.initialize_all_variables().run()
+ prediction = lr.predictions(examples)
+ loss = lr.regularized_loss(examples)
+
+ for _ in xrange(5):
+ lr.minimize().run()
+
+ # Predictions should be -4.0, 48/5 due to minimizing regularized loss:
+ # (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2 + L1 * 4 * weight
+ self.assertAllClose([-4.0, 20.0 / 3.0], prediction.eval(), rtol=0.08)
+
+ # Loss should be the sum of the regularized loss value from above per
+ # example after plugging in the optimal weights.
+ self.assertAllClose(308.0 / 6.0, loss.eval(), atol=0.01)
+
def testLinearFeatureValues(self):
# Setup test data
example_protos = [
@@ -512,8 +541,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
+ loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
@@ -537,8 +565,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_dense_variable_dict(2, 2)
options = dict(symmetric_l2_regularization=1,
symmetric_l1_regularization=0,
- loss_type='squared_loss',
- prior=0.0)
+ loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
predictions = lr.predictions(examples)
@@ -553,10 +580,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
rtol=0.01)
loss = lr.regularized_loss(examples)
- self.assertAllClose(
- (4.0 + 7.84 + 16.0 + 31.36) / 2,
- loss.eval(),
- rtol=0.01)
+ self.assertAllClose(148.0 / 10.0, loss.eval(), atol=0.01)
def testSimpleHinge(self):
# Setup test data
@@ -574,8 +598,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
variables = make_variable_dict(1, 1)
options = dict(symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0,
- loss_type='hinge_loss',
- prior=0.0)
+ loss_type='hinge_loss')
model = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
@@ -592,7 +615,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
# are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender (say w3
# and w4). Solving the system w1 + w3 = 1.0, w2 + w4 = -1.0 and minimizing
# wrt to \|\vec{w}\|_2, gives w1=w3=1/2 and w2=w4=-1/2. This gives 0.0
- # unregularized loss and 0.5 L2 loss.
+ # unregularized loss and 0.25 L2 loss.
for _ in xrange(5):
model.minimize().run()
@@ -600,7 +623,7 @@ class SdcaOptimizerTest(TensorFlowTestCase):
self.assertAllEqual([-1.0, 1.0], predictions.eval())
self.assertAllEqual([0.0, 1.0], binary_predictions.eval())
self.assertAllClose(0.0, unregularized_loss.eval())
- self.assertAllClose(0.5, regularized_loss.eval(), atol=0.05)
+ self.assertAllClose(0.25, regularized_loss.eval(), atol=0.05)
def testHingeDenseFeaturesPerfectlySeparable(self):
with self._single_threaded_test_session():
@@ -626,11 +649,11 @@ class SdcaOptimizerTest(TensorFlowTestCase):
# (1.0, 1.0) and (1.0, -1.0) are perfectly separable by x-axis (that is,
# the SVM's functional margin >=1), so the unregularized loss is ~0.0.
# There is only loss due to l2-regularization. For these datapoints, it
- # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.5.
+ # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.25.
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
self.assertAllClose(0.0, unregularized_loss.eval(), atol=0.02)
- self.assertAllClose(0.5, regularized_loss.eval(), atol=0.02)
+ self.assertAllClose(0.25, regularized_loss.eval(), atol=0.02)
def testHingeDenseFeaturesSeparableWithinMargins(self):
with self._single_threaded_test_session():
@@ -653,13 +676,13 @@ class SdcaOptimizerTest(TensorFlowTestCase):
# (1.0, 0.5) and (1.0, -0.5) are separable by x-axis but the datapoints
# are within the margins so there is unregularized loss (1/2 per example).
# For these datapoints, optimal weights are w_1~=0.0 and w_2~=1.0 which
- # gives an L2 loss of ~0.5.
+ # gives an L2 loss of ~0.25.
self.assertAllClose([0.5, -0.5], predictions.eval(), rtol=0.05)
self.assertAllClose([1.0, 0.0], binary_predictions.eval())
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
self.assertAllClose(0.5, unregularized_loss.eval(), atol=0.02)
- self.assertAllClose(1.0, regularized_loss.eval(), atol=0.02)
+ self.assertAllClose(0.75, regularized_loss.eval(), atol=0.02)
def testHingeDenseFeaturesWeightedExamples(self):
with self._single_threaded_test_session():
@@ -682,14 +705,15 @@ class SdcaOptimizerTest(TensorFlowTestCase):
# try to increase the margin from (1.0, 0.5). Due to regularization,
# (1.0, -0.5) will be within the margin. For these points and example
# weights, the optimal weights are w_1~=0.4 and w_2~=1.2 which give an L2
- # loss of 0.25 * 1.6 = 0.4. The binary predictions will be correct, but
- # the boundary will be much closer to the 2nd point than the first one.
+ # loss of 0.5 * 0.25 * 0.25 * 1.6 = 0.2. The binary predictions will be
+ # correct, but the boundary will be much closer to the 2nd point than the
+ # first one.
self.assertAllClose([1.0, -0.2], predictions.eval(), atol=0.05)
self.assertAllClose([1.0, 0.0], binary_predictions.eval(), atol=0.05)
unregularized_loss = model.unregularized_loss(examples)
regularized_loss = model.regularized_loss(examples)
self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02)
- self.assertAllClose(0.6, regularized_loss.eval(), atol=0.02)
+ self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02)
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index b83be34dca..957a734b07 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -182,7 +182,7 @@ class SdcaModel(object):
dense_weights = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
l1 = self._options['symmetric_l1_regularization']
- loss = 0
+ loss = 0.0
for w in sparse_weights:
loss += l1 * math_ops.reduce_sum(abs(w))
for w in dense_weights:
@@ -197,12 +197,13 @@ class SdcaModel(object):
dense_weights = self._convert_n_to_tensor(self._variables[
'dense_features_weights'])
l2 = self._options['symmetric_l2_regularization']
- loss = 0
+ loss = 0.0
for w in sparse_weights:
loss += l2 * math_ops.reduce_sum(math_ops.square(w))
for w in dense_weights:
loss += l2 * math_ops.reduce_sum(math_ops.square(w))
- return loss
+ # SDCA L2 regularization cost is 1/2 * l2 * sum(weights^2)
+ return loss / 2.0
def _convert_n_to_tensor(self, input_list, as_ref=False):
"""Converts input list to a set of tensors."""
@@ -361,8 +362,9 @@ class SdcaModel(object):
err = math_ops.sub(labels, predictions)
weighted_squared_err = math_ops.mul(math_ops.square(err), weights)
+ # SDCA squared loss function is sum(err^2) / (2*sum(weights))
return (math_ops.reduce_sum(weighted_squared_err) /
- math_ops.reduce_sum(weights))
+ (2.0 * math_ops.reduce_sum(weights)))
def regularized_loss(self, examples):
"""Add operations to compute the loss with regularization loss included.