diff options
Diffstat (limited to 'tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h')
-rw-r--r-- | tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h | 46 |
1 files changed, 32 insertions, 14 deletions
diff --git a/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h b/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h index 877fceeeb4..3655fa707e 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h +++ b/tensorflow/contrib/linear_optimizer/kernels/hinge-loss.h @@ -19,8 +19,13 @@ limitations under the License. #include <algorithm> #include <cmath> +#include "tensorflow/contrib/linear_optimizer/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" + namespace tensorflow { -struct hinge_loss { + +class HingeLossUpdater : public DualLossUpdater { + public: // Computes the updated dual variable (corresponding) to a single example. The // updated dual value maximizes the objective function of the dual // optimization problem associated with hinge loss (conditioned on keeping the @@ -30,13 +35,11 @@ struct hinge_loss { // and the particular form of conjugate function for hinge loss. // TODO(pmol): Write up a doc with concrete derivation and point to it from // here. - inline static double ComputeUpdatedDual(const double label, - const double example_weight, - const double current_dual, - const double wx, - const double weighted_example_norm, - const double primal_loss, - const double dual_loss) { + double ComputeUpdatedDual(const double label, const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm, + const double primal_loss, + const double dual_loss) const final { // Intutitvely there are 3 cases: // a. new optimal value of the dual variable falls withing the admissible // range [0, 1]. In this case we set new dual to this value. @@ -65,9 +68,8 @@ struct hinge_loss { // on its label. In particular: // \phi_y*(z) = y*z if y*z \in [-w, 0] and +infinity everywhere else where // y \in {-1,1}. The following method implements \phi_y*(-\alpha/w). - inline static double ComputeDualLoss(const double current_dual, - const double example_label, - const double example_weight) { + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { // For binary classification, there are 2 conjugate functions, one per // label value (-1 and 1). const double y_alpha = current_dual * example_label; // y \alpha @@ -80,13 +82,29 @@ struct hinge_loss { // Hinge loss for binary classification for a single example. Hinge loss // equals max(0, 1 - y * wx) (see https://en.wikipedia.org/wiki/Hinge_loss). // For weighted instances loss should be multiplied by the instance weight. - inline static double ComputePrimalLoss(const double wx, - const double example_label, - const double example_weight) { + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { const double y_wx = example_label * wx; return std::max(0.0, 1 - y_wx) * example_weight; } + + // Converts binary example labels from 0.0 or 1.0 to -1.0 or 1.0 respectively + // as expected by hinge loss. + Status ConvertLabel(float* const example_label) const final { + if (*example_label == 0.0) { + *example_label = -1; + return Status::OK(); + } + if (*example_label == 1.0) { + return Status::OK(); + } + return errors::InvalidArgument( + "Only labels of 0.0 or 1.0 are supported right now. " + "Found example with label: ", + *example_label); + } }; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LINEAR_OPTIMIZER_KERNELS_HINGE_LOSS_H_ |