aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kfac
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-03 10:00:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 10:03:06 -0700
commitcf8c504688c5f5813c8772eb107ed3d4a1385888 (patch)
treedf8684393eb737fc974a2e803bec99ed33030242 /tensorflow/contrib/kfac
parent27c762c336bb11c8f74694e3d3ea5c8c47a28003 (diff)
Bug Fix: If num_uses > 0 the the inputs tensor need not be a list but can be reshaped to
[batch_size*num_uses, input_size]. `num_uses` should be incremented by one in this case.' PiperOrigin-RevId: 191456184
Diffstat (limited to 'tensorflow/contrib/kfac')
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py23
1 files changed, 15 insertions, 8 deletions
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
index 586a004f88..19608aca47 100644
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py
@@ -990,9 +990,11 @@ class LayerCollection(object):
num_uses=num_uses),
reuse=reuse)
block.register_additional_tower(inputs, outputs)
-
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
def register_conv2d_multi(self,
params,
@@ -1066,9 +1068,11 @@ class LayerCollection(object):
reuse=reuse)
block.register_additional_tower(inputs, outputs)
-
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ assert len(inputs) == len(outputs)
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
# TODO(b/74108452): change the loss registration functions names to refer
# to "loss functions" instead of distributions. Following naming convention
@@ -1088,7 +1092,7 @@ class LayerCollection(object):
inputs: A list of Tensors, each of shape [batch_size, input_size] and
dtype int32. Indices into embedding matrix. The list indexes each use
in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [num_uses, batch_size, input_size],
+ OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
which is a reshaped version of a Tensor of shape [num_uses, batch_size,
input_size].
outputs: A list of Tensors, each of shape [batch_size, embedding_size].
@@ -1129,7 +1133,10 @@ class LayerCollection(object):
params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
block.register_additional_tower(inputs, outputs)
- self._add_uses(params, len(inputs))
+ if isinstance(inputs, (tuple, list)):
+ self._add_uses(params, len(inputs))
+ else:
+ self._add_uses(params, 1)
def register_categorical_predictive_distribution(self,
logits,