aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-27 19:58:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 20:01:36 -0700
commit078793e131ecec0eae3e7a0549eb13a321993dfb (patch)
treed5c1fd643e92922c739f940f1e64cd6e77737db8
parent756dc39c83136cc3518e20993be4382fe77f9013 (diff)
Fix _force_data_dependency for scalar inputs
PiperOrigin-RevId: 190715033
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index 0b38c0c3fd..e49589ddf6 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -33,6 +33,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops as framework_ops
from tensorflow.python.layers import base
@@ -660,7 +661,9 @@ def _force_data_dependency(first_compute, then_compute):
if x.get_shape().ndims is None:
raise ValueError("Rank of Tensor %s must be known" % x)
ndims = x.get_shape().ndims
- return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), [])
+ begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32)
+ size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32)
+ return array_ops.reshape(array_ops.slice(x, begin, size), [])
first_compute_sum = math_ops.add_n(
[_first_element(x) for x in first_compute if x is not None])