diff options
author | 2018-03-27 16:43:48 -0700 | |
---|---|---|
committer | 2018-03-27 16:48:25 -0700 | |
commit | 71593602d95385fbd8c3dde361dab09d381b5ac6 (patch) | |
tree | 5a2e4c2e39d54b6ec8721ce4e531e3ac254fe9fd /tensorflow/contrib/kfac | |
parent | 496840acbdd8b8b7688c257793e09a02229d21f6 (diff) |
Fixed a bug in ConvKFCBasicMultiIndepFB introduced in the last CL
PiperOrigin-RevId: 190695737
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/fisher_blocks.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index b04bf76a88..e0d9cb5ea9 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -861,12 +861,12 @@ class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB): super(ConvKFCBasicFB, self).__init__(layer_collection) def instantiate_factors(self, grads_list, damping): + inputs, grads_list = self._process_data(grads_list) + # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(self._inputs[0].shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) - inputs, grads_list = self._process_data(grads_list) - self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, (inputs, self._filter_shape, self._padding, self._strides, @@ -1391,7 +1391,7 @@ class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse, inputs, grads_list = self._process_data(grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._num_locations = num_conv_locations(inputs[0].shape.as_list(), self._strides) self._input_factor = self._layer_collection.make_or_get_factor( |