aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 17:13:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 17:15:48 -0700
commit3dc861e754ffb86286038ef9c78327f59384eaad (patch)
tree827abe93f85172b4df1446a4d51812ef43e7850f /tensorflow/contrib/kfac
parent6f8b85d301140ce42c0aa4871750ee0aec758105 (diff)
K-FAC: Bugfixes for TPU compatibility with covariance update ops.
PiperOrigin-RevId: 190699635
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 353e1c6abb..0d40d265a1 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -336,12 +336,16 @@ class FisherFactor(object):
new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
- # I have no idea if the TPU code below is still correct since I don't know
- # what it actually does. Also, this code is not present in some of the
- # other versions of make_covariance_update_op. Does it matter?
- # Synchronize value across all TPU cores.
+ # Compute average of 'new_cov' across all TPU cores. On a TPU, each
+ # instance of 'new_cov' will be based on a different minibatch. This ensures
+ # that by the end of assign_moving_average(), all TPU cores see the same
+ # value for self._cov.
+ #
+ # Other implementations of make_covariance_update_op() that accumulate
+ # statistics in other variables should mimic this behavior.
if utils.on_tpu():
new_cov = utils.cross_replica_mean(new_cov)
+
return moving_averages.assign_moving_average(
self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
@@ -1398,6 +1402,10 @@ class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
/ float(self._num_towers))
+ # See comments in FisherFactor.make_covariance_update_op() for details.
+ if utils.on_tpu():
+ new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1)
+
op2 = moving_averages.assign_moving_average(
self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)