diff options
author | 2017-09-11 11:16:14 -0700 | |
---|---|---|
committer | 2017-09-11 11:20:24 -0700 | |
commit | 509372c2ee09d26f21e7040defa45bf7c63ce4c1 (patch) | |
tree | 0b785f14ef13a11d056eb29c553f9656f6073ce4 /tensorflow/python/profiler/internal | |
parent | 80ed8afc02b16adb209bffeb551e7f0c435985f6 (diff) |
Add a lot of operations' flops calculations
PiperOrigin-RevId: 168256746
Diffstat (limited to 'tensorflow/python/profiler/internal')
-rw-r--r-- | tensorflow/python/profiler/internal/BUILD | 10 | ||||
-rw-r--r-- | tensorflow/python/profiler/internal/flops_registry.py | 446 |
2 files changed, 456 insertions, 0 deletions
diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD index f9cc8af19d..dcac070a3f 100644 --- a/tensorflow/python/profiler/internal/BUILD +++ b/tensorflow/python/profiler/internal/BUILD @@ -8,6 +8,16 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") py_library( + name = "flops_registry", + srcs = ["flops_registry.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:graph_util", + ], +) + +py_library( name = "model_analyzer_testlib", srcs = ["model_analyzer_testlib.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/python/profiler/internal/flops_registry.py b/tensorflow/python/profiler/internal/flops_registry.py new file mode 100644 index 0000000000..e143501049 --- /dev/null +++ b/tensorflow/python/profiler/internal/flops_registry.py @@ -0,0 +1,446 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Register flops statistics for various TensorFlow operations. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import ops + + +# List of all ops which have implemented flops statistics. +IMPLEMENTED_OPS = set([ + # Unary ops + "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd", + "L2Loss", "Softmax", + # Binary ops + "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad", + "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual", + "SquaredDifference", + # Reduction ops + "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad", + # Convolution and pooling + "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput", + "Conv2DBackpropFilter", + # Other ops + "AddN", + # Ops implemented in core tensorflow: + "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D", +]) + + +def _zero_flops(graph, node): + """Returns zero flops.""" + del graph, node # graph and node are unused + return ops.OpStats("flops", 0) + + +def _list_product(lst): + """Computes product of element of the list.""" + result = 1 + for item in lst: + result *= item + return result + +################################################################################ +# Unary operations +################################################################################ + + +def _unary_op_flops(graph, node, ops_per_element=1): + """Common code which compute flops for unary operations.""" + in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + in_shape.assert_is_fully_defined() + return ops.OpStats("flops", in_shape.num_elements() * ops_per_element) + + +@ops.RegisterStatistics("Reciprocal", "flops") +def _reciprocal_flops(graph, node): + """Compute flops for Reciprocal operation.""" + return _unary_op_flops(graph, node) + + +@ops.RegisterStatistics("Square", "flops") +def _square_flops(graph, node): + """Compute flops for Square operation.""" + return _unary_op_flops(graph, node) + + +@ops.RegisterStatistics("Rsqrt", "flops") +def _rsqrt_flops(graph, node): + """Compute flops for Rsqrt operation.""" + # Rsqrt(x) = 1 / sqrt(x) + return _unary_op_flops(graph, node, ops_per_element=2) + + +@ops.RegisterStatistics("Log", "flops") +def _log_flops(graph, node): + """Compute flops for Log operation.""" + return _unary_op_flops(graph, node) + + +@ops.RegisterStatistics("Neg", "flops") +def _neg_flops(graph, node): + """Compute flops for Neg operation.""" + return _unary_op_flops(graph, node) + + +@ops.RegisterStatistics("AssignSub", "flops") +def _assign_sub_flops(graph, node): + """Compute flops for AssignSub operation.""" + return _unary_op_flops(graph, node) + + +@ops.RegisterStatistics("AssignAdd", "flops") +def _assign_add_flops(graph, node): + """Compute flops for AssignAdd operation.""" + return _unary_op_flops(graph, node) + + +@ops.RegisterStatistics("L2Loss", "flops") +def _l2_loss_flops(graph, node): + """Compute flops for L2Loss operation.""" + in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + in_shape.assert_is_fully_defined() + # Tensorflow uses inefficient implementation, with (3*N-1) flops: + # Optimal implementation is 2*N flops + return ops.OpStats("flops", in_shape.num_elements() * 3 - 1) + + +@ops.RegisterStatistics("Softmax", "flops") +def _softmax_flops(graph, node): + """Compute flops for Softmax operation.""" + # Softmax implenetation: + # + # Approximate flops breakdown: + # 2*n -- compute shifted logits + # n -- exp of shifted logits + # 2*n -- compute softmax from exp of shifted logits + return _unary_op_flops(graph, node, ops_per_element=5) + +################################################################################ +# Binary operations +################################################################################ + + +def _binary_per_element_op_flops(graph, node, ops_per_element=1): + """Common code which compute flops for binary operations.""" + out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) + out_shape.assert_is_fully_defined() + return ops.OpStats("flops", out_shape.num_elements() * ops_per_element) + + +@ops.RegisterStatistics("Add", "flops") +def _add_flops(graph, node): + """Compute flops for Add operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Sub", "flops") +def _sub_flops(graph, node): + """Compute flops for Sub operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Mul", "flops") +def _mul_flops(graph, node): + """Compute flops for Mul operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("RealDiv", "flops") +def _real_div_flops(graph, node): + """Compute flops for RealDiv operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Maximum", "flops") +def _maximum_flops(graph, node): + """Compute flops for Maximum operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Minimum", "flops") +def _minimum_flops(graph, node): + """Compute flops for Minimum operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Pow", "flops") +def _pow_flops(graph, node): + """Compute flops for Pow operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("RsqrtGrad", "flops") +def _rsqrt_grad_flops(graph, node): + """Compute flops for RsqrtGrad operation.""" + return _binary_per_element_op_flops(graph, node, ops_per_element=4) + + +@ops.RegisterStatistics("GreaterEqual", "flops") +def _greater_equal_flops(graph, node): + """Compute flops for GreaterEqual operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Greater", "flops") +def _greater_flops(graph, node): + """Compute flops for Greater operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("LessEqual", "flops") +def _less_equal_flops(graph, node): + """Compute flops for LessEqual operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Less", "flops") +def _less_flops(graph, node): + """Compute flops for Less operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("Equal", "flops") +def _equal_flops(graph, node): + """Compute flops for Equal operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("NotEqual", "flops") +def _not_equal_flops(graph, node): + """Compute flops for NotEqual operation.""" + return _binary_per_element_op_flops(graph, node) + + +@ops.RegisterStatistics("SquaredDifference", "flops") +def _squared_difference_flops(graph, node): + """Compute flops for SquaredDifference operation.""" + return _binary_per_element_op_flops(graph, node, ops_per_element=2) + +################################################################################ +# Reduction ops +################################################################################ + + +def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0): + """Common code which compute flops for reduction operations.""" + in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + in_shape.assert_is_fully_defined() + out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) + out_shape.assert_is_fully_defined() + num_flops = (in_shape.num_elements() * reduce_flops + + out_shape.num_elements() * (finalize_flops - reduce_flops)) + return ops.OpStats("flops", num_flops) + + +@ops.RegisterStatistics("Mean", "flops") +def _mean_flops(graph, node): + """Compute flops for Mean operation.""" + # reduction - sum, finalization - divide + return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1) + + +@ops.RegisterStatistics("Sum", "flops") +def _sum_flops(graph, node): + """Compute flops for Sum operation.""" + # reduction - sum, no finalization + return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) + + +@ops.RegisterStatistics("ArgMax", "flops") +def _arg_max_flops(graph, node): + """Compute flops for ArgMax operation.""" + # reduction - comparison, no finalization + return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) + + +@ops.RegisterStatistics("ArgMin", "flops") +def _arg_min_flops(graph, node): + """Compute flops for ArgMin operation.""" + # reduction - comparison, no finalization + return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) + + +@ops.RegisterStatistics("BiasAddGrad", "flops") +def _bias_add_grad_flops(graph, node): + """Compute flops for BiasAddGrad operation.""" + # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping: + # So computing flops same way as for "Sum" + return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0) + +################################################################################ +# Convolution and pooling +# Note: all flops statistics are implemented only for NHWC data format +################################################################################ + + +def _verify_conv_data_format(node): + """Verifies data format for pooling and convolutional operations.""" + # TODO(xpan): P1: Support NCHW + if node.attr["data_format"].s != b"NHWC": + raise ValueError("Only NHWC format is supported in flops computations") + + +def _pool_flops(graph, node): + """Common code which compute flops for pooling operations.""" + # compute flops for average and max pooling + _verify_conv_data_format(node) + # + # Pooling declaration: + # Inputs: + # - value + # Outputs: + # - output + # Attributes: + # - ksize + # - strides + # - padding + # - data_format + # + # Pooling implenetation: + out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) + out_shape.assert_is_fully_defined() + kernel_shape = list(node.attr["ksize"].list.i) + kernel_area = _list_product(kernel_shape) + return ops.OpStats("flops", kernel_area * out_shape.num_elements()) + + +@ops.RegisterStatistics("AvgPool", "flops") +def _avg_pool_flops(graph, node): + """Compute flops for AvgPool operation.""" + return _pool_flops(graph, node) + + +@ops.RegisterStatistics("MaxPool", "flops") +def _max_pool_flops(graph, node): + """Compute flops for MaxPool operation.""" + return _pool_flops(graph, node) + + +@ops.RegisterStatistics("AvgPoolGrad", "flops") +def _avg_pool_grad_flops(graph, node): + """Compute flops for AvgPoolGrad operation.""" + _verify_conv_data_format(node) + # Pooling gradient implementation: + out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph, + node.input[1]) + out_backprop_shape.assert_is_fully_defined() + kernel_shape = list(node.attr["ksize"].list.i) + kernel_area = _list_product(kernel_shape) + # TensorFlow multiply each element of pooling window by coefficient, + # then sum up all of them, thus we have 2 flops per element: + # More optimal implementation - if division is done after. + return ops.OpStats("flops", + kernel_area * out_backprop_shape.num_elements() * 2) + + +@ops.RegisterStatistics("MaxPoolGrad", "flops") +def _max_pool_grad_flops(graph, node): + """Compute flops for MaxPoolGrad operation.""" + _verify_conv_data_format(node) + # + # MaxPoolGrad declaration: + # Inputs: + # - orig_input -- original input tensor (of max_pool) + # - orig_output -- original output tensor (of max_pool) + # - grad -- gradient with respect to output of max_pool + # Outputs: + # - output -- gradient with respect to input of max_pool + # Attributes: + # - ksize + # - strides + # - padding + # - data_format + # It computes MaxPool first, then one flop per each element of original output + # + kernel_shape = list(node.attr["ksize"].list.i) + kernel_area = _list_product(kernel_shape) + orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph, + node.input[1]) + max_pool_ops = kernel_area * orig_out_shape.num_elements() + return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements()) + + +@ops.RegisterStatistics("Conv2DBackpropInput", "flops") +def _conv_2d_backprop_input_flops(graph, node): + """Compute flops for Conv2DBackpropInput operation.""" + # Formula: + # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim + # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) + # + # Where: + # image_x_dim, image_y_dim and input_depth --- size of input to source (no + # backprop) convolution, in other words they are sizes of backprop output. + # output_depth --- number of filters in the original convolution, thus + # depth of backprop input. + # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension + # image_x_stride and image_x_stride --- strides of the convolution + # + _verify_conv_data_format(node) + # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth] + out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) + out_shape.assert_is_fully_defined() + # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] + kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, + node.input[1]) + kernel_shape.assert_is_fully_defined() + # strides + strides_shape = list(node.attr["strides"].list.i) + strides_product = strides_shape[1] * strides_shape[2] + return ops.OpStats("flops", + (2 * out_shape.num_elements() + * kernel_shape.num_elements() + / (out_shape[-1].value * strides_product))) + + +@ops.RegisterStatistics("Conv2DBackpropFilter", "flops") +def _conv_2d_backprop_filter_flops(graph, node): + """Compute flops for Conv2DBackpropFilter operation.""" + # Formula same as for Conv2DBackpropInput: + # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim + # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride) + # + _verify_conv_data_format(node) + # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth] + image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + image_shape.assert_is_fully_defined() + # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth] + kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name) + kernel_shape.assert_is_fully_defined() + # strides + strides_shape = list(node.attr["strides"].list.i) + strides_product = strides_shape[1] * strides_shape[2] + return ops.OpStats("flops", + (2 * image_shape.num_elements() + * kernel_shape.num_elements() + / (image_shape[-1].value * strides_product))) + +################################################################################ +# Other ops +################################################################################ + + +@ops.RegisterStatistics("AddN", "flops") +def _add_n_flops(graph, node): + """Compute flops for AddN operation.""" + if not node.input: + return _zero_flops(graph, node) + in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0]) + in_shape.assert_is_fully_defined() + return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1)) |