diff options
Diffstat (limited to 'tensorflow')
3 files changed, 64 insertions, 16 deletions
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 5d86373a23..5b7747b0a1 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -139,6 +139,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 87339cb059..39ce3e9337 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -24,6 +24,7 @@ from tensorflow.contrib.kfac.python.ops import loss_functions from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -96,6 +97,22 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): # difficult to say if the output is correct or not... neg_log_prob = sess.run(neg_log_prob) + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index 3cfde7f9ab..e2e5bc3ffe 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -56,6 +56,30 @@ class LossFunction(object): """The inputs to the loss function (excluding the targets).""" pass + @property + def input_minibatches(self): + """A `list` of inputs to the loss function, separated by minibatch. + + Typically there will be one minibatch per tower in a multi-tower setup. + Returns a list consisting of `self.inputs` by default; `LossFunction`s + supporting registering multiple minibatches should override this method. + + Returns: + A `list` of `Tensor`s representing + """ + return [self.inputs] + + @property + def num_registered_minibatches(self): + """Number of minibatches registered for this LossFunction. + + Typically equal to the number of towers in a multi-tower setup. + + Returns: + An `int` representing the number of registered minibatches. + """ + return len(self.input_minibatches) + def evaluate(self): """Evaluate the loss function on the targets.""" if self.targets is not None: @@ -75,7 +99,6 @@ class LossFunction(object): Returns: log probability of each target, summed across all targets. """ - pass @abc.abstractmethod @@ -415,8 +438,8 @@ class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), axis=-1) output_slice = self._var**-0.5 * ones_slice - return insert_slice_in_zeros(output_slice, 1, - int(self._mean.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): @@ -474,24 +497,23 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def _fisher_mean(self): - return 1./self._variance + return 1. / self._variance @property def _fisher_mean_factor(self): - return 1./self._scale + return 1. / self._scale @property def _fisher_var(self): - return 1./(2*math_ops.square(self._variance)) + return 1. / (2 * math_ops.square(self._variance)) @property def _fisher_var_factor(self): - return 1./(math_ops.sqrt(2.)*self._variance) + return 1. / (math_ops.sqrt(2.) * self._variance) def multiply_fisher(self, vecs): mean_vec, var_vec = vecs - return (self._fisher_mean * mean_vec, - self._fisher_var * var_vec) + return (self._fisher_mean * mean_vec, self._fisher_var * var_vec) def multiply_fisher_factor(self, vecs): mean_vec, var_vec = self._split(vecs) @@ -511,8 +533,8 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): # Index corresponds to mean parameter. mean_slice = self._fisher_mean_factor[:, index] mean_slice = array_ops.expand_dims(mean_slice, axis=-1) - mean_output = insert_slice_in_zeros(mean_slice, 1, - int(self._mean.shape[1]), index) + mean_output = insert_slice_in_zeros(mean_slice, 1, int( + self._mean.shape[1]), index) var_output = array_ops.zeros_like(mean_output) else: index -= int(self._mean.shape[-1]) @@ -527,13 +549,17 @@ class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): @property def fisher_factor_inner_shape(self): - return array_ops.concat([array_ops.shape(self._mean)[:-1], - 2*array_ops.shape(self._mean)[-1:]], axis=0) + return array_ops.concat( + [ + array_ops.shape(self._mean)[:-1], + 2 * array_ops.shape(self._mean)[-1:] + ], + axis=0) @property def fisher_factor_inner_static_shape(self): shape = self._mean.shape.as_list() - return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]]) + return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]]) def multiply_hessian(self, vector): raise NotImplementedError() @@ -606,6 +632,10 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, return array_ops.concat(self._logits_components, axis=0) @property + def input_minibatches(self): + return self._logits_components + + @property def targets(self): if all(target is None for target in self._targets_components): return None @@ -710,8 +740,8 @@ class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, assert len(index) == 1, "Length of index was {}".format(len(index)) probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) - return insert_slice_in_zeros(output_slice, 1, - int(self._logits.shape[1]), index[0]) + return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]), + index[0]) @property def fisher_factor_inner_shape(self): |