aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/profiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-11 11:16:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-11 11:20:24 -0700
commit509372c2ee09d26f21e7040defa45bf7c63ce4c1 (patch)
tree0b785f14ef13a11d056eb29c553f9656f6073ce4 /tensorflow/python/profiler
parent80ed8afc02b16adb209bffeb551e7f0c435985f6 (diff)
Add a lot of operations' flops calculations
PiperOrigin-RevId: 168256746
Diffstat (limited to 'tensorflow/python/profiler')
-rw-r--r--tensorflow/python/profiler/BUILD1
-rw-r--r--tensorflow/python/profiler/internal/BUILD10
-rw-r--r--tensorflow/python/profiler/internal/flops_registry.py446
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py9
-rw-r--r--tensorflow/python/profiler/tfprof_logger.py1
5 files changed, 463 insertions, 4 deletions
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index 8dd2413661..26cc5f0b74 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -82,6 +82,7 @@ py_library(
"//tensorflow/core/profiler:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
+ "//tensorflow/python/profiler/internal:flops_registry",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/profiler/internal/BUILD b/tensorflow/python/profiler/internal/BUILD
index f9cc8af19d..dcac070a3f 100644
--- a/tensorflow/python/profiler/internal/BUILD
+++ b/tensorflow/python/profiler/internal/BUILD
@@ -8,6 +8,16 @@ load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
py_library(
+ name = "flops_registry",
+ srcs = ["flops_registry.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:graph_util",
+ ],
+)
+
+py_library(
name = "model_analyzer_testlib",
srcs = ["model_analyzer_testlib.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/profiler/internal/flops_registry.py b/tensorflow/python/profiler/internal/flops_registry.py
new file mode 100644
index 0000000000..e143501049
--- /dev/null
+++ b/tensorflow/python/profiler/internal/flops_registry.py
@@ -0,0 +1,446 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Register flops statistics for various TensorFlow operations.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import ops
+
+
+# List of all ops which have implemented flops statistics.
+IMPLEMENTED_OPS = set([
+ # Unary ops
+ "Reciprocal", "Square", "Rsqrt", "Log", "Neg", "AssignSub", "AssignAdd",
+ "L2Loss", "Softmax",
+ # Binary ops
+ "Add", "Sub", "Mul", "RealDiv", "Maximum", "Minimum", "Pow", "RsqrtGrad",
+ "GreaterEqual", "Greater", "LessEqual", "Less", "Equal", "NotEqual",
+ "SquaredDifference",
+ # Reduction ops
+ "Mean", "Sum", "ArgMax", "ArgMin", "BiasAddGrad",
+ # Convolution and pooling
+ "AvgPool", "MaxPool", "AvgPoolGrad", "MaxPoolGrad", "Conv2DBackpropInput",
+ "Conv2DBackpropFilter",
+ # Other ops
+ "AddN",
+ # Ops implemented in core tensorflow:
+ "MatMul", "Conv2D", "DepthwiseConv2dNative", "BiasAdd", "Dilation2D",
+])
+
+
+def _zero_flops(graph, node):
+ """Returns zero flops."""
+ del graph, node # graph and node are unused
+ return ops.OpStats("flops", 0)
+
+
+def _list_product(lst):
+ """Computes product of element of the list."""
+ result = 1
+ for item in lst:
+ result *= item
+ return result
+
+################################################################################
+# Unary operations
+################################################################################
+
+
+def _unary_op_flops(graph, node, ops_per_element=1):
+ """Common code which compute flops for unary operations."""
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ return ops.OpStats("flops", in_shape.num_elements() * ops_per_element)
+
+
+@ops.RegisterStatistics("Reciprocal", "flops")
+def _reciprocal_flops(graph, node):
+ """Compute flops for Reciprocal operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Square", "flops")
+def _square_flops(graph, node):
+ """Compute flops for Square operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Rsqrt", "flops")
+def _rsqrt_flops(graph, node):
+ """Compute flops for Rsqrt operation."""
+ # Rsqrt(x) = 1 / sqrt(x)
+ return _unary_op_flops(graph, node, ops_per_element=2)
+
+
+@ops.RegisterStatistics("Log", "flops")
+def _log_flops(graph, node):
+ """Compute flops for Log operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Neg", "flops")
+def _neg_flops(graph, node):
+ """Compute flops for Neg operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("AssignSub", "flops")
+def _assign_sub_flops(graph, node):
+ """Compute flops for AssignSub operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("AssignAdd", "flops")
+def _assign_add_flops(graph, node):
+ """Compute flops for AssignAdd operation."""
+ return _unary_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("L2Loss", "flops")
+def _l2_loss_flops(graph, node):
+ """Compute flops for L2Loss operation."""
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ # Tensorflow uses inefficient implementation, with (3*N-1) flops:
+ # Optimal implementation is 2*N flops
+ return ops.OpStats("flops", in_shape.num_elements() * 3 - 1)
+
+
+@ops.RegisterStatistics("Softmax", "flops")
+def _softmax_flops(graph, node):
+ """Compute flops for Softmax operation."""
+ # Softmax implenetation:
+ #
+ # Approximate flops breakdown:
+ # 2*n -- compute shifted logits
+ # n -- exp of shifted logits
+ # 2*n -- compute softmax from exp of shifted logits
+ return _unary_op_flops(graph, node, ops_per_element=5)
+
+################################################################################
+# Binary operations
+################################################################################
+
+
+def _binary_per_element_op_flops(graph, node, ops_per_element=1):
+ """Common code which compute flops for binary operations."""
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ return ops.OpStats("flops", out_shape.num_elements() * ops_per_element)
+
+
+@ops.RegisterStatistics("Add", "flops")
+def _add_flops(graph, node):
+ """Compute flops for Add operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Sub", "flops")
+def _sub_flops(graph, node):
+ """Compute flops for Sub operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Mul", "flops")
+def _mul_flops(graph, node):
+ """Compute flops for Mul operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("RealDiv", "flops")
+def _real_div_flops(graph, node):
+ """Compute flops for RealDiv operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Maximum", "flops")
+def _maximum_flops(graph, node):
+ """Compute flops for Maximum operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Minimum", "flops")
+def _minimum_flops(graph, node):
+ """Compute flops for Minimum operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Pow", "flops")
+def _pow_flops(graph, node):
+ """Compute flops for Pow operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("RsqrtGrad", "flops")
+def _rsqrt_grad_flops(graph, node):
+ """Compute flops for RsqrtGrad operation."""
+ return _binary_per_element_op_flops(graph, node, ops_per_element=4)
+
+
+@ops.RegisterStatistics("GreaterEqual", "flops")
+def _greater_equal_flops(graph, node):
+ """Compute flops for GreaterEqual operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Greater", "flops")
+def _greater_flops(graph, node):
+ """Compute flops for Greater operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("LessEqual", "flops")
+def _less_equal_flops(graph, node):
+ """Compute flops for LessEqual operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Less", "flops")
+def _less_flops(graph, node):
+ """Compute flops for Less operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("Equal", "flops")
+def _equal_flops(graph, node):
+ """Compute flops for Equal operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("NotEqual", "flops")
+def _not_equal_flops(graph, node):
+ """Compute flops for NotEqual operation."""
+ return _binary_per_element_op_flops(graph, node)
+
+
+@ops.RegisterStatistics("SquaredDifference", "flops")
+def _squared_difference_flops(graph, node):
+ """Compute flops for SquaredDifference operation."""
+ return _binary_per_element_op_flops(graph, node, ops_per_element=2)
+
+################################################################################
+# Reduction ops
+################################################################################
+
+
+def _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0):
+ """Common code which compute flops for reduction operations."""
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ num_flops = (in_shape.num_elements() * reduce_flops
+ + out_shape.num_elements() * (finalize_flops - reduce_flops))
+ return ops.OpStats("flops", num_flops)
+
+
+@ops.RegisterStatistics("Mean", "flops")
+def _mean_flops(graph, node):
+ """Compute flops for Mean operation."""
+ # reduction - sum, finalization - divide
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=1)
+
+
+@ops.RegisterStatistics("Sum", "flops")
+def _sum_flops(graph, node):
+ """Compute flops for Sum operation."""
+ # reduction - sum, no finalization
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+
+@ops.RegisterStatistics("ArgMax", "flops")
+def _arg_max_flops(graph, node):
+ """Compute flops for ArgMax operation."""
+ # reduction - comparison, no finalization
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+
+@ops.RegisterStatistics("ArgMin", "flops")
+def _arg_min_flops(graph, node):
+ """Compute flops for ArgMin operation."""
+ # reduction - comparison, no finalization
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+
+@ops.RegisterStatistics("BiasAddGrad", "flops")
+def _bias_add_grad_flops(graph, node):
+ """Compute flops for BiasAddGrad operation."""
+ # Implementation of BiasAddGrad, essentially it's a reduce sum and reshaping:
+ # So computing flops same way as for "Sum"
+ return _reduction_op_flops(graph, node, reduce_flops=1, finalize_flops=0)
+
+################################################################################
+# Convolution and pooling
+# Note: all flops statistics are implemented only for NHWC data format
+################################################################################
+
+
+def _verify_conv_data_format(node):
+ """Verifies data format for pooling and convolutional operations."""
+ # TODO(xpan): P1: Support NCHW
+ if node.attr["data_format"].s != b"NHWC":
+ raise ValueError("Only NHWC format is supported in flops computations")
+
+
+def _pool_flops(graph, node):
+ """Common code which compute flops for pooling operations."""
+ # compute flops for average and max pooling
+ _verify_conv_data_format(node)
+ #
+ # Pooling declaration:
+ # Inputs:
+ # - value
+ # Outputs:
+ # - output
+ # Attributes:
+ # - ksize
+ # - strides
+ # - padding
+ # - data_format
+ #
+ # Pooling implenetation:
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ kernel_shape = list(node.attr["ksize"].list.i)
+ kernel_area = _list_product(kernel_shape)
+ return ops.OpStats("flops", kernel_area * out_shape.num_elements())
+
+
+@ops.RegisterStatistics("AvgPool", "flops")
+def _avg_pool_flops(graph, node):
+ """Compute flops for AvgPool operation."""
+ return _pool_flops(graph, node)
+
+
+@ops.RegisterStatistics("MaxPool", "flops")
+def _max_pool_flops(graph, node):
+ """Compute flops for MaxPool operation."""
+ return _pool_flops(graph, node)
+
+
+@ops.RegisterStatistics("AvgPoolGrad", "flops")
+def _avg_pool_grad_flops(graph, node):
+ """Compute flops for AvgPoolGrad operation."""
+ _verify_conv_data_format(node)
+ # Pooling gradient implementation:
+ out_backprop_shape = graph_util.tensor_shape_from_node_def_name(graph,
+ node.input[1])
+ out_backprop_shape.assert_is_fully_defined()
+ kernel_shape = list(node.attr["ksize"].list.i)
+ kernel_area = _list_product(kernel_shape)
+ # TensorFlow multiply each element of pooling window by coefficient,
+ # then sum up all of them, thus we have 2 flops per element:
+ # More optimal implementation - if division is done after.
+ return ops.OpStats("flops",
+ kernel_area * out_backprop_shape.num_elements() * 2)
+
+
+@ops.RegisterStatistics("MaxPoolGrad", "flops")
+def _max_pool_grad_flops(graph, node):
+ """Compute flops for MaxPoolGrad operation."""
+ _verify_conv_data_format(node)
+ #
+ # MaxPoolGrad declaration:
+ # Inputs:
+ # - orig_input -- original input tensor (of max_pool)
+ # - orig_output -- original output tensor (of max_pool)
+ # - grad -- gradient with respect to output of max_pool
+ # Outputs:
+ # - output -- gradient with respect to input of max_pool
+ # Attributes:
+ # - ksize
+ # - strides
+ # - padding
+ # - data_format
+ # It computes MaxPool first, then one flop per each element of original output
+ #
+ kernel_shape = list(node.attr["ksize"].list.i)
+ kernel_area = _list_product(kernel_shape)
+ orig_out_shape = graph_util.tensor_shape_from_node_def_name(graph,
+ node.input[1])
+ max_pool_ops = kernel_area * orig_out_shape.num_elements()
+ return ops.OpStats("flops", max_pool_ops + orig_out_shape.num_elements())
+
+
+@ops.RegisterStatistics("Conv2DBackpropInput", "flops")
+def _conv_2d_backprop_input_flops(graph, node):
+ """Compute flops for Conv2DBackpropInput operation."""
+ # Formula:
+ # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
+ # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
+ #
+ # Where:
+ # image_x_dim, image_y_dim and input_depth --- size of input to source (no
+ # backprop) convolution, in other words they are sizes of backprop output.
+ # output_depth --- number of filters in the original convolution, thus
+ # depth of backprop input.
+ # kernel_x_dim and kernel_y_dim --- sizes of filter in spatial dimension
+ # image_x_stride and image_x_stride --- strides of the convolution
+ #
+ _verify_conv_data_format(node)
+ # out_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
+ out_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ out_shape.assert_is_fully_defined()
+ # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
+ kernel_shape = graph_util.tensor_shape_from_node_def_name(graph,
+ node.input[1])
+ kernel_shape.assert_is_fully_defined()
+ # strides
+ strides_shape = list(node.attr["strides"].list.i)
+ strides_product = strides_shape[1] * strides_shape[2]
+ return ops.OpStats("flops",
+ (2 * out_shape.num_elements()
+ * kernel_shape.num_elements()
+ / (out_shape[-1].value * strides_product)))
+
+
+@ops.RegisterStatistics("Conv2DBackpropFilter", "flops")
+def _conv_2d_backprop_filter_flops(graph, node):
+ """Compute flops for Conv2DBackpropFilter operation."""
+ # Formula same as for Conv2DBackpropInput:
+ # batch_size * image_x_dim * image_y_dim * kernel_x_dim * kernel_y_dim
+ # * input_depth * output_depth * 2 / (image_x_stride * image_x_stride)
+ #
+ _verify_conv_data_format(node)
+ # image_shape = [batch_size, image_y_dim, image_x_dim, input_depth]
+ image_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ image_shape.assert_is_fully_defined()
+ # kernel_shape = [kernel_y_dim, kernel_x_dim, input_depth, output_depth]
+ kernel_shape = graph_util.tensor_shape_from_node_def_name(graph, node.name)
+ kernel_shape.assert_is_fully_defined()
+ # strides
+ strides_shape = list(node.attr["strides"].list.i)
+ strides_product = strides_shape[1] * strides_shape[2]
+ return ops.OpStats("flops",
+ (2 * image_shape.num_elements()
+ * kernel_shape.num_elements()
+ / (image_shape[-1].value * strides_product)))
+
+################################################################################
+# Other ops
+################################################################################
+
+
+@ops.RegisterStatistics("AddN", "flops")
+def _add_n_flops(graph, node):
+ """Compute flops for AddN operation."""
+ if not node.input:
+ return _zero_flops(graph, node)
+ in_shape = graph_util.tensor_shape_from_node_def_name(graph, node.input[0])
+ in_shape.assert_is_fully_defined()
+ return ops.OpStats("flops", in_shape.num_elements() * (len(node.input) - 1))
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 494ba2e2a0..81c628289e 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -35,6 +35,7 @@ from tensorflow.python.profiler import model_analyzer
from tensorflow.python.profiler import option_builder
from tensorflow.python.profiler import profile_context
from tensorflow.python.profiler.internal import model_analyzer_testlib as lib
+from tensorflow.python.util import compat
builder = option_builder.ProfileOptionBuilder
@@ -158,7 +159,7 @@ class PrintModelAnalysisTest(test.TestCase):
with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
- 'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/10.44k flops, _kTFScopeParent, --/8|--/36, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n DW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/0 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, 0/0|1/7, )\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0/0|1/6, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n _retval_Conv2D_1_0_0 (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|RunTimeOp, 1/1|1/1, )\n init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, 1/1|1/1, )\n',
+ 'node name | # parameters | # float_ops | assigned devices | op types | op count (run|defined) | input shapes\n_TFProfRoot (--/451 params, --/11.34k flops, _kTFScopeParent, --/8|--/36, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 1/1|1/1, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/324 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/324 flops, _kTFScopeParent, 0/0|1/7, )\n DW/Initializer/random_normal (0/0 params, 162/324 flops, Add, 0/0|1/6, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/mul (0/0 params, 162/162 flops, Mul, 0/0|1/1, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/576 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, 1/2|1/10, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/576 flops, _kTFScopeParent, 0/0|1/7, )\n DW2/Initializer/random_normal (0/0 params, 288/576 flops, Add, 0/0|1/6, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/mul (0/0 params, 288/288 flops, Mul, 0/0|1/1, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 1/1|1/1, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/2 flops, VariableV2|_trainable_variables, 0/0|1/10, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/2 flops, _kTFScopeParent, 0/0|1/7, )\n ScalarW/Initializer/random_normal (0/0 params, 1/2 flops, Add, 0/0|1/6, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0/0|1/1, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 1/1 flops, Mul, 0/0|1/1, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, 0/0|1/1, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0/0|1/1, 0:1)\n _retval_Conv2D_1_0_0 (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|RunTimeOp, 1/1|1/1, )\n init (0/0 params, 0/0 flops, NoOp, 0/0|1/1, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, 1/1|1/1, )\n',
f.read())
# pylint: enable=line-too-long
@@ -221,12 +222,12 @@ class PrintModelAnalysisTest(test.TestCase):
with gfile.Open(outfile, 'r') as f:
lines = f.read().split('\n')
result = '\n'.join([l[:min(len(l), 80)] for l in lines])
- self.assertEqual('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/91.04k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/41.76k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/0 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/0 flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/0 flo\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/16.51k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/32.77k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/0 flops)\n',
- result)
+ self.assertEqual(compat.as_bytes('node name | # parameters | # float_ops\n_TFProfRoot (--/2.84k params, --/168.85k flops)\n model_analyzer_testlib.py:63:BuildFullModel (0/1.80k params, 0/45.37k flops)\n model_analyzer_testlib.py:40:BuildSmallModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:44:BuildSmallModel (0/4 params, 0/8 flops)\n model_analyzer_testlib.py:48:BuildSmallModel (0/648 params, 0/1.30k flops)\n model_analyzer_testlib.py:49:BuildSmallModel (0/0 params, 0/23.33k flops)\n model_analyzer_testlib.py:53:BuildSmallModel (0/1.15k params, 0/2.30k flops)\n model_analyzer_testlib.py:54:BuildSmallModel (0/0 params, 0/18.43k flops)\n model_analyzer_testlib.py:63:BuildFullModel (gradient) (0/0 params, 0/67.39k f\n model_analyzer_testlib.py:49:BuildSmallModel (gradient) (0/0 params, 0/46.66\n model_analyzer_testlib.py:54:BuildSmallModel (gradient) (0/0 params, 0/20.74\n model_analyzer_testlib.py:67:BuildFullModel (0/1.04k params, 0/18.57k flops)\n model_analyzer_testlib.py:67:BuildFullModel (gradient) (0/0 params, 0/37.00k f\n model_analyzer_testlib.py:69:BuildFullModel (0/0 params, 0/0 flops)\n model_analyzer_testlib.py:70:BuildFullModel (0/0 params, 0/258 flops)\n model_analyzer_testlib.py:70:BuildFullModel (gradient) (0/0 params, 0/130 flop\n model_analyzer_testlib.py:72:BuildFullModel (0/0 params, 0/141 flops)\n'),
+ compat.as_bytes(result))
self.assertLess(0, tfprof_node.total_exec_micros)
self.assertEqual(2844, tfprof_node.total_parameters)
- self.assertEqual(91040, tfprof_node.total_float_ops)
+ self.assertEqual(168855, tfprof_node.total_float_ops)
self.assertEqual(8, len(tfprof_node.children))
self.assertEqual('_TFProfRoot', tfprof_node.name)
self.assertEqual(
diff --git a/tensorflow/python/profiler/tfprof_logger.py b/tensorflow/python/profiler/tfprof_logger.py
index 675eb98f8e..9020f60421 100644
--- a/tensorflow/python/profiler/tfprof_logger.py
+++ b/tensorflow/python/profiler/tfprof_logger.py
@@ -28,6 +28,7 @@ from tensorflow.core.profiler import tfprof_log_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
+from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import
TRAINABLE_VARIABLES = '_trainable_variables'
REGISTERED_FLOP_STATS = 'flops'