aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 10:47:43 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commitd7ebc1f4ca2c677710c5257d30c757f0f8b604c6 (patch)
tree7992c0a89d42c3bb383d7ab2072fcbdd2e6e1664
parentbcd500259f891fd818c0d85e3bd00dface533b1b (diff)
Avoid overflow in flops calculations in nn_ops.py by forcing
np.prod() to use np.int64 in a few places. PiperOrigin-RevId: 202505308
-rw-r--r--tensorflow/python/ops/nn_ops.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index acaa9c0ce5..41d54a6c2f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -2167,7 +2167,7 @@ def _calc_conv_flops(graph, node):
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
filter_in_depth = int(filter_shape[2])
- output_count = np.prod(output_shape.as_list())
+ output_count = np.prod(output_shape.as_list(), dtype=np.int64)
return ops.OpStats(
"flops",
(output_count * filter_in_depth * filter_height * filter_width * 2))
@@ -2185,7 +2185,7 @@ def _calc_depthwise_conv_flops(graph, node):
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
- output_count = np.prod(output_shape.as_list())
+ output_count = np.prod(output_shape.as_list(), dtype=np.int64)
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))
@@ -2595,7 +2595,7 @@ def _calc_dilation2d_flops(graph, node):
output_shape.assert_is_fully_defined()
filter_height = int(filter_shape[0])
filter_width = int(filter_shape[1])
- output_count = np.prod(output_shape.as_list())
+ output_count = np.prod(output_shape.as_list(), dtype=np.int64)
return ops.OpStats("flops", (output_count * filter_height * filter_width * 2))