diff options
author | 2018-03-27 19:58:53 -0700 | |
---|---|---|
committer | 2018-03-27 20:01:36 -0700 | |
commit | 078793e131ecec0eae3e7a0549eb13a321993dfb (patch) | |
tree | d5c1fd643e92922c739f940f1e64cd6e77737db8 | |
parent | 756dc39c83136cc3518e20993be4382fe77f9013 (diff) |
Fix _force_data_dependency for scalar inputs
PiperOrigin-RevId: 190715033
-rw-r--r-- | tensorflow/contrib/layers/python/layers/rev_block_lib.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 0b38c0c3fd..e49589ddf6 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -33,6 +33,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops from tensorflow.python.layers import base @@ -660,7 +661,9 @@ def _force_data_dependency(first_compute, then_compute): if x.get_shape().ndims is None: raise ValueError("Rank of Tensor %s must be known" % x) ndims = x.get_shape().ndims - return array_ops.reshape(array_ops.slice(x, [0] * ndims, [1] * ndims), []) + begin = framework_ops.convert_to_tensor([0] * ndims, dtype=dtypes.int32) + size = framework_ops.convert_to_tensor([1] * ndims, dtype=dtypes.int32) + return array_ops.reshape(array_ops.slice(x, begin, size), []) first_compute_sum = math_ops.add_n( [_first_element(x) for x in first_compute if x is not None]) |