aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-05 20:20:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-05 21:01:24 -0800
commit5fff4f796149d0557c2ee69a333f4aaecdfbb05e (patch)
tree7559aa3abc539723413cf9633e38b715866b0969 /tensorflow/contrib/linear_optimizer/kernels
parent917a662ae15b4f3dbd5eeb1f88b70caced47561f (diff)
Refactor logistic loss into component parts.
Change: 116471151
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/kernels')
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h44
1 files changed, 29 insertions, 15 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h b/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
index 729f55fd55..d75a707820 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
+++ b/tensorflow/contrib/linear_optimizer/kernels/logistic-loss.h
@@ -24,6 +24,29 @@ limitations under the License.
namespace tensorflow {
struct logistic_loss {
+ // Partial derivative of the logistic loss w.r.t (1 + exp(-ywx)).
+ inline static double PartialDerivativeLogisticLoss(const double wx,
+ const double label) {
+ // To avoid overflow, we compute partial derivative of logistic loss as
+ // follows.
+ const double ywx = label * wx;
+ if (ywx > 0) {
+ const double exp_minus_ywx = exp(-ywx);
+ return exp_minus_ywx / (1 + exp_minus_ywx);
+ }
+ return 1 / (1 + exp(ywx));
+ }
+
+ // Smoothness constant for the logistic loss.
+ inline static double SmoothnessConstantLogisticLoss(
+ const double partial_derivative_loss, const double wx,
+ const double label) {
+ // Upper bound on the smoothness constant of log loss. This is 0.25 i.e.
+ // when log-odds is zero.
+ return (wx == 0) ? 0.25
+ : (1 - 2 * partial_derivative_loss) / (2 * label * wx);
+ }
+
// Use an approximate step that is guaranteed to decrease the dual loss.
// Derivation of this is available in Page 14 Eq 16 of
// http://arxiv.org/pdf/1211.2717v1.pdf
@@ -34,23 +57,14 @@ struct logistic_loss {
const double weighted_example_norm,
const double primal_loss,
const double dual_loss) {
- const double ywx = label * wx;
- // To avoid overflow, we compute derivative of logistic loss with respect to
- // log-odds as follows.
- double inverse_exp_term = 0;
- if (ywx > 0) {
- const double exp_minus_ywx = exp(-ywx);
- inverse_exp_term = exp_minus_ywx / (1 + exp_minus_ywx);
- } else {
- inverse_exp_term = 1 / (1 + exp(ywx));
- }
+ const double partial_derivative_loss =
+ PartialDerivativeLogisticLoss(label, wx);
// f(a) = sup (a*x - f(x)) then a = f'(x), where a is the aproximate dual.
- const double approximate_dual = inverse_exp_term * label;
- const double delta_dual = approximate_dual - current_dual;
- // Upper bound on the smoothness constant of log loss. This is 0.25 i.e.
- // when log-odds is zero.
+ const double approximate_dual = partial_derivative_loss * label;
+ // Dual loss is gamma-strongly convex.
const double gamma =
- (wx == 0) ? 0.25 : (1 - 2 * inverse_exp_term) / (2 * ywx);
+ 1 / SmoothnessConstantLogisticLoss(partial_derivative_loss, label, wx);
+ const double delta_dual = approximate_dual - current_dual;
const double wx_dual = wx * current_dual * example_weight;
const double delta_dual_squared = delta_dual * delta_dual;
const double smooth_delta_dual_squared = delta_dual_squared * gamma * 0.5;