diff options
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/array_ops.py | 23 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/resource_variable_ops.py | 7 |
3 files changed, 32 insertions, 3 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 68d446602e..fa26e07c85 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1566,6 +1566,16 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False): # pylint: enable=invalid-name +def _constant_if_small(value, shape, dtype, name): + try: + if np.prod(shape) < 1000: + return constant(value, shape=shape, dtype=dtype, name=name) + except TypeError: + # Happens when shape is a Tensor, list with Tensor elements, etc. + pass + return None + + @tf_export("zeros") def zeros(shape, dtype=dtypes.float32, name=None): """Creates a tensor with all elements set to zero. @@ -1596,8 +1606,15 @@ def zeros(shape, dtype=dtypes.float32, name=None): zero = "" else: zero = 0 + if not isinstance(shape, ops.Tensor): try: + # Create a constant if it won't be very big. Otherwise create a fill op + # to prevent serialized GraphDefs from becoming too large. + output = _constant_if_small(zero, shape, dtype, name) + if output is not None: + return output + # Go through tensor shapes to get int64-if-needed semantics shape = constant_op._tensor_shape_tensor_conversion_function( tensor_shape.TensorShape(shape)) @@ -1729,6 +1746,12 @@ def ones(shape, dtype=dtypes.float32, name=None): one = True if dtype == dtypes.bool else 1 if not isinstance(shape, ops.Tensor): try: + # Create a constant if it won't be very big. Otherwise create a fill op + # to prevent serialized GraphDefs from becoming too large. + output = _constant_if_small(one, shape, dtype, name) + if output is not None: + return output + # Go through tensor shapes to get int64-if-needed semantics shape = constant_op._tensor_shape_tensor_conversion_function( tensor_shape.TensorShape(shape)) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index b460ce5b95..01d670ea2d 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1402,10 +1402,11 @@ def reduce_sum(input_tensor, keep_dims: Deprecated alias for `keepdims`. Returns: - The reduced tensor. + The reduced tensor, of the same dtype as the input_tensor. @compatibility(numpy) - Equivalent to np.sum + Equivalent to np.sum appart the fact that numpy upcast uint8 and int32 to + int64 while tensorflow returns the same dtype as the input. @end_compatibility """ keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims, diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 07e25e540c..508ba9bfee 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -72,7 +72,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): # know the shape and dtype of the variable pointed to by a handle. Since # shape inference doesn't run in eager mode we copy this data here for when # the handle is captured by an eager mode function. - handle._handle_data = h._handle_data # pylint: disable=protected-access + # pylint: disable=protected-access + if h._handle_data is None: + ops.set_shape_and_handle_data_for_outputs(h.op) + handle._handle_data = h._handle_data + # pylint: enable=protected-access + # Clean up our reference cycles to avoid making the garbage collector run. # pylint: disable=protected-access # OrderedDict, constructed on Graph creation, makes a simple reference loop |