aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-02 05:02:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-02 05:06:21 -0800
commit7b700c515b132c0620e20f12eb032ea3dba397de (patch)
tree9ad94409ce845734cdd5b320a322da0aa705cb1c /tensorflow/contrib/kfac
parent12d82a1f53fd57b0ca0990a266121ec29d0d42b7 (diff)
K-FAC: Support onehot categorical in kfac.loss_functions.
PiperOrigin-RevId: 180536416
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py71
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py14
3 files changed, 86 insertions, 0 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
index 39ce3e9337..63f45ea55b 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -114,5 +114,76 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
self.assertEqual(loss.num_registered_minibatches, num_towers)
+class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
+
+ def testSample(self):
+ """Ensure samples can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ sample = loss.sample(42)
+ sample = sess.run(sample)
+ self.assertEqual(sample.shape, (2, 3))
+
+ def testEvaluateOnTargets(self):
+ """Ensure log probability can be evaluated correctly."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ targets = np.asarray([2, 1]).astype(np.int32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
+ neg_log_prob = loss.evaluate()
+ neg_log_prob = sess.run(neg_log_prob)
+
+ # Calculate explicit log probability of targets.
+ probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
+ log_probs = np.log([
+ probs[0, targets[0]], #
+ probs[1, targets[1]]
+ ])
+ expected_log_prob = np.sum(log_probs)
+
+ self.assertAllClose(neg_log_prob, -expected_log_prob)
+
+ def testEvaluateOnSample(self):
+ """Ensure log probability of a sample can be drawn."""
+ with ops.Graph().as_default(), self.test_session() as sess:
+ logits = np.asarray([
+ [0., 0., 0.], #
+ [1., -1., 0.]
+ ]).astype(np.float32)
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ array_ops.constant(logits))
+ neg_log_prob = loss.evaluate_on_sample(42)
+
+ # Simply ensure this doesn't crash. As the output is random, it's
+ # difficult to say if the output is correct or not...
+ neg_log_prob = sess.run(neg_log_prob)
+
+ def testMultiMinibatchRegistration(self):
+ """Ensure this loss function supports registering multiple minibatches."""
+ with ops.Graph().as_default():
+ tower_logits = []
+ loss = None
+ num_towers = 5
+ for _ in range(num_towers):
+ logits = random_ops.random_uniform(shape=[2, 3])
+ tower_logits.append(logits)
+ if loss is None:
+ loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
+ logits)
+ else:
+ loss.register_additional_minibatch(logits)
+ self.assertListEqual(loss.input_minibatches, tower_logits)
+ self.assertEqual(loss.num_registered_minibatches, num_towers)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index 9be3d60dc0..cd9dca3f02 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -65,6 +65,7 @@ py_library(
srcs = ["loss_functions.py"],
srcs_version = "PY2AND3",
deps = [
+ "//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index d449abcfa7..2daead2a71 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -22,6 +22,7 @@ import abc
import six
+from tensorflow.contrib.distributions.python.ops import onehot_categorical
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -785,3 +786,16 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
after[dim] = dim_size - position - 1
return array_ops.pad(slice_to_insert, list(zip(before, after)))
+
+
+class OnehotCategoricalLogitsNegativeLogProbLoss(
+ CategoricalLogitsNegativeLogProbLoss):
+ """Neg log prob loss for a categorical distribution with onehot targets.
+
+ Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
+ distribution is OneHotCategorical as opposed to Categorical.
+ """
+
+ @property
+ def dist(self):
+ return onehot_categorical.OneHotCategorical(logits=self._logits)