diff options
author | 2017-06-09 17:45:58 -0700 | |
---|---|---|
committer | 2017-06-09 17:50:06 -0700 | |
commit | 19b4ccd95fa667adf6240d24368022594f320e73 (patch) | |
tree | 077038a26ec4fbdf70d76cec519cbf3d81d5ed85 | |
parent | ef8b5fd1876ba2be3a10a363e121e6d2005bb480 (diff) |
Estimate cost for element wise ops and a minimum compute cost for dummy execution time.
PiperOrigin-RevId: 158585464
6 files changed, 356 insertions, 103 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index e36928a3fb..2b30facd84 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -232,6 +232,7 @@ cc_library( ":op_performance_data_cc", "//tensorflow/core:framework", "//tensorflow/core/grappler/clusters:utils", + "//third_party/eigen3", ], ) diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc index 9e3dd38b09..b1d04f4562 100644 --- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc @@ -102,7 +102,7 @@ TEST_F(AnalyticalCostEstimatorTest, SimpleTest) { Costs summary; TF_ASSERT_OK(estimator.PredictCosts(item.graph, &cost_graph, &summary)); - EXPECT_EQ(Costs::NanoSeconds(9108), summary.execution_time); + EXPECT_EQ(Costs::NanoSeconds(9136), summary.execution_time); EXPECT_FALSE(summary.inaccurate); } diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 75ff75123e..3668b7d7bf 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" + +#include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/grappler/clusters/utils.h" @@ -34,6 +36,105 @@ constexpr char kBatchMatMul[] = "BatchMatMul"; constexpr char kVariable[] = "Variable"; constexpr char kVariableV2[] = "VariableV2"; +namespace { + +string GetDataFormat(const OpInfo& op_features) { + string data_format = "NHWC"; // Default format. + if (op_features.attr().find("data_format") != op_features.attr().end()) { + data_format = op_features.attr().at("data_format").s(); + } + return data_format; +} + +Padding GetPadding(const OpInfo& op_features) { + if (op_features.attr().find("padding") != op_features.attr().end() && + op_features.attr().at("padding").s() == "VALID") { + return Padding::VALID; + } + return Padding::SAME; // Default padding. +} + +std::vector<int64> GetStrides(const OpInfo& op_features) { + if (op_features.attr().find("strides") != op_features.attr().end()) { + const auto strides = op_features.attr().at("strides").list().i(); + return {strides[0], strides[1], strides[2], strides[3]}; + } + return {1, 1, 1, 1}; +} + +int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, + const Padding& padding) { + // Logic for calculating output shape is from GetWindowedOutputSizeVerbose() + // function in third_party/tensorflow/core/framework/common_shape_fns.cc. + if (padding == Padding::VALID) { + return (input - filter + stride) / stride; + } else { // SAME. + return (input + stride - 1) / stride; + } +} + +// Return a minimum shape if the shape is unknown. If known, return the original +// shape. +TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, + int rank, bool* found_unknown_shapes) { + auto shape = original_shape; + if (shape.unknown_rank()) { + *found_unknown_shapes = true; + } + if (shape.unknown_rank() || shape.dim_size() == 0) { + TensorShapeProto::Dim dim; + VLOG(1) << "WARNING: Use minimum shape because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + dim.set_size(1); + for (int i = 0; i < rank; i++) { + *shape.add_dim() = dim; + } + } else { + CHECK_EQ(shape.dim_size(), rank); + for (int i = 0; i < rank; i++) { + if (shape.dim(i).size() == -1) { + *found_unknown_shapes = true; + VLOG(1) + << "WARNING: Use minimum dim size 1 because the shape is unknown."; + // The size of each dimension is at least 1, if unknown. + shape.mutable_dim(i)->set_size(1); + } + } + } + return shape; +} + +// Return the output element count of a binary element-wise op considering +// broadcasting. +int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1, + const TensorShapeProto& input_shape_2) { + bool found_unknown_shapes; + int rank = std::max(1, input_shape_1.dim_size()); + TensorShapeProto output_shape = + MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes); + + if (input_shape_1.dim_size() == input_shape_2.dim_size()) { + auto shape_1 = + MaybeGetMinimumShape(input_shape_1, rank, &found_unknown_shapes); + auto shape_2 = + MaybeGetMinimumShape(input_shape_2, rank, &found_unknown_shapes); + if (shape_1.dim_size() == shape_2.dim_size()) { + for (int i = 0; i < shape_1.dim_size(); i++) { + output_shape.mutable_dim(i)->set_size( + std::max(shape_1.dim(i).size(), shape_2.dim(i).size())); + } + } + } + + int64 count = 1; + for (int i = 0; i < output_shape.dim_size(); i++) { + count *= output_shape.dim(i).size(); + } + return count; +} + +} // namespace + OpLevelCostEstimator::OpLevelCostEstimator() { // Syntactic sugar to build and return a lambda that takes an OpInfo and // returns a cost. @@ -58,15 +159,114 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)}, {kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}}; + + elementwise_ops_ = { + // Unary ops alphabetically sorted + {"Acos", Eigen::internal::functor_traits< + Eigen::internal::scalar_acos_op<float>>::Cost}, + {"Asin", Eigen::internal::functor_traits< + Eigen::internal::scalar_asin_op<float>>::Cost}, + {"Atan", Eigen::internal::functor_traits< + Eigen::internal::scalar_atan_op<float>>::Cost}, + {"Atan2", Eigen::internal::functor_traits< + Eigen::internal::scalar_quotient_op<float>>::Cost + + Eigen::internal::functor_traits< + Eigen::internal::scalar_atan_op<float>>::Cost}, + {"Ceil", Eigen::internal::functor_traits< + Eigen::internal::scalar_ceil_op<float>>::Cost}, + {"Cos", Eigen::internal::functor_traits< + Eigen::internal::scalar_cos_op<float>>::Cost}, + {"Erf", 1}, + {"Erfc", 1}, + {"Exp", Eigen::internal::functor_traits< + Eigen::internal::scalar_exp_op<float>>::Cost}, + {"Expm1", Eigen::internal::functor_traits< + Eigen::internal::scalar_expm1_op<float>>::Cost}, + {"Floor", Eigen::internal::functor_traits< + Eigen::internal::scalar_floor_op<float>>::Cost}, + {"Inv", Eigen::internal::functor_traits< + Eigen::internal::scalar_inverse_op<float>>::Cost}, + {"InvGrad", 1}, + {"Lgamma", 1}, + {"Log", Eigen::internal::functor_traits< + Eigen::internal::scalar_log_op<float>>::Cost}, + {"Log1p", Eigen::internal::functor_traits< + Eigen::internal::scalar_log1p_op<float>>::Cost}, + {"Neg", Eigen::internal::functor_traits< + Eigen::internal::scalar_opposite_op<float>>::Cost}, + {"Reciprocal", Eigen::internal::functor_traits< + Eigen::internal::scalar_inverse_op<float>>::Cost}, + {"Rint", 1}, + {"Round", Eigen::internal::functor_traits< + Eigen::internal::scalar_round_op<float>>::Cost}, + {"Rsqrt", Eigen::internal::functor_traits< + Eigen::internal::scalar_rsqrt_op<float>>::Cost}, + {"Sqrt", Eigen::internal::functor_traits< + Eigen::internal::scalar_sqrt_op<float>>::Cost}, + {"Square", Eigen::internal::functor_traits< + Eigen::internal::scalar_square_op<float>>::Cost}, + {"Tanh", Eigen::internal::functor_traits< + Eigen::internal::scalar_tanh_op<float>>::Cost}, + {"Sigmoid", Eigen::internal::functor_traits< + Eigen::internal::scalar_sigmoid_op<float>>::Cost}, + {"Sign", Eigen::internal::functor_traits< + Eigen::internal::scalar_sign_op<float>>::Cost}, + {"Sin", Eigen::internal::functor_traits< + Eigen::internal::scalar_sin_op<float>>::Cost}, + {"Tan", Eigen::internal::functor_traits< + Eigen::internal::scalar_tan_op<float>>::Cost}, + // Binary ops alphabetically sorted + {"Add", Eigen::internal::functor_traits< + Eigen::internal::scalar_sum_op<float>>::Cost}, + {"ApproximateEqual", 1}, + {"Div", Eigen::internal::functor_traits< + Eigen::internal::scalar_quotient_op<float>>::Cost}, + {"Equal", 1}, + {"FloorDiv", Eigen::internal::functor_traits< + Eigen::internal::scalar_quotient_op<float>>::Cost}, + {"FloorMod", Eigen::internal::functor_traits< + Eigen::internal::scalar_mod_op<float>>::Cost}, + {"Greater", 1}, + {"GreaterEqual", 1}, + {"Less", 1}, + {"LessEqual", 1}, + {"LogicalAnd", Eigen::internal::functor_traits< + Eigen::internal::scalar_boolean_and_op>::Cost}, + {"LogicalNot", 1}, + {"LogicalOr", Eigen::internal::functor_traits< + Eigen::internal::scalar_boolean_or_op>::Cost}, + {"Maximum", Eigen::internal::functor_traits< + Eigen::internal::scalar_max_op<float>>::Cost}, + {"Minimum", Eigen::internal::functor_traits< + Eigen::internal::scalar_min_op<float>>::Cost}, + {"Mod", Eigen::internal::functor_traits< + Eigen::internal::scalar_mod_op<float>>::Cost}, + {"Mul", Eigen::internal::functor_traits< + Eigen::internal::scalar_product_op<float>>::Cost}, + {"NotEqual", 1}, + {"QuantizedAdd", Eigen::internal::functor_traits< + Eigen::internal::scalar_sum_op<float>>::Cost}, + {"QuantizedMul", Eigen::internal::functor_traits< + Eigen::internal::scalar_product_op<float>>::Cost}, + {"RealDiv", Eigen::internal::functor_traits< + Eigen::internal::scalar_quotient_op<float>>::Cost}, + {"SquareDifference", 1}, + {"Sub", Eigen::internal::functor_traits< + Eigen::internal::scalar_difference_op<float>>::Cost}, + {"TruncateDiv", Eigen::internal::functor_traits< + Eigen::internal::scalar_quotient_op<float>>::Cost}, + {"TruncateMod", Eigen::internal::functor_traits< + Eigen::internal::scalar_mod_op<float>>::Cost}}; } Costs OpLevelCostEstimator::PredictCosts(const OpInfo& op_features) const { auto it = device_cost_impl_.find(op_features.op()); if (it == device_cost_impl_.end()) { + if (elementwise_ops_.find(op_features.op()) != elementwise_ops_.end()) { + return PredictCwiseOp(op_features); + } VLOG(1) << "Missing implementation for op: " << op_features.op(); - Costs costs; - costs = DummyExecutionTime(op_features); - return costs; + return DummyExecutionTime(op_features); } std::function<Costs(const OpInfo&)> estimator = it->second; @@ -121,9 +321,44 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo( return std::make_pair(gflops, bandwidth); } +Costs OpLevelCostEstimator::PredictCwiseOp(const OpInfo& op_features) const { + bool found_unknown_shapes; + // For unary or binary element-wise operations, op count is the element count + // of any input. We use the count for the largest input here to be more robust + // in case that the shape is unknown or partially known for other input. + int64 op_count = + CalculateLargestInputCount(op_features, &found_unknown_shapes); + // If output shape is available, try use the element count calcuated from + // that. + if (op_features.outputs_size() > 0) { + op_count = + std::max(op_count, CalculateTensorElementCount(op_features.outputs(0), + &found_unknown_shapes)); + } + // For binary ops, calculate the output shape possibly resulting from + // broadcasting. + if (op_features.inputs_size() >= 2) { + op_count = std::max(op_count, + CwiseOutputElementCount(op_features.inputs(0).shape(), + op_features.inputs(1).shape())); + } + + int op_cost = 1; + auto it = elementwise_ops_.find(op_features.op()); + if (it != elementwise_ops_.end()) { + op_cost = it->second; + } + Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_features); + if (found_unknown_shapes) { + costs.inaccurate = true; + } + return costs; +} + Costs OpLevelCostEstimator::DummyExecutionTime( const OpInfo& op_features) const { - Costs costs = PredictOpCountBasedCost(0, op_features); + // Use CwiseOp time as an estimation + auto costs = PredictCwiseOp(op_features); costs.inaccurate = true; return costs; } @@ -159,75 +394,6 @@ int64 OpLevelCostEstimator::CountConv2DOperations( return CountConv2DOperations(op_features, nullptr, found_unknown_shapes); } -namespace { - -string GetDataFormat(const OpInfo& op_features) { - string data_format = "NHWC"; // Default format. - if (op_features.attr().find("data_format") != op_features.attr().end()) { - data_format = op_features.attr().at("data_format").s(); - } - return data_format; -} - -Padding GetPadding(const OpInfo& op_features) { - if (op_features.attr().find("padding") != op_features.attr().end() && - op_features.attr().at("padding").s() == "VALID") { - return Padding::VALID; - } - return Padding::SAME; // Default padding. -} - -std::vector<int64> GetStrides(const OpInfo& op_features) { - if (op_features.attr().find("strides") != op_features.attr().end()) { - const auto strides = op_features.attr().at("strides").list().i(); - return {strides[0], strides[1], strides[2], strides[3]}; - } - return {1, 1, 1, 1}; -} - -int64 GetOutputSize(const int64 input, const int64 filter, const int64 stride, - const Padding& padding) { - // Logic for calculating output shape is from GetWindowedOutputSizeVerbose() - // function in third_party/tensorflow/core/framework/common_shape_fns.cc. - if (padding == Padding::VALID) { - return (input - filter + stride) / stride; - } else { // SAME. - return (input + stride - 1) / stride; - } -} - -// Return a minimum shape if the shape is unknown. If known, return the original -// shape. -TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape, - int rank, bool* found_unknown_shapes) { - auto shape = original_shape; - if (shape.unknown_rank()) { - *found_unknown_shapes = true; - } - if (shape.unknown_rank() || shape.dim_size() == 0) { - TensorShapeProto::Dim dim; - VLOG(1) << "WARNING: Use minimum shape because the shape is unknown."; - // The size of each dimension is at least 1, if unknown. - dim.set_size(1); - for (int i = 0; i < rank; i++) { - *shape.add_dim() = dim; - } - } else { - CHECK_EQ(shape.dim_size(), rank); - for (int i = 0; i < rank; i++) { - if (shape.dim(i).size() == -1) { - *found_unknown_shapes = true; - VLOG(1) - << "WARNING: Use minimum dim size 1 because the shape is unknown."; - // The size of each dimension is at least 1, if unknown. - shape.mutable_dim(i)->set_size(1); - } - } - } - return shape; -} -} // namespace - // Helper to translate the positional arguments into named fields. OpLevelCostEstimator::ConvolutionDimensions OpLevelCostEstimator::ConvolutionDimensionsFromInputs( @@ -560,25 +726,31 @@ int64 OpLevelCostEstimator::CountConv2DBackPropFilterOperations( return ops; } -int64 OpLevelCostEstimator::CalculateSingleInputSize( - const OpInfo::TensorProperties& input, bool* found_unknown_shapes) const { - VLOG(1) << " with " << input.dtype() << " input of shape " - << input.shape().DebugString(); - int64 input_size = 1; - int num_dims = std::max(1, input.shape().dim_size()); - auto input_shape = - MaybeGetMinimumShape(input.shape(), num_dims, found_unknown_shapes); - for (const auto& dim : input_shape.dim()) { - input_size *= dim.size(); - } - return input_size * DataTypeSize(BaseType(input.dtype())); +int64 OpLevelCostEstimator::CalculateTensorElementCount( + const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const { + VLOG(1) << " with " << tensor.dtype() << " tensor of shape " + << tensor.shape().DebugString(); + int64 tensor_size = 1; + int num_dims = std::max(1, tensor.shape().dim_size()); + auto tensor_shape = + MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes); + for (const auto& dim : tensor_shape.dim()) { + tensor_size *= dim.size(); + } + return tensor_size; +} + +int64 OpLevelCostEstimator::CalculateTensorSize( + const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) const { + return CalculateTensorElementCount(tensor, found_unknown_shapes) * + DataTypeSize(BaseType(tensor.dtype())); } int64 OpLevelCostEstimator::CalculateInputSize( const OpInfo& op_features, bool* found_unknown_shapes) const { int64 total_input_size = 0; for (auto& input : op_features.inputs()) { - int64 input_size = CalculateSingleInputSize(input, found_unknown_shapes); + int64 input_size = CalculateTensorSize(input, found_unknown_shapes); total_input_size += input_size; VLOG(1) << "Input Size: " << input_size << " Total Input Size:" << total_input_size; @@ -586,6 +758,21 @@ int64 OpLevelCostEstimator::CalculateInputSize( return total_input_size; } +int64 OpLevelCostEstimator::CalculateLargestInputCount( + const OpInfo& op_features, bool* found_unknown_shapes) const { + int64 largest_input_count = 0; + for (auto& input : op_features.inputs()) { + int64 input_count = + CalculateTensorElementCount(input, found_unknown_shapes); + if (input_count > largest_input_count) { + largest_input_count = input_count; + } + VLOG(1) << "Input Count: " << input_count + << " Largest Input Count:" << largest_input_count; + } + return largest_input_count; +} + int64 OpLevelCostEstimator::CalculateOutputSize( const OpInfo& op_features, bool* found_unknown_shapes) const { int64 total_output_size = 0; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index ec7f21622f..28d49a7703 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -89,17 +89,26 @@ class OpLevelCostEstimator { ConvolutionDimensions* conv_info, bool* found_unknown_shapes) const; - // Calculate the total size in bytes of a single input to a TensorFlow op. - int64 CalculateSingleInputSize(const OpInfo::TensorProperties& input, - bool* found_unknown_shapes) const; + // Calculate the element count of an input/output tensor. + int64 CalculateTensorElementCount(const OpInfo::TensorProperties& tensor, + bool* found_unknown_shapes) const; + + // Calculate the total size in bytes of an input/output tensor. + int64 CalculateTensorSize(const OpInfo::TensorProperties& tensor, + bool* found_unknown_shapes) const; + + // Calculate the element count of the largest + // input of specified TensorFlow op. + int64 CalculateLargestInputCount(const OpInfo& op_features, + bool* found_unknown_shapes) const; // Calculate the total size in bytes of the all - // the inputs of specified TensorFlow Op + // the inputs of specified TensorFlow op. int64 CalculateInputSize(const OpInfo& op_features, bool* found_unknown_shapes) const; // Calculate the total size in bytes of the all - // the outputs of specified TensorFlow Op + // the outputs of specified TensorFlow op. int64 CalculateOutputSize(const OpInfo& op_features, bool* found_unknown_shapes) const; @@ -114,6 +123,7 @@ class OpLevelCostEstimator { // execution_time is optional, depending on the // device. Costs PredictConv2D(const OpInfo& op_features) const; + Costs PredictCwiseOp(const OpInfo& op_features) const; Costs PredictConv2DBackPropInput(const OpInfo& op_features) const; Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const; Costs PredictMatMul(const OpInfo& op_features) const; @@ -136,6 +146,7 @@ class OpLevelCostEstimator { bool* found_unknown_shapes); protected: + std::map<string, int> elementwise_ops_; typedef std::function<Costs(const OpInfo& op_feature)> CostImpl; std::map<string, CostImpl> device_cost_impl_; diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 013673ea8e..1f0e02c160 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -37,8 +37,9 @@ void DescribeMatrix(int rows, int columns, OpInfo *op_features) { void SetCpuDevice(OpInfo* op_features) { auto device = op_features->mutable_device(); device->set_type("CPU"); - device->set_num_cores(1); - device->set_frequency(2000); // Mhz + device->set_num_cores(10); + device->set_bandwidth(10000000); // 10000000 KB/s = 10 GB/s + device->set_frequency(1000); // 1000 Mhz = 1 GHz } // Returns an OpInfo for MatMul with the minimum set of fields set up. @@ -103,6 +104,7 @@ void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, shape->add_dim()->set_size(dim1); shape->add_dim()->set_size(dim2); shape->add_dim()->set_size(dim3); + input->set_dtype(DT_FLOAT); } // Returns an OpInfo for Conv2D with the minimum set of fields set up. @@ -116,6 +118,26 @@ OpInfo DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx, DescribeTensor4D(kx, ky, iz2, oz, &op_features); return op_features; } + +OpInfo DescribeOp(const string& op, int size1, int size2) { + OpInfo op_features; + SetCpuDevice(&op_features); + op_features.set_op(op); + + DescribeTensor4D(size1, 1, 1, 1, &op_features); + DescribeTensor4D(2 * size1, size2, 1, 1, &op_features); + + auto output = op_features.add_outputs(); + auto shape = output->mutable_shape(); + shape->add_dim()->set_size(2 * size1); + shape->add_dim()->set_size(size2); + shape->add_dim()->set_size(1); + shape->add_dim()->set_size(1); + output->set_dtype(DT_FLOAT); + + SetCpuDevice(&op_features); + return op_features; +} } // namespace class OpLevelCostEstimatorTest : public ::testing::Test { @@ -138,6 +160,38 @@ class OpLevelCostEstimatorTest : public ::testing::Test { OpLevelCostEstimator estimator_; }; +TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) { + auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1)); + EXPECT_EQ(Costs::Duration(2000), cost.memory_time); + EXPECT_EQ(Costs::Duration(200), cost.compute_time); + EXPECT_EQ(Costs::Duration(2200), cost.execution_time); + EXPECT_TRUE(cost.inaccurate); +} + +TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) { + auto cost = PredictCosts(DescribeOp("Mul", 1000, 1)); + EXPECT_EQ(Costs::Duration(2000), cost.memory_time); + EXPECT_EQ(Costs::Duration(200), cost.compute_time); + EXPECT_EQ(Costs::Duration(2200), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + +TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) { + auto cost = PredictCosts(DescribeOp("Mul", 1000, 2)); + EXPECT_EQ(Costs::Duration(3600), cost.memory_time); + EXPECT_EQ(Costs::Duration(400), cost.compute_time); + EXPECT_EQ(Costs::Duration(4000), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + +TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) { + auto cost = PredictCosts(DescribeOp("Mod", 1000, 1)); + EXPECT_EQ(Costs::Duration(2000), cost.memory_time); + EXPECT_EQ(Costs::Duration(1600), cost.compute_time); + EXPECT_EQ(Costs::Duration(3600), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) { EXPECT_FALSE(PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate); EXPECT_TRUE(PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate); diff --git a/tensorflow/core/grappler/optimizers/static_schedule_test.cc b/tensorflow/core/grappler/optimizers/static_schedule_test.cc index f53feaca4c..b5e21b0c40 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule_test.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule_test.cc @@ -66,15 +66,15 @@ TEST_F(StaticScheduleTest, BasicGraph) { } else if (time.first->name() == "x") { EXPECT_EQ(Costs::NanoSeconds(250001), time.second); } else if (time.first->name() == "AddN") { - EXPECT_EQ(Costs::NanoSeconds(1500001), time.second); + EXPECT_EQ(Costs::NanoSeconds(1500003), time.second); } else if (time.first->name() == "AddN_1") { - EXPECT_EQ(Costs::NanoSeconds(2750001), time.second); + EXPECT_EQ(Costs::NanoSeconds(2750005), time.second); } else if (time.first->name() == "AddN_2") { - EXPECT_EQ(Costs::NanoSeconds(4000001), time.second); + EXPECT_EQ(Costs::NanoSeconds(4000007), time.second); } else if (time.first->name() == "AddN_3") { - EXPECT_EQ(Costs::NanoSeconds(5250001), time.second); + EXPECT_EQ(Costs::NanoSeconds(5250009), time.second); } else if (time.first->name() == "y") { - EXPECT_EQ(Costs::NanoSeconds(6500001), time.second); + EXPECT_EQ(Costs::NanoSeconds(6500011), time.second); } } } @@ -110,13 +110,13 @@ TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) { if (time.first->name() == "a") { EXPECT_EQ(Costs::NanoSeconds(1), time.second); } else if (time.first->name() == "b") { - EXPECT_EQ(Costs::NanoSeconds(12500001), time.second); + EXPECT_EQ(Costs::NanoSeconds(12500026), time.second); } else if (time.first->name() == "c") { - EXPECT_EQ(Costs::NanoSeconds(12500002), time.second); + EXPECT_EQ(Costs::NanoSeconds(12500027), time.second); } else if (time.first->name() == "d") { - EXPECT_EQ(Costs::NanoSeconds(12500003), time.second); + EXPECT_EQ(Costs::NanoSeconds(12500028), time.second); } else if (time.first->name() == "e") { - EXPECT_EQ(Costs::NanoSeconds(25000003), time.second); + EXPECT_EQ(Costs::NanoSeconds(25000053), time.second); } } } |