aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-02 01:36:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 01:39:15 -0700
commit5e1448f691afe6e9ba57bb67497311c45b855b82 (patch)
tree494e00ce35ce5d8623606951ec95ab33234cb585 /tensorflow/contrib/kfac
parent7715b7b0650c2f20b47189a060580a45e510acd8 (diff)
BUGFIX: Convert inputs and list of gradients into tuple if they are not instance of tuple. Otherwise this causes "unhashable keys" error when we try to hash.
Also fixed lint error. PiperOrigin-RevId: 195061425
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
index 32c776cb38..3a5c8eb5f9 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
@@ -673,9 +673,6 @@ class KroneckerProductFB(FisherBlock):
output factors.
"""
- def __init__(self, layer_collection):
- super(KroneckerProductFB, self).__init__(layer_collection)
-
def _setup_damping(self, damping, normalization=None):
"""Makes functions that compute the damping values for both factors."""
def compute_damping():
@@ -1309,6 +1306,8 @@ class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
else:
raise ValueError("Global config variable TOWER_STRATEGY must be one of "
"'concat' or 'separate'.")
+ else:
+ inputs = tuple(inputs)
# Now we perform the analogous processing for grads_list
if isinstance(grads_list[0][0], (list, tuple)):
@@ -1351,6 +1350,8 @@ class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
else:
raise ValueError("Global config variable TOWER_STRATEGY must be one of "
"'concat' or 'separate'.")
+ else:
+ grads_list = tuple(tuple(grads) for grads in grads_list)
if self._num_uses is None:
raise ValueError("You must supply a value for the num_uses argument if "