aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/profiler/internal
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-11 11:16:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-11 11:20:24 -0700
commit509372c2ee09d26f21e7040defa45bf7c63ce4c1 (patch)
tree0b785f14ef13a11d056eb29c553f9656f6073ce4 /tensorflow/python/profiler/internal
parent80ed8afc02b16adb209bffeb551e7f0c435985f6 (diff)
Add a lot of operations' flops calculations
PiperOrigin-RevId: 168256746
Diffstat (limited to 'tensorflow/python/profiler/internal')
-rw-r--r--tensorflow/python/profiler/internal/BUILD10
-rw-r--r--tensorflow/python/profiler/internal/flops_registry.py446
2 files changed, 456 insertions, 0 deletions
diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD
index f9cc8af19d..dcac070a3f 100644
--- a/tensorflow/python/profiler/internal/BUILD
+++ b/tensorflow/python/profiler/internal/BUILD
@@ -8,6 +8,16 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
py_library(
+ name = "flops_registry",
+ srcs = ["flops_registry.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:graph_util",
+ ],
+)
+
+py_library(
name = "model_analyzer_testlib",
srcs = ["model_analyzer_testlib.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/profiler/internal/flops_registry.py b/tensorflow/python/profiler/internal/flops_registry.py
new file mode 100644
index 0000000000..e143501049
--- /dev/null
+++ b/tensorflow/python/profiler/internal/flops_registry.py
@@ -0,0 +1,446 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Register flops statistics for various TensorFlow operations.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import ops
+
+
+# List of all ops which have implemented flops statistics.
+IMPLEMENTED_OPS = set([
+ # Unary ops
+ "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
+ "L2Loss", "Softmax",
+ # Binary ops
+ "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
+ "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
+ "SquaredDifference",
+ # Reduction ops
+ "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
+ # Convolution and pooling
+ "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
+ "Conv2DBackpropFilter",
+ # Other ops
+ "AddN",
+ # Ops implemented in core tensorflow:
+ "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
+])
+
+
+def _zero_flops(graph, node):
+ """Returns zero flops."""
+ del graph, node # graph and node are unused
+ return ops.OpStats("flops", 0)
+
+
+def _list_product(lst):
+ """Computes product of element of the list."""
+ result = 1
+ for item in lst:
+ result *= item
+ return result
+
+################################################################################
+# Unary operations
+################################################################################
+
+
+def _unary_op_flops(graph, node, ops_per_element=1):
+ """Common code which compute flops for unary operations."""
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
+
+
+@ops.RegisterStatistics("Reciprocal", "flops")
+def _reciprocal_flops(graph, node):
+ """Compute flops for Reciprocal operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Square", "flops")
+def _square_flops(graph, node):
+ """Compute flops for Square operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Rsqrt", "flops")
+def _rsqrt_flops(graph, node):
+ """Compute flops for Rsqrt operation."""
+ # Rsqrt(x) = 1 / sqrt(x)
+ return _unary_op_flops(graph, node, ops_per_element=2)
+
+
+@ops.RegisterStatistics("Log", "flops")
+def _log_flops(graph, node):
+ """Compute flops for Log operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Neg", "flops")
+def _neg_flops(graph, node):
+ """Compute flops for Neg operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("AssignSub", "flops")
+def _assign_sub_flops(graph, node):
+ """Compute flops for AssignSub operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("AssignAdd", "flops")
+def _assign_add_flops(graph, node):
+ """Compute flops for AssignAdd operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("L2Loss", "flops")
+def _l2_loss_flops(graph, node):
+ """Compute flops for L2Loss operation."""
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ # Tensorflow uses inefficient implementation, with (3*N-1) flops:
+ # Optimal implementation is 2*N flops
+ return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
+
+
+@ops.RegisterStatistics("Softmax", "flops")
+def _softmax_flops(graph, node):
+ """Compute flops for Softmax operation."""
+ # Softmax implenetation:
+ #
+ # Approximate flops breakdown:
+ # 2*n -- compute shifted logits
+ # n -- exp of shifted logits
+ # 2*n -- compute softmax from exp of shifted logits
+ return _unary_op_flops(graph, node, ops_per_element=5)
+
+################################################################################
+# Binary operations
+################################################################################
+
+
+def _binary_per_element_op_flops(graph, node, ops_per_element=1):
+ """Common code which compute flops for binary operations."""
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
+
+
+@ops.RegisterStatistics("Add", "flops")
+def _add_flops(graph, node):
+ """Compute flops for Add operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Sub", "flops")
+def _sub_flops(graph, node):
+ """Compute flops for Sub operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Mul", "flops")
+def _mul_flops(graph, node):
+ """Compute flops for Mul operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("RealDiv", "flops")
+def _real_div_flops(graph, node):
+ """Compute flops for RealDiv operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Maximum", "flops")
+def _maximum_flops(graph, node):
+ """Compute flops for Maximum operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Minimum", "flops")
+def _minimum_flops(graph, node):
+ """Compute flops for Minimum operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Pow", "flops")
+def _pow_flops(graph, node):
+ """Compute flops for Pow operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("RsqrtGrad", "flops")
+def _rsqrt_grad_flops(graph, node):
+ """Compute flops for RsqrtGrad operation."""
+ return _binary_per_element_op_flops(graph, node, ops_per_element=4)
+
+
+@ops.RegisterStatistics("GreaterEqual", "flops")
+def _greater_equal_flops(graph, node):
+ """Compute flops for GreaterEqual operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Greater", "flops")
+def _greater_flops(graph, node):
+ """Compute flops for Greater operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("LessEqual", "flops")
+def _less_equal_flops(graph, node):
+ """Compute flops for LessEqual operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Less", "flops")
+def _less_flops(graph, node):
+ """Compute flops for Less operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Equal", "flops")
+def _equal_flops(graph, node):
+ """Compute flops for Equal operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("NotEqual", "flops")
+def _not_equal_flops(graph, node):
+ """Compute flops for NotEqual operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("SquaredDifference", "flops")
+def _squared_difference_flops(graph, node):
+ """Compute flops for SquaredDifference operation."""
+ return _binary_per_element_op_flops(graph, node, ops_per_element=2)
+
+################################################################################
+# Reduction ops
+################################################################################
+
+
+def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
+ """Common code which compute flops for reduction operations."""
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ num_flops = (in_shape.num_elements() * reduce_flops
+ + out_shape.num_elements() * (finalize_flops - reduce_flops))
+ return ops.OpStats("flops", num_flops)
+
+
+@ops.RegisterStatistics("Mean", "flops")
+def _mean_flops(graph, node):
+ """Compute flops for Mean operation."""
+ # reduction - sum, finalization - divide
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
+
+
+@ops.RegisterStatistics("Sum", "flops")
+def _sum_flops(graph, node):
+ """Compute flops for Sum operation."""
+ # reduction - sum, no finalization
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+
+@ops.RegisterStatistics("ArgMax", "flops")
+def _arg_max_flops(graph, node):
+ """Compute flops for ArgMax operation."""
+ # reduction - comparison, no finalization
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+
+@ops.RegisterStatistics("ArgMin", "flops")
+def _arg_min_flops(graph, node):
+ """Compute flops for ArgMin operation."""
+ # reduction - comparison, no finalization
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+
+@ops.RegisterStatistics("BiasAddGrad", "flops")
+def _bias_add_grad_flops(graph, node):
+ """Compute flops for BiasAddGrad operation."""
+ # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
+ # So computing flops same way as for "Sum"
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+################################################################################
+# Convolution and pooling
+# Note: all flops statistics are implemented only for NHWC data format
+################################################################################
+
+
+def _verify_conv_data_format(node):
+ """Verifies data format for pooling and convolutional operations."""
+ # TODO(xpan): P1: Support NCHW
+ if node.attr["data_format"].s != b"NHWC":
+ raise ValueError("Only NHWC format is supported in flops computations")
+
+
+def _pool_flops(graph, node):
+ """Common code which compute flops for pooling operations."""
+ # compute flops for average and max pooling
+ _verify_conv_data_format(node)
+ #
+ # Pooling declaration:
+ # Inputs:
+ # - value
+ # Outputs:
+ # - output
+ # Attributes:
+ # - ksize
+ # - strides
+ # - padding
+ # - data_format
+ #
+ # Pooling implenetation:
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ kernel_shape = list(node.attr["ksize"].list.i)
+ kernel_area = _list_product(kernel_shape)
+ return ops.OpStats("flops", kernel_area * out_shape.num_elements())
+
+
+@ops.RegisterStatistics("AvgPool", "flops")
+def _avg_pool_flops(graph, node):
+ """Compute flops for AvgPool operation."""
+ return _pool_flops(graph, node)
+
+
+@ops.RegisterStatistics("MaxPool", "flops")
+def _max_pool_flops(graph, node):
+ """Compute flops for MaxPool operation."""
+ return _pool_flops(graph, node)
+
+
+@ops.RegisterStatistics("AvgPoolGrad", "flops")
+def _avg_pool_grad_flops(graph, node):
+ """Compute flops for AvgPoolGrad operation."""
+ _verify_conv_data_format(node)
+ # Pooling gradient implementation:
+ out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
+ node.input[1])
+ out_backprop_shape.assert_is_fully_defined()
+ kernel_shape = list(node.attr["ksize"].list.i)
+ kernel_area = _list_product(kernel_shape)
+ # TensorFlow multiply each element of pooling window by coefficient,
+ # then sum up all of them, thus we have 2 flops per element:
+ # More optimal implementation - if division is done after.
+ return ops.OpStats("flops",
+ kernel_area * out_backprop_shape.num_elements() * 2)
+
+
+@ops.RegisterStatistics("MaxPoolGrad", "flops")
+def _max_pool_grad_flops(graph, node):
+ """Compute flops for MaxPoolGrad operation."""
+ _verify_conv_data_format(node)
+ #
+ # MaxPoolGrad declaration:
+ # Inputs:
+ # - orig_input -- original input tensor (of max_pool)
+ # - orig_output -- original output tensor (of max_pool)
+ # - grad -- gradient with respect to output of max_pool
+ # Outputs:
+ # - output -- gradient with respect to input of max_pool
+ # Attributes:
+ # - ksize
+ # - strides
+ # - padding
+ # - data_format
+ # It computes MaxPool first, then one flop per each element of original output
+ #
+ kernel_shape = list(node.attr["ksize"].list.i)
+ kernel_area = _list_product(kernel_shape)
+ orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
+ node.input[1])
+ max_pool_ops = kernel_area * orig_out_shape.num_elements()
+ return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
+
+
+@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
+def _conv_2d_backprop_input_flops(graph, node):
+ """Compute flops for Conv2DBackpropInput operation."""
+ # Formula:
+ # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
+ # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
+ #
+ # Where:
+ # image_x_dim, image_y_dim and input_depth --- size of input to source (no
+ # backprop) convolution, in other words they are sizes of backprop output.
+ # output_depth --- number of filters in the original convolution, thus
+ # depth of backprop input.
+ # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
+ # image_x_stride and image_x_stride --- strides of the convolution
+ #
+ _verify_conv_data_format(node)
+ # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
+ kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
+ node.input[1])
+ kernel_shape.assert_is_fully_defined()
+ # strides
+ strides_shape = list(node.attr["strides"].list.i)
+ strides_product = strides_shape[1] * strides_shape[2]
+ return ops.OpStats("flops",
+ (2 * out_shape.num_elements()
+ * kernel_shape.num_elements()
+ / (out_shape[-1].value * strides_product)))
+
+
+@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
+def _conv_2d_backprop_filter_flops(graph, node):
+ """Compute flops for Conv2DBackpropFilter operation."""
+ # Formula same as for Conv2DBackpropInput:
+ # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
+ # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
+ #
+ _verify_conv_data_format(node)
+ # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
+ image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ image_shape.assert_is_fully_defined()
+ # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
+ kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ kernel_shape.assert_is_fully_defined()
+ # strides
+ strides_shape = list(node.attr["strides"].list.i)
+ strides_product = strides_shape[1] * strides_shape[2]
+ return ops.OpStats("flops",
+ (2 * image_shape.num_elements()
+ * kernel_shape.num_elements()
+ / (image_shape[-1].value * strides_product)))
+
+################################################################################
+# Other ops
+################################################################################
+
+
+@ops.RegisterStatistics("AddN", "flops")
+def _add_n_flops(graph, node):
+ """Compute flops for AddN operation."""
+ if not node.input:
+ return _zero_flops(graph, node)
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))