aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 16:43:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 16:48:25 -0700
commit71593602d95385fbd8c3dde361dab09d381b5ac6 (patch)
tree5a2e4c2e39d54b6ec8721ce4e531e3ac254fe9fd /tensorflow/contrib/kfac
parent496840acbdd8b8b7688c257793e09a02229d21f6 (diff)
Fixed a bug in ConvKFCBasicMultiIndepFB introduced in the last CL
PiperOrigin-RevId: 190695737
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index b04bf76a88..e0d9cb5ea9 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -861,12 +861,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
super(ConvKFCBasicFB, self).__init__(layer_collection)
def instantiate_factors(self, grads_list, damping):
+ inputs, grads_list = self._process_data(grads_list)
+
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
- inputs, grads_list = self._process_data(grads_list)
-
self._input_factor = self._layer_collection.make_or_get_factor(
fisher_factors.ConvInputKroneckerFactor,
(inputs, self._filter_shape, self._padding, self._strides,
@@ -1391,7 +1391,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
inputs, grads_list = self._process_data(grads_list)
# Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs.shape.as_list(),
+ self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
self._strides)
self._input_factor = self._layer_collection.make_or_get_factor(