From 9b336b4a33158061535fd6ba4973605248055b69 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 4 May 2017 13:05:05 -0800 Subject: Open sourced op level cost prediction Change: 155123817 --- tensorflow/core/grappler/costs/BUILD | 27 + .../core/grappler/costs/op_level_cost_estimator.cc | 554 +++++++++++++++++++++ .../core/grappler/costs/op_level_cost_estimator.h | 143 ++++++ .../grappler/costs/op_level_cost_estimator_test.cc | 113 +++++ tensorflow/core/grappler/costs/utils.cc | 4 +- 5 files changed, 840 insertions(+), 1 deletion(-) create mode 100644 tensorflow/core/grappler/costs/op_level_cost_estimator.cc create mode 100644 tensorflow/core/grappler/costs/op_level_cost_estimator.h create mode 100644 tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc (limited to 'tensorflow/core/grappler/costs') diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 22f4708d03..372092f42a 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -111,6 +111,7 @@ cc_library( name = "utils", srcs = ["utils.cc"], hdrs = ["utils.h"], + defines = if_cuda(["GOOGLE_CUDA=1"]), visibility = ["//visibility:public"], deps = [ ":op_performance_data_cc", @@ -167,3 +168,29 @@ cc_library( "//tensorflow/core/kernels:ops_util", ], ) + +cc_library( + name = "op_level_cost_estimator", + srcs = ["op_level_cost_estimator.cc"], + hdrs = ["op_level_cost_estimator.h"], + visibility = ["//visibility:public"], + deps = [ + ":cost_estimator", + ":op_performance_data_cc", + ":utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "op_level_cost_estimator_test", + srcs = ["op_level_cost_estimator_test.cc"], + deps = [ + ":op_level_cost_estimator", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc new file mode 100644 index 0000000000..baed7a8899 --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -0,0 +1,554 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/grappler/costs/utils.h" + +namespace tensorflow { +namespace grappler { + +constexpr int kOpsPerMac = 2; +constexpr char kConv2d[] = "Conv2D"; +constexpr char kConv2dBackPropFilter[] = "Conv2DBackpropFilter"; +constexpr char kConv2dBackPropInput[] = "Conv2DBackpropInput"; +constexpr char kMatMul[] = "MatMul"; +constexpr char kSparseMatMul[] = "SparseMatMul"; +constexpr char kIdentity[] = "Identity"; +constexpr char kNoOp[] = "NoOp"; +constexpr char kReshape[] = "Reshape"; + +OpLevelCostEstimator::OpLevelCostEstimator() { + // Syntactic sugar to build and return a lambda that takes an OpInfo and + // returns a cost. + typedef Costs (OpLevelCostEstimator::*CostImpl)(const OpInfo& op_feature) + const; + auto wrap = [this](CostImpl impl) -> std::function { + return [this, impl](const OpInfo& op) { return (this->*impl)(op); }; + }; + + device_cost_impl_ = { + {kConv2d, wrap(&OpLevelCostEstimator::PredictConv2D)}, + {kConv2dBackPropFilter, + wrap(&OpLevelCostEstimator::PredictConv2DBackPropFilter)}, + {kConv2dBackPropInput, + wrap(&OpLevelCostEstimator::PredictConv2DBackPropInput)}, + {kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, + {kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)}, + {kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)}, + {kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)}}; +} + +Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const { + auto it = device_cost_impl_.find(op_features.op()); + if (it == device_cost_impl_.end()) { + VLOG(1) << "Missing implementation for op: " << op_features.op(); + Costs costs; + costs = DummyExecutionTime(op_features); + return costs; + } + + std::function estimator = it->second; + Costs costs = estimator(op_features); + VLOG(1) << "Operation " << op_features.op() << " takes " + << costs.execution_time.count() << " ns."; + return costs; +} + +std::pair OpLevelCostEstimator::GetDeviceInfo( + const OpInfo::DeviceProperties& device) const { + double gflops = -1; + double bandwidth = -1; + if (device.bandwidth() > 0) { + bandwidth = device.bandwidth() / 1e6; + } + + if (device.type() == "CPU") { + const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo(); + // Check if vector instructions are available, and refine performance + // prediction based on this. + gflops = local_cpu.num_cores() * local_cpu.frequency(); + if (bandwidth < 0) { + if (local_cpu.bandwidth() > 0) { + bandwidth = local_cpu.bandwidth() / 1e6; + } else { + bandwidth = 32; + } + } + } else if (device.type() == "GPU") { + const OpInfo::DeviceProperties local_gpu = GetLocalGPUInfo(0); + const string architecture = local_gpu.environment().at("architecture"); + int cores_per_multiprocessor; + if (architecture < "3") { + // Fermi + cores_per_multiprocessor = 32; + } else if (architecture < "4") { + // Kepler + cores_per_multiprocessor = 192; + } else if (architecture < "6") { + // Maxwell + cores_per_multiprocessor = 128; + } else { + // Pascal. + cores_per_multiprocessor = 64; + } + gflops = local_gpu.num_cores() * local_gpu.frequency() * + cores_per_multiprocessor * kOpsPerMac; + if (bandwidth < 0) { + CHECK(local_gpu.bandwidth() > 0); + bandwidth = local_gpu.bandwidth() / 1e6; + } + } + + return std::make_pair(gflops, bandwidth); +} + +Costs OpLevelCostEstimator::DummyExecutionTime( + const OpInfo& op_features) const { + Costs costs = PredictOpCountBasedCost(0, op_features); + costs.inaccurate = true; + return costs; +} + +Costs OpLevelCostEstimator::PredictOpCountBasedCost( + double operations, const OpInfo& op_features) const { + std::pair device_perf = GetDeviceInfo(op_features.device()); + Costs::NanoSeconds compute_cost(operations / device_perf.first); + VLOG(1) << "Op:" << op_features.op() << " GOps:" << operations / 1e9 + << " Execution Time (ns):" << compute_cost.count(); + + bool found_unknown_shapes = false; + double total_input_size = + CalculateInputSize(op_features, &found_unknown_shapes); + double total_output_size = + CalculateOutputSize(op_features, &found_unknown_shapes); + double total_io_size = total_input_size + total_output_size; + + Costs::NanoSeconds memory_cost(total_io_size / device_perf.second); + VLOG(1) << "Op:" << op_features.op() << " Size (KB):" << (total_io_size) / 1e3 + << " Memory Time (ns):" << memory_cost.count(); + + Costs costs; + costs.compute_time = compute_cost; + costs.memory_time = memory_cost; + costs.execution_time = compute_cost + memory_cost; + costs.inaccurate = found_unknown_shapes; + return costs; +} + +int64 OpLevelCostEstimator::CountConv2DOperations( + const OpInfo& op_features, bool* found_unknown_shapes) const { + return CountConv2DOperations(op_features, nullptr, found_unknown_shapes); +} + +namespace { + +string GetDataFormat(const OpInfo& op_features) { + string data_format = "NHWC"; // Default format. + if (op_features.attr().find("data_format") != op_features.attr().end()) { + data_format = op_features.attr().at("data_format").s(); + } + return data_format; +} + +Padding GetPadding(const OpInfo& op_features) { + if (op_features.attr().find("padding") != op_features.attr().end() && + op_features.attr().at("padding").s() == "VALID") { + return Padding::VALID; + } + return Padding::SAME; // Default padding. +} + +std::vector GetStrides(const OpInfo& op_features) { + if (op_features.attr().find("strides") != op_features.attr().end()) { + const auto strides = op_features.attr().at("strides").list().i(); + return {strides[0], strides[1], strides[2], strides[3]}; + } + return {1, 1, 1, 1}; +} + +int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, + const Padding& padding) { + // Logic for calculating output shape is from GetWindowedOutputSizeVerbose() + // function in third_party/tensorflow/core/framework/common_shape_fns.cc. + if (padding == Padding::VALID) { + return (input - filter + stride) / stride; + } else { // SAME. + return (input + stride - 1) / stride; + } +} + +// Return a minimum shape if the shape is unknown. If known, return the original +// shape. +TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, + int rank, bool* found_unknown_shapes) { + auto shape = original_shape; + if (shape.unknown_rank()) { + *found_unknown_shapes = true; + } + if (shape.unknown_rank() || shape.dim_size() == 0) { + TensorShapeProto::Dim dim; + VLOG(1) << "WARNING: Use minimum shape because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + dim.set_size(1); + for (int i = 0; i < rank; i++) { + *shape.add_dim() = dim; + } + } else { + CHECK_EQ(shape.dim_size(), rank); + for (int i = 0; i < rank; i++) { + if (shape.dim(i).size() == -1) { + *found_unknown_shapes = true; + VLOG(1) + << "WARNING: Use minimum dim size 1 because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + shape.mutable_dim(i)->set_size(1); + } + } + } + return shape; +} +} // namespace + +// Helper to translate the positional arguments into named fields. +OpLevelCostEstimator::ConvolutionDimensions +OpLevelCostEstimator::ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_features, + bool* found_unknown_shapes) { + auto image_shape = + MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes); + auto filter_shape = + MaybeGetMinimumShape(original_filter_shape, 4, found_unknown_shapes); + + int x_index, y_index, channel_index; + const string& data_format = GetDataFormat(op_features); + if (data_format == "NCHW") { + x_index = 2; + y_index = 3; + channel_index = 1; + } else { + x_index = 1; + y_index = 2; + channel_index = 3; + } + int64 batch = image_shape.dim(0).size(); + int64 ix = image_shape.dim(x_index).size(); + int64 iy = image_shape.dim(y_index).size(); + int64 iz = image_shape.dim(channel_index).size(); + int64 kx = filter_shape.dim(0).size(); + int64 ky = filter_shape.dim(1).size(); + std::vector strides = GetStrides(op_features); + const auto padding = GetPadding(op_features); + int64 sx = strides[x_index]; + int64 sy = strides[y_index]; + int64 ox = GetOutputSize(ix, kx, sx, padding); + int64 oy = GetOutputSize(iy, ky, sy, padding); + int64 oz = filter_shape.dim(3).size(); + // Only check equality when both sizes are known (in other words, when + // neither is set to a minimum dimension size of 1). + if (iz != 1 && filter_shape.dim(2).size() != 1) { + CHECK_EQ(iz, filter_shape.dim(2).size()); + } else { + iz = std::max(iz, filter_shape.dim(2).size()); + } + OpLevelCostEstimator::ConvolutionDimensions conv_dims = { + batch, ix, iy, iz, kx, ky, oz, ox, oy, sx, sy, padding}; + + VLOG(1) << "Batch Size:" << batch; + VLOG(1) << "Image Dims:" << ix << "," << iy; + VLOG(1) << "Input Features:" << iz; + VLOG(1) << "Kernel Dims:" << kx << "," << ky; + VLOG(1) << "Output Features:" << oz; + VLOG(1) << "Output Dims:" << ox << "," << oy; + VLOG(1) << "Strides:" << sx << "," << sy; + VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME"); + return conv_dims; +} + +int64 OpLevelCostEstimator::CountConv2DOperations( + const OpInfo& op_features, ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const { + if (op_features.op() != kConv2d) { + LOG(ERROR) << "Invalid Operation"; + return 0; + } + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + op_features.inputs(0).shape(), op_features.inputs(1).shape(), op_features, + found_unknown_shapes); + + int64 ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + VLOG(1) << "Operations for Conv2D" << ops; + + if (conv_info != nullptr) { + *conv_info = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CountMatMulOperations( + const OpInfo& op_features, bool* found_unknown_shapes) const { + return CountMatMulOperations(op_features, nullptr, found_unknown_shapes); +} + +int64 OpLevelCostEstimator::CountMatMulOperations( + const OpInfo& op_features, MatMulDimensions* mat_mul, + bool* found_unknown_shapes) const { + double ops = 0; + + // TODO(nishantpatil): Create separate estimator for Sparse Matmul + if ((op_features.op() != kMatMul) && (op_features.op() != kSparseMatMul)) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + // first matrix + auto& a_matrix = op_features.inputs(0); + auto& b_matrix = op_features.inputs(1); + + bool transpose_a = false; + bool transpose_b = false; + + double m_dim, n_dim, k_dim, k_dim_b = 0; + + for (const auto& item : op_features.attr()) { + VLOG(1) << "Key:" << item.first + << " Value:" << SummarizeAttrValue(item.second); + if (item.first == "transpose_a" && item.second.b() == true) + transpose_a = true; + if (item.first == "transpose_b" && item.second.b() == true) + transpose_b = true; + } + VLOG(1) << "transpose_a:" << transpose_a; + VLOG(1) << "transpose_b:" << transpose_b; + auto a_matrix_shape = + MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes); + auto b_matrix_shape = + MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes); + if (transpose_a) { + m_dim = a_matrix_shape.dim(1).size(); + k_dim = a_matrix_shape.dim(0).size(); + } else { + m_dim = a_matrix_shape.dim(0).size(); + k_dim = a_matrix_shape.dim(1).size(); + } + if (transpose_b) { + k_dim_b = b_matrix_shape.dim(1).size(); + n_dim = b_matrix_shape.dim(0).size(); + } else { + k_dim_b = b_matrix_shape.dim(0).size(); + n_dim = b_matrix_shape.dim(1).size(); + } + + VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim; + // Only check equality when both sizes are known (in other words, when + // neither is set to a minimum dimension size of 1). + if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) { + LOG(ERROR) << "Incompatible Matrix dimensions"; + return ops; + } else { + // One of k_dim and k_dim_b might be 1 (mininum dimension size). + k_dim = std::max(k_dim, k_dim_b); + } + + ops = m_dim * n_dim * k_dim * 2; + VLOG(1) << "Operations for Matmul" << ops; + + if (mat_mul != nullptr) { + mat_mul->m = m_dim; + mat_mul->n = n_dim; + mat_mul->k = k_dim; + } + return ops; +} + +// TODO(cliffy): Dedup this method and CountConv2DBackPropFilterOperations. +int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations( + const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes) const { + int64 ops = 0; + + if (op_features.op() != kConv2dBackPropInput) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + // Need _output_shapes for input shape. + LOG(ERROR) << "No output shape in Conv2DBackPropInput op feaure."; + return ops; + } + + const auto& input_shape = + op_features.attr().at("_output_shapes").list().shape(0); + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + input_shape, op_features.inputs(1).shape(), op_features, + found_unknown_shapes); + + ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + + VLOG(1) << "Operations for Conv2DBackPropInput" << ops; + + if (returned_conv_dims != nullptr) { + *returned_conv_dims = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations( + const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, + bool* found_unknown_shapes) const { + int64 ops = 0; + if (op_features.op() != kConv2dBackPropFilter) { + LOG(ERROR) << "Invalid Operation"; + return ops; + } + + if (op_features.attr().find("_output_shapes") == op_features.attr().end()) { + // Need _output_shapes for filter shape. + LOG(ERROR) << "No output shape in Conv2DBackPropFilter op feaure."; + return ops; + } + + const auto& filter_shape = + op_features.attr().at("_output_shapes").list().shape(0); + ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( + op_features.inputs(0).shape(), filter_shape, op_features, + found_unknown_shapes); + + ops = conv_dims.batch; + ops *= conv_dims.ox * conv_dims.oy; + ops *= conv_dims.kx * conv_dims.ky; + ops *= conv_dims.iz * conv_dims.oz; + ops *= kOpsPerMac; + + VLOG(1) << "Operations for Conv2DBackPropFilter" << ops; + + if (returned_conv_dims != nullptr) { + *returned_conv_dims = conv_dims; + } + return ops; +} + +int64 OpLevelCostEstimator::CalculateSingleInputSize( + const OpInfo::TensorProperties& input, bool* found_unknown_shapes) const { + VLOG(1) << " with " << input.dtype() << " input of shape " + << input.shape().DebugString(); + int64 input_size = 1; + int num_dims = std::max(1, input.shape().dim_size()); + auto input_shape = + MaybeGetMinimumShape(input.shape(), num_dims, found_unknown_shapes); + for (const auto& dim : input_shape.dim()) { + input_size *= dim.size(); + } + return input_size * DataTypeSize(input.dtype()); +} + +int64 OpLevelCostEstimator::CalculateInputSize( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 total_input_size = 0; + for (auto& input : op_features.inputs()) { + int64 input_size = CalculateSingleInputSize(input, found_unknown_shapes); + total_input_size += input_size; + VLOG(1) << "Input Size: " << input_size + << " Total Input Size:" << total_input_size; + } + return total_input_size; +} + +int64 OpLevelCostEstimator::CalculateOutputSize( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 total_output_size = 0; + // use float as default for calculations + DataType dt = DT_FLOAT; + for (const auto& item : op_features.attr()) { + VLOG(1) << "Key:" << item.first + << " Value:" << SummarizeAttrValue(item.second); + if (item.first == "_output_shapes") { + for (const auto& original_output_shape : item.second.list().shape()) { + int64 output_size = 1; + int num_dims = std::max(1, original_output_shape.dim_size()); + auto output_shape = MaybeGetMinimumShape( + original_output_shape, num_dims, found_unknown_shapes); + for (const auto& dim : output_shape.dim()) { + output_size *= dim.size(); + } + output_size *= DataTypeSize(dt); + total_output_size += output_size; + VLOG(1) << "Output Size: " << output_size + << " Total Output Size:" << total_output_size; + } + } + if (item.first == "T") { + dt = item.second.type(); + } + } + return total_output_size; +} + +Costs OpLevelCostEstimator::PredictConv2D(const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = PredictOpCountBasedCost( + CountConv2DOperations(op_features, &found_unknown_shapes), op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictConv2DBackPropInput( + const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = + PredictOpCountBasedCost(CountConv2DBackPropInputOperations( + op_features, nullptr, &found_unknown_shapes), + op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictConv2DBackPropFilter( + const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = + PredictOpCountBasedCost(CountConv2DBackPropFilterOperations( + op_features, nullptr, &found_unknown_shapes), + op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictMatMul(const OpInfo& op_features) const { + bool found_unknown_shapes = false; + auto costs = PredictOpCountBasedCost( + CountMatMulOperations(op_features, &found_unknown_shapes), op_features); + costs.inaccurate = found_unknown_shapes; + return costs; +} + +Costs OpLevelCostEstimator::PredictNoOp(const OpInfo& op_features) const { + VLOG(1) << "Op:" << op_features.op() << " Execution Time 0 (ns)"; + return Costs::ZeroCosts(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h new file mode 100644 index 0000000000..5bb20cc6bb --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -0,0 +1,143 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ + +#include +#include +#include + +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/grappler/costs/cost_estimator.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/util/padding.h" + +namespace tensorflow { +namespace grappler { + +class OpLevelCostEstimator { + public: + OpLevelCostEstimator(); + virtual ~OpLevelCostEstimator() {} + + Costs PredictCosts(const OpInfo& op_features) const; + + protected: + // Returns an estimate of device performance (in billions of operations + // executed per second) and memory bandwith (in GigaBytes/second) for the + // specified device. + virtual std::pair GetDeviceInfo( + const OpInfo::DeviceProperties& device) const; + + // For operations for which we haven't yet built estimates, returns a dummy + // value based on input size. + Costs DummyExecutionTime(const OpInfo& op_features) const; + + // Naive cost estimate based on operations divided by device ops/sec. + Costs PredictOpCountBasedCost(double operations, + const OpInfo& op_features) const; + + // This family of routines counts the number of operations to perform the + // specified TensorFlow Op. + struct MatMulDimensions { + int m; + int n; + int k; + }; + struct ConvolutionDimensions { + int64 batch; // Batch size. + int64 ix; // Input size x. + int64 iy; // Input size y. + int64 iz; // Input depth. + int64 kx; // Kernel x. + int64 ky; // Kernel y. + int64 oz; // Output depth. + int64 ox; // Output size x. + int64 oy; // Output size y. + int64 sx; // Stride x. + int64 sy; // Stride y. + Padding padding; // SAME or VALID. + }; + int64 CountConv2DOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; + int64 CountConv2DOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + int64 CountMatMulOperations(const OpInfo& op_features, + bool* found_unknown_shapes) const; + int64 CountMatMulOperations(const OpInfo& op_features, + MatMulDimensions* mat_mul, + bool* found_unknown_shapes) const; + int64 CountConv2DBackPropInputOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + int64 CountConv2DBackPropFilterOperations(const OpInfo& op_features, + ConvolutionDimensions* conv_info, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of a single input to a TensorFlow op. + int64 CalculateSingleInputSize(const OpInfo::TensorProperties& input, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of the all + // the inputs of specified TensorFlow Op + int64 CalculateInputSize(const OpInfo& op_features, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of the all + // the outputs of specified TensorFlow Op + int64 CalculateOutputSize(const OpInfo& op_features, + bool* found_unknown_shapes) const; + + // This family of routines predicts the costs to + // perform the specified TensorFlow Op on the + // device represented by a subclass. The default + // implementation just divides the operations to + // perform the op (from the "Count" routines, + // above) by the device peak operations per + // second. Override to supply a better estimate. + // Implementation of costs other than + // execution_time is optional, depending on the + // device. + Costs PredictConv2D(const OpInfo& op_features) const; + Costs PredictConv2DBackPropInput(const OpInfo& op_features) const; + Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const; + Costs PredictMatMul(const OpInfo& op_features) const; + Costs PredictNoOp(const OpInfo& op_features) const; + + // Utility function for safe division. Returns 0 + // if rhs is 0 or negative. + static double SafeDiv(const double lhs, const double rhs) { + if (rhs > 0) { + return lhs / rhs; + } else { + return 0.0; + } + } + + static ConvolutionDimensions ConvolutionDimensionsFromInputs( + const TensorShapeProto& original_image_shape, + const TensorShapeProto& original_filter_shape, const OpInfo& op_features, + bool* found_unknown_shapes); + + private: + typedef std::function CostImpl; + std::map device_cost_impl_; +}; + +} // end namespace grappler +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_ diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc new file mode 100644 index 0000000000..e0b0348c8e --- /dev/null +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +namespace { +// Wrangles the minimum number of proto fields to set up a matrix. +void DescribeMatrix(int rows, int columns, OpInfo *op_features) { + auto input = op_features->add_inputs(); + auto shape = input->mutable_shape(); + auto shape_rows = shape->add_dim(); + shape_rows->set_size(rows); + auto shape_columns = shape->add_dim(); + shape_columns->set_size(columns); + input->set_dtype(DT_FLOAT); +} + +// Returns an OpInfo for MatMul with the minimum set of fields set up. +OpInfo DescribeMatMul(int m, int n, int l, int k) { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("MatMul"); + + DescribeMatrix(m, l, &op_features); + DescribeMatrix(k, n, &op_features); + return op_features; +} + +// Returns an OpInfo for MatMul with unknown input shapes. +OpInfo DescribeMatMulUnknownShape() { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("MatMul"); + + auto input = op_features.add_inputs(); + auto shape = input->mutable_shape(); + shape->set_unknown_rank(true); + + input = op_features.add_inputs(); + shape = input->mutable_shape(); + shape->set_unknown_rank(true); + + return op_features; +} + +// Wrangles the minimum number of proto fields to set up a 4D Tensor for cost +// estimation purposes. +void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, + OpInfo *op_features) { + auto input = op_features->add_inputs(); + auto shape = input->mutable_shape(); + shape->add_dim()->set_size(dim0); + shape->add_dim()->set_size(dim1); + shape->add_dim()->set_size(dim2); + shape->add_dim()->set_size(dim3); +} + +// Returns an OpInfo for Conv2D with the minimum set of fields set up. +OpInfo DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx, + int ky, int oz) { + OpInfo op_features; + auto device = op_features.mutable_device(); + device->set_type("CPU"); + op_features.set_op("Conv2D"); + + DescribeTensor4D(batch, ix, iy, iz1, &op_features); + DescribeTensor4D(kx, ky, iz2, oz, &op_features); + return op_features; +} +} // namespace + +TEST(OpLevelCostEstimatorTest, UnknownOrPartialShape) { + OpLevelCostEstimator estimator; + + EXPECT_EQ(false, + estimator.PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate); + EXPECT_EQ(true, + estimator.PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate); + EXPECT_EQ(true, + estimator.PredictCosts(DescribeMatMul(2, 4, -1, 7)).inaccurate); + + EXPECT_EQ( + false, + estimator.PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256)) + .inaccurate); + EXPECT_EQ( + true, + estimator.PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256)) + .inaccurate); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 4e35de9d4a..0852cb4fd3 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() { // Combine cpu family and model into the model string. device.set_model( strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum())); - device.set_frequency(port::NominalCPUFrequency()); + device.set_frequency(port::NominalCPUFrequency() * 1e-9); device.set_num_cores(port::NumSchedulableCPUs()); device.set_l1_cache_size(Eigen::l1CacheSize()); device.set_l2_cache_size(Eigen::l2CacheSize()); @@ -195,6 +195,8 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) { properties.memoryClockRate * 2); } + (*device.mutable_environment())["architecture"] = + strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); #endif -- cgit v1.2.3