diff options
author | 2018-08-31 11:30:49 -0700 | |
---|---|---|
committer | 2018-08-31 11:35:21 -0700 | |
commit | e894ca7c736c58a8e4c71f0c3f1b1f0c327fa924 (patch) | |
tree | ed480e9041bebac1e5dd2583d56f498c8644ab68 /tensorflow/core/kernels/loss_test.cc | |
parent | 86ed8fada295758705a96a7390802eb4f6303641 (diff) |
Add the poisson log loss to the SDCA optimizer.
PiperOrigin-RevId: 211116606
Diffstat (limited to 'tensorflow/core/kernels/loss_test.cc')
-rw-r--r-- | tensorflow/core/kernels/loss_test.cc | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc index 6ab0ce5edb..9209ed2ab7 100644 --- a/tensorflow/core/kernels/loss_test.cc +++ b/tensorflow/core/kernels/loss_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/hinge-loss.h" #include "tensorflow/core/kernels/logistic-loss.h" +#include "tensorflow/core/kernels/poisson-loss.h" #include "tensorflow/core/kernels/smooth-hinge-loss.h" #include "tensorflow/core/kernels/squared-loss.h" #include "tensorflow/core/lib/core/errors.h" @@ -288,5 +289,68 @@ TEST(SmoothHingeLoss, ComputeUpdatedDual) { 0.8 /* wx */, 10.0 /* weighted_example_norm */); } +TEST(PoissonLoss, ComputePrimalLoss) { + PoissonLossUpdater loss_updater; + EXPECT_NEAR(1.0, + loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR(21996.0, + loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */, + 1.0 /* example weight */), + 1.0); + EXPECT_NEAR(0.606, + loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR(6.64, + loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */, + 2.0 /* example weight */), + 1e-2); +} + +TEST(PoissonLoss, ComputeDualLoss) { + PoissonLossUpdater loss_updater; + // Dual is undefined. + EXPECT_NEAR( + std::numeric_limits<double>::max(), + loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + 0.0, + loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */, + 3.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + -0.847, + loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + -2.675, + loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */, + 3.0 /* example weight */), + 1e-3); +} + +TEST(PoissonLoss, ConvertLabel) { + PoissonLossUpdater loss_updater; + float example_label = -1.0; + // Negative label should throw an error. + Status status = loss_updater.ConvertLabel(&example_label); + EXPECT_FALSE(status.ok()); +} + +TEST(PoissonLoss, ComputeUpdatedDual) { + PoissonLossUpdater loss_updater; + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */, + 1.0 /* example weight */, 0.5 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); + TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */, + 1.0 /* example weight */, 0.0 /* current_dual */, + -0.8 /* wx */, 10.0 /* weighted_example_norm */); +} + } // namespace } // namespace tensorflow |