aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-10 10:27:25 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit7ac140a5845553275427162aabd9d54987144b4a (patch)
tree3929de71e5cafe72b96a5aab8fa49b4cc246294b /tensorflow
parent51889acee1a266478b578afad3fbe7b3a90fc17a (diff)
Adds properties to LossFunction to access inputs separated by minibatch.
Currently, information about separate minibatches registered by `LossFunction`s is private, and only the concatenation of all minibatch inputs is exposed through the `inputs` property. This change adds `input_minibatches` and `num_registered_minibatches` to `LossFunction` to expose this information. PiperOrigin-RevId: 175306297
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py17
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py62
3 files changed, 64 insertions, 16 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
index 5d86373a23..5b7747b0a1 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
@@ -139,6 +139,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_ops",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
index 87339cb059..39ce3e9337 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
@@ -24,6 +24,7 @@ from tensorflow.contrib.kfac.python.ops import loss_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
@@ -96,6 +97,22 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
# difficult to say if the output is correct or not...
neg_log_prob = sess.run(neg_log_prob)
+ def testMultiMinibatchRegistration(self):
+ """Ensure this loss function supports registering multiple minibatches."""
+ with ops.Graph().as_default():
+ tower_logits = []
+ loss = None
+ num_towers = 5
+ for _ in range(num_towers):
+ logits = random_ops.random_uniform(shape=[2, 3])
+ tower_logits.append(logits)
+ if loss is None:
+ loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
+ else:
+ loss.register_additional_minibatch(logits)
+ self.assertListEqual(loss.input_minibatches, tower_logits)
+ self.assertEqual(loss.num_registered_minibatches, num_towers)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
index 3cfde7f9ab..e2e5bc3ffe 100644
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py
@@ -56,6 +56,30 @@ class LossFunction(object):
"""The inputs to the loss function (excluding the targets)."""
pass
+ @property
+ def input_minibatches(self):
+ """A `list` of inputs to the loss function, separated by minibatch.
+
+ Typically there will be one minibatch per tower in a multi-tower setup.
+ Returns a list consisting of `self.inputs` by default; `LossFunction`s
+ supporting registering multiple minibatches should override this method.
+
+ Returns:
+ A `list` of `Tensor`s representing
+ """
+ return [self.inputs]
+
+ @property
+ def num_registered_minibatches(self):
+ """Number of minibatches registered for this LossFunction.
+
+ Typically equal to the number of towers in a multi-tower setup.
+
+ Returns:
+ An `int` representing the number of registered minibatches.
+ """
+ return len(self.input_minibatches)
+
def evaluate(self):
"""Evaluate the loss function on the targets."""
if self.targets is not None:
@@ -75,7 +99,6 @@ class LossFunction(object):
Returns:
log probability of each target, summed across all targets.
"""
-
pass
@abc.abstractmethod
@@ -415,8 +438,8 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
axis=-1)
output_slice = self._var**-0.5 * ones_slice
- return insert_slice_in_zeros(output_slice, 1,
- int(self._mean.shape[1]), index[0])
+ return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
+ index[0])
@property
def fisher_factor_inner_shape(self):
@@ -474,24 +497,23 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def _fisher_mean(self):
- return 1./self._variance
+ return 1. / self._variance
@property
def _fisher_mean_factor(self):
- return 1./self._scale
+ return 1. / self._scale
@property
def _fisher_var(self):
- return 1./(2*math_ops.square(self._variance))
+ return 1. / (2 * math_ops.square(self._variance))
@property
def _fisher_var_factor(self):
- return 1./(math_ops.sqrt(2.)*self._variance)
+ return 1. / (math_ops.sqrt(2.) * self._variance)
def multiply_fisher(self, vecs):
mean_vec, var_vec = vecs
- return (self._fisher_mean * mean_vec,
- self._fisher_var * var_vec)
+ return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
def multiply_fisher_factor(self, vecs):
mean_vec, var_vec = self._split(vecs)
@@ -511,8 +533,8 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
# Index corresponds to mean parameter.
mean_slice = self._fisher_mean_factor[:, index]
mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
- mean_output = insert_slice_in_zeros(mean_slice, 1,
- int(self._mean.shape[1]), index)
+ mean_output = insert_slice_in_zeros(mean_slice, 1, int(
+ self._mean.shape[1]), index)
var_output = array_ops.zeros_like(mean_output)
else:
index -= int(self._mean.shape[-1])
@@ -527,13 +549,17 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
@property
def fisher_factor_inner_shape(self):
- return array_ops.concat([array_ops.shape(self._mean)[:-1],
- 2*array_ops.shape(self._mean)[-1:]], axis=0)
+ return array_ops.concat(
+ [
+ array_ops.shape(self._mean)[:-1],
+ 2 * array_ops.shape(self._mean)[-1:]
+ ],
+ axis=0)
@property
def fisher_factor_inner_static_shape(self):
shape = self._mean.shape.as_list()
- return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]])
+ return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
def multiply_hessian(self, vector):
raise NotImplementedError()
@@ -606,6 +632,10 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
return array_ops.concat(self._logits_components, axis=0)
@property
+ def input_minibatches(self):
+ return self._logits_components
+
+ @property
def targets(self):
if all(target is None for target in self._targets_components):
return None
@@ -710,8 +740,8 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
assert len(index) == 1, "Length of index was {}".format(len(index))
probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
- return insert_slice_in_zeros(output_slice, 1,
- int(self._logits.shape[1]), index[0])
+ return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
+ index[0])
@property
def fisher_factor_inner_shape(self):