diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-02 05:02:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-02 05:06:21 -0800 |
commit | 7b700c515b132c0620e20f12eb032ea3dba397de (patch) | |
tree | 9ad94409ce845734cdd5b320a322da0aa705cb1c /tensorflow/contrib/kfac | |
parent | 12d82a1f53fd57b0ca0990a266121ec29d0d42b7 (diff) |
K-FAC: Support onehot categorical in kfac.loss_functions.
PiperOrigin-RevId: 180536416
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py | 71 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/loss_functions.py | 14 |
3 files changed, 86 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 39ce3e9337..63f45ea55b 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -114,5 +114,76 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): self.assertEqual(loss.num_registered_minibatches, num_towers) +class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2, 3)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index 9be3d60dc0..cd9dca3f02 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -65,6 +65,7 @@ py_library( srcs = ["loss_functions.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index d449abcfa7..2daead2a71 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -22,6 +22,7 @@ import abc import six +from tensorflow.contrib.distributions.python.ops import onehot_categorical from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -785,3 +786,16 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): after[dim] = dim_size - position - 1 return array_ops.pad(slice_to_insert, list(zip(before, after))) + + +class OnehotCategoricalLogitsNegativeLogProbLoss( + CategoricalLogitsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution with onehot targets. + + Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying + distribution is OneHotCategorical as opposed to Categorical. + """ + + @property + def dist(self): + return onehot_categorical.OneHotCategorical(logits=self._logits) |