diff options
author | 2018-06-28 10:47:43 -0700 | |
---|---|---|
committer | 2018-06-28 21:37:43 -0700 | |
commit | d7ebc1f4ca2c677710c5257d30c757f0f8b604c6 (patch) | |
tree | 7992c0a89d42c3bb383d7ab2072fcbdd2e6e1664 | |
parent | bcd500259f891fd818c0d85e3bd00dface533b1b (diff) |
Avoid overflow in flops calculations in nn_ops.py by forcing
np.prod() to use np.int64 in a few places.
PiperOrigin-RevId: 202505308
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index acaa9c0ce5..41d54a6c2f 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2167,7 +2167,7 @@ def _calc_conv_flops(graph, node): filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) filter_in_depth = int(filter_shape[2]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats( "flops", (output_count * filter_in_depth * filter_height * filter_width * 2)) @@ -2185,7 +2185,7 @@ def _calc_depthwise_conv_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) @@ -2595,7 +2595,7 @@ def _calc_dilation2d_flops(graph, node): output_shape.assert_is_fully_defined() filter_height = int(filter_shape[0]) filter_width = int(filter_shape[1]) - output_count = np.prod(output_shape.as_list()) + output_count = np.prod(output_shape.as_list(), dtype=np.int64) return ops.OpStats("flops", (output_count * filter_height * filter_width * 2)) |