aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-31 11:30:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-31 11:35:21 -0700
commite894ca7c736c58a8e4c71f0c3f1b1f0c327fa924 (patch)
treeed480e9041bebac1e5dd2583d56f498c8644ab68 /tensorflow/contrib/linear_optimizer
parent86ed8fada295758705a96a7390802eb4f6303641 (diff)
Add the poisson log loss to the SDCA optimizer.
PiperOrigin-RevId: 211116606
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md40
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py51
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py14
3 files changed, 104 insertions, 1 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
index a4f5086dde..5fe883d647 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
+++ b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
@@ -199,6 +199,46 @@ does.
However, in practice, convergence with $$x_0 = 0$$ always happens (tested for a
sample of generic values for the parameters).
+### Poisson log loss
+
+Poisson log loss is defined as $$ \l(u) = e^u - uy $$ for label $$y \geq 0.$$
+Its dual is
+
+$$ \l^\star(v) = (y+v) (\log(y+v) - 1) $$
+
+and is only defined for $$ y+v > 0 $$. We then have the constraint
+
+$$ y > \a+\d. $$
+
+The dual is
+
+$$ D(\d) = -(y-\a-\d) (\log(y-\a-\d) - 1) - \bar{y} \d - \frac{A}{2} \d^2 $$
+
+and its derivative is,
+
+$$ D'(\d) = \log(y-\a-\d) - \bar{y} - A\d $$
+
+Similar to the logistic loss, we perform a change of variable to handle the
+constraint on $$ \d $$
+
+$$ y - (\a+\d) = e^x $$
+
+After this change of variable, the goal is to find the zero of this function
+
+$$ H(x) = x - \bar{y} -A(y-\a-e^x) $$
+
+whose first derivative is
+
+$$ H'(x) = 1+Ae^x $$
+
+Since this function is always positive, $$H$$ is increasing and has a unique
+zero.
+
+We can start Newton algorithm at $$\d=0$$ which corresponds to $$ x =
+\log(y-\a)$$. As before the Newton step is given by
+
+$$x_{k+1} = x_k - \frac{H(x_k)}{H'(x_k)}. $$
+
### References
[1] C. Ma et al., Adding vs. Averaging in Distributed Primal-Dual Optimization,
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index ef0e08a777..1d2db1cec8 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -1192,6 +1192,57 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest):
self.assertAllClose(0.33, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)
+class SdcaWithPoissonLossTest(SdcaModelTest):
+ """SDCA optimizer test class for poisson loss."""
+
+ def testSimple(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [1],
+ 'gender': [1]
+ }, 2),
+ ]
+ example_weights = [100.0, 100.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(
+ symmetric_l2_regularization=1.0,
+ symmetric_l1_regularization=0,
+ loss_type='poisson_loss')
+ model = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+
+ # Before minimization, the weights default to zero. There is no loss due
+ # to regularization, only unregularized loss which is 1 for each example.
+ predictions = model.predictions(examples)
+ self.assertAllClose([1.0, 1.0], predictions.eval())
+ unregularized_loss = model.unregularized_loss(examples)
+ regularized_loss = model.regularized_loss(examples)
+ approximate_duality_gap = model.approximate_duality_gap()
+ self.assertAllClose(1.0, unregularized_loss.eval())
+ self.assertAllClose(1.0, regularized_loss.eval())
+
+ # There are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender
+ # (say w3 and w4). The minimization leads to:
+ # w1=w3=-1.96487, argmin of 100*(exp(2*w)-2*w*0)+w**2.
+ # w2=w4=0.345708, argmin of 100*(exp(2*w)-2*w*2)+w**2.
+ # This gives an unregularized loss of .3167 and .3366 with regularization.
+ train_op = model.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ model.update_weights(train_op).run()
+
+ self.assertAllClose([0.0196, 1.9965], predictions.eval(), atol=1e-4)
+ self.assertAllClose(0.3167, unregularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0.3366, regularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0., approximate_duality_gap.eval(), atol=1e-6)
+
class SdcaFprintTest(SdcaModelTest):
"""Tests for the SdcaFprint op.
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 0047d5753a..14f59a3f64 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
+from tensorflow.python.ops.nn import log_poisson_loss
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.summary import summary
@@ -51,6 +52,7 @@ class SdcaModel(object):
* Squared loss
* Hinge loss
* Smooth hinge loss
+ * Poisson log loss
This class defines an optimizer API to train a linear model.
@@ -112,7 +114,7 @@ class SdcaModel(object):
raise ValueError('examples, variables and options must all be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
- 'smooth_hinge_loss')
+ 'smooth_hinge_loss', 'poisson_loss')
if options['loss_type'] not in supported_losses:
raise ValueError('Unsupported loss_type: ', options['loss_type'])
@@ -315,6 +317,7 @@ class SdcaModel(object):
"""Add operations to compute predictions by the model.
If logistic_loss is being used, predicted probabilities are returned.
+ If poisson_loss is being used, predictions are exponentiated.
Otherwise, (raw) linear predictions (w*x) are returned.
Args:
@@ -335,6 +338,10 @@ class SdcaModel(object):
# Convert logits to probability for logistic loss predictions.
with name_scope('sdca/logistic_prediction'):
result = math_ops.sigmoid(result)
+ elif self._options['loss_type'] == 'poisson_loss':
+ # Exponeniate the prediction for poisson loss predictions.
+ with name_scope('sdca/poisson_prediction'):
+ result = math_ops.exp(result)
return result
def _get_partitioned_update_ops(self,
@@ -624,6 +631,11 @@ class SdcaModel(object):
logits=predictions),
weights)) / math_ops.reduce_sum(weights)
+ if self._options['loss_type'] == 'poisson_loss':
+ return math_ops.reduce_sum(math_ops.multiply(
+ log_poisson_loss(targets=labels, log_input=predictions),
+ weights)) / math_ops.reduce_sum(weights)
+
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
# first convert 0/1 labels into -1/1 labels.