aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py23
-rw-r--r--tensorflow/python/ops/math_ops.py5
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py7
3 files changed, 32 insertions, 3 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 68d446602e..fa26e07c85 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1566,6 +1566,16 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
# pylint: enable=invalid-name
+def _constant_if_small(value, shape, dtype, name):
+ try:
+ if np.prod(shape) < 1000:
+ return constant(value, shape=shape, dtype=dtype, name=name)
+ except TypeError:
+ # Happens when shape is a Tensor, list with Tensor elements, etc.
+ pass
+ return None
+
+
@tf_export("zeros")
def zeros(shape, dtype=dtypes.float32, name=None):
"""Creates a tensor with all elements set to zero.
@@ -1596,8 +1606,15 @@ def zeros(shape, dtype=dtypes.float32, name=None):
zero = ""
else:
zero = 0
+
if not isinstance(shape, ops.Tensor):
try:
+ # Create a constant if it won't be very big. Otherwise create a fill op
+ # to prevent serialized GraphDefs from becoming too large.
+ output = _constant_if_small(zero, shape, dtype, name)
+ if output is not None:
+ return output
+
# Go through tensor shapes to get int64-if-needed semantics
shape = constant_op._tensor_shape_tensor_conversion_function(
tensor_shape.TensorShape(shape))
@@ -1729,6 +1746,12 @@ def ones(shape, dtype=dtypes.float32, name=None):
one = True if dtype == dtypes.bool else 1
if not isinstance(shape, ops.Tensor):
try:
+ # Create a constant if it won't be very big. Otherwise create a fill op
+ # to prevent serialized GraphDefs from becoming too large.
+ output = _constant_if_small(one, shape, dtype, name)
+ if output is not None:
+ return output
+
# Go through tensor shapes to get int64-if-needed semantics
shape = constant_op._tensor_shape_tensor_conversion_function(
tensor_shape.TensorShape(shape))
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index b460ce5b95..01d670ea2d 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1402,10 +1402,11 @@ def reduce_sum(input_tensor,
keep_dims: Deprecated alias for `keepdims`.
Returns:
- The reduced tensor.
+ The reduced tensor, of the same dtype as the input_tensor.
@compatibility(numpy)
- Equivalent to np.sum
+ Equivalent to np.sum appart the fact that numpy upcast uint8 and int32 to
+ int64 while tensorflow returns the same dtype as the input.
@end_compatibility
"""
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 07e25e540c..508ba9bfee 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -72,7 +72,12 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
# know the shape and dtype of the variable pointed to by a handle. Since
# shape inference doesn't run in eager mode we copy this data here for when
# the handle is captured by an eager mode function.
- handle._handle_data = h._handle_data # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ if h._handle_data is None:
+ ops.set_shape_and_handle_data_for_outputs(h.op)
+ handle._handle_data = h._handle_data
+ # pylint: enable=protected-access
+
# Clean up our reference cycles to avoid making the garbage collector run.
# pylint: disable=protected-access
# OrderedDict, constructed on Graph creation, makes a simple reference loop