aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2017-12-01 14:48:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-01 14:52:24 -0800
commitd0ae1064ed0bb4bd1aed00afd4235f4dd5c853f0 (patch)
treeb077af5e338ac90aa8d10046e188943c049b24b0 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent10f77231b005c76b5a771243e18384b4b66be325 (diff)
Prefix inaccurate costs with "~" in VirtualScheduler verbose log.
Fix some inaccurate estimates exposed by this approach: - propagate the inaccuracy flag when merging device stats; - estimate Const as no-op; - estimate RandomUniform, Relu and Softmax as element-wise; - consider estimates accurate for known element-wise ops in op_level_cost_estimator. PiperOrigin-RevId: 177643976
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc32
1 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b1e04ceec8..1c278a1030 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -25,6 +25,7 @@ namespace tensorflow {
namespace grappler {
constexpr int kOpsPerMac = 2;
+constexpr char kConst[] = "Const";
constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
@@ -167,6 +168,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kSend, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kConst, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)},
@@ -221,6 +223,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
Eigen::internal::scalar_square_op<float>>::Cost},
{"Tanh", Eigen::internal::functor_traits<
Eigen::internal::scalar_tanh_op<float>>::Cost},
+ {"Relu", Eigen::internal::functor_traits<
+ Eigen::internal::scalar_max_op<float>>::Cost},
{"Sigmoid", Eigen::internal::functor_traits<
Eigen::internal::scalar_sigmoid_op<float>>::Cost},
{"Sign", Eigen::internal::functor_traits<
@@ -283,8 +287,10 @@ Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
if (elementwise_ops_.find(op_features.op()) != elementwise_ops_.end()) {
return PredictCwiseOp(op_context);
}
- VLOG(1) << "Missing implementation for op: " << op_features.op();
- return DummyExecutionTime(op_context);
+
+ VLOG(1) << "Missing accurate estimator for op: " << op_features.op();
+
+ return PredictCostOfAnUnknownOp(op_context);
}
std::function<Costs(const OpContext&)> estimator = it->second;
@@ -366,19 +372,27 @@ Costs OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context) const {
}
int op_cost = 1;
+ bool is_known_elementwise_op = false;
auto it = elementwise_ops_.find(op_features.op());
if (it != elementwise_ops_.end()) {
op_cost = it->second;
+ is_known_elementwise_op = true;
+ } else {
+ LOG(WARNING) << "Not a cwise op: " << op_features.op();
}
+
Costs costs = PredictOpCountBasedCost(op_count * op_cost, op_features);
- costs.inaccurate = found_unknown_shapes;
+ if (found_unknown_shapes || !is_known_elementwise_op) {
+ costs.inaccurate = true;
+ }
return costs;
}
-Costs OpLevelCostEstimator::DummyExecutionTime(
+Costs OpLevelCostEstimator::PredictCostOfAnUnknownOp(
const OpContext& op_context) const {
- // Use CwiseOp time as an estimation
- auto costs = PredictCwiseOp(op_context);
+ // Don't assume the operation is cwise, return cost based on input/output size
+ // and admit that it is inaccurate...
+ auto costs = PredictOpCountBasedCost(0, op_context.op_info);
costs.inaccurate = true;
return costs;
}
@@ -391,11 +405,11 @@ Costs OpLevelCostEstimator::PredictOpCountBasedCost(
<< " Execution Time (ns):" << compute_cost.count();
bool found_unknown_shapes = false;
- double total_input_size =
+ const double total_input_size =
CalculateInputSize(op_features, &found_unknown_shapes);
- double total_output_size =
+ const double total_output_size =
CalculateOutputSize(op_features, &found_unknown_shapes);
- double total_io_size = total_input_size + total_output_size;
+ const double total_io_size = total_input_size + total_output_size;
Costs::NanoSeconds memory_cost(
std::ceil(total_io_size / device_perf.gb_per_sec));