diff options
author | 2018-04-03 10:00:00 -0700 | |
---|---|---|
committer | 2018-04-03 10:03:06 -0700 | |
commit | cf8c504688c5f5813c8772eb107ed3d4a1385888 (patch) | |
tree | df8684393eb737fc974a2e803bec99ed33030242 /tensorflow/contrib/kfac | |
parent | 27c762c336bb11c8f74694e3d3ea5c8c47a28003 (diff) |
Bug Fix: If num_uses > 0 the the inputs tensor need not be a list but can be reshaped to
[batch_size*num_uses, input_size]. `num_uses` should be incremented by one in this case.'
PiperOrigin-RevId: 191456184
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r-- | tensorflow/contrib/kfac/python/ops/layer_collection.py | 23 |
1 files changed, 15 insertions, 8 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 586a004f88..19608aca47 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -990,9 +990,11 @@ class LayerCollection(object): num_uses=num_uses), reuse=reuse) block.register_additional_tower(inputs, outputs) - - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_conv2d_multi(self, params, @@ -1066,9 +1068,11 @@ class LayerCollection(object): reuse=reuse) block.register_additional_tower(inputs, outputs) - - assert len(inputs) == len(outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + assert len(inputs) == len(outputs) + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) # TODO(b/74108452): change the loss registration functions names to refer # to "loss functions" instead of distributions. Following naming convention @@ -1088,7 +1092,7 @@ class LayerCollection(object): inputs: A list of Tensors, each of shape [batch_size, input_size] and dtype int32. Indices into embedding matrix. The list indexes each use in the graph (which might correspond to a "time-step" in an RNN). - OR, can be single Tensor, of shape [num_uses, batch_size, input_size], + OR, can be single Tensor, of shape [num_uses*batch_size, input_size], which is a reshaped version of a Tensor of shape [num_uses, batch_size, input_size]. outputs: A list of Tensors, each of shape [batch_size, embedding_size]. @@ -1129,7 +1133,10 @@ class LayerCollection(object): params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse) block.register_additional_tower(inputs, outputs) - self._add_uses(params, len(inputs)) + if isinstance(inputs, (tuple, list)): + self._add_uses(params, len(inputs)) + else: + self._add_uses(params, 1) def register_categorical_predictive_distribution(self, logits, |