aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.h
blob: ec7f21622f6dd294443fe2caa48c3b23970aeea2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_
#define TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_

#include <functional>
#include <map>
#include <string>

#include "tensorflow/core/grappler/costs/cost_estimator.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
#include "tensorflow/core/util/padding.h"

namespace tensorflow {
namespace grappler {

class OpLevelCostEstimator {
 public:
  OpLevelCostEstimator();
  virtual ~OpLevelCostEstimator() {}

  virtual Costs PredictCosts(const OpInfo& op_features) const;

 protected:
  // Returns an estimate of device performance (in billions of operations
  // executed per second) and memory bandwidth (in GigaBytes/second) for the
  // specified device.
  virtual std::pair<double, double> GetDeviceInfo(
      const DeviceProperties& device) const;

  // For operations for which we haven't yet built estimates, returns a dummy
  // value based on input size.
  Costs DummyExecutionTime(const OpInfo& op_features) const;

  // Naive cost estimate based on operations divided by device ops/sec.
  Costs PredictOpCountBasedCost(double operations,
                                const OpInfo& op_features) const;

  // This family of routines counts the number of operations to perform the
  // specified TensorFlow Op.
  struct MatMulDimensions {
    int m;
    int n;
    int k;
  };
  struct ConvolutionDimensions {
    int64 batch;      // Batch size.
    int64 ix;         // Input size x.
    int64 iy;         // Input size y.
    int64 iz;         // Input depth.
    int64 kx;         // Kernel x.
    int64 ky;         // Kernel y.
    int64 oz;         // Output depth.
    int64 ox;         // Output size x.
    int64 oy;         // Output size y.
    int64 sx;         // Stride x.
    int64 sy;         // Stride y.
    Padding padding;  // SAME or VALID.
  };
  int64 CountConv2DOperations(const OpInfo& op_features,
                              bool* found_unknown_shapes) const;
  int64 CountConv2DOperations(const OpInfo& op_features,
                              ConvolutionDimensions* conv_info,
                              bool* found_unknown_shapes) const;
  int64 CountMatMulOperations(const OpInfo& op_features,
                              bool* found_unknown_shapes) const;
  int64 CountMatMulOperations(const OpInfo& op_features,
                              MatMulDimensions* mat_mul,
                              bool* found_unknown_shapes) const;
  int64 CountBatchMatMulOperations(const OpInfo& op_features,
                                   bool* found_unknown_shapes) const;
  int64 CountConv2DBackPropInputOperations(const OpInfo& op_features,
                                           ConvolutionDimensions* conv_info,
                                           bool* found_unknown_shapes) const;
  int64 CountConv2DBackPropFilterOperations(const OpInfo& op_features,
                                            ConvolutionDimensions* conv_info,
                                            bool* found_unknown_shapes) const;

  // Calculate the total size in bytes of a single input to a TensorFlow op.
  int64 CalculateSingleInputSize(const OpInfo::TensorProperties& input,
                                 bool* found_unknown_shapes) const;

  // Calculate the total size in bytes of the all
  // the inputs of specified TensorFlow Op
  int64 CalculateInputSize(const OpInfo& op_features,
                           bool* found_unknown_shapes) const;

  // Calculate the total size in bytes of the all
  // the outputs of specified TensorFlow Op
  int64 CalculateOutputSize(const OpInfo& op_features,
                            bool* found_unknown_shapes) const;

  // This family of routines predicts the costs to
  // perform the specified TensorFlow Op on the
  // device represented by a subclass. The default
  // implementation just divides the operations to
  // perform the op (from the "Count" routines,
  // above) by the device peak operations per
  // second. Override to supply a better estimate.
  // Implementation of costs other than
  // execution_time is optional, depending on the
  // device.
  Costs PredictConv2D(const OpInfo& op_features) const;
  Costs PredictConv2DBackPropInput(const OpInfo& op_features) const;
  Costs PredictConv2DBackPropFilter(const OpInfo& op_features) const;
  Costs PredictMatMul(const OpInfo& op_features) const;
  Costs PredictNoOp(const OpInfo& op_features) const;
  Costs PredictBatchMatMul(const OpInfo& op_features) const;

  // Utility function for safe division. Returns 0
  // if rhs is 0 or negative.
  static double SafeDiv(const double lhs, const double rhs) {
    if (rhs > 0) {
      return lhs / rhs;
    } else {
      return 0.0;
    }
  }

  static ConvolutionDimensions ConvolutionDimensionsFromInputs(
      const TensorShapeProto& original_image_shape,
      const TensorShapeProto& original_filter_shape, const OpInfo& op_features,
      bool* found_unknown_shapes);

 protected:
  typedef std::function<Costs(const OpInfo& op_feature)> CostImpl;
  std::map<string, CostImpl> device_cost_impl_;

 private:
  friend class OpLevelCostEstimatorTest;
};

}  // end namespace grappler
}  // end namespace tensorflow
#endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_OP_LEVEL_COST_ESTIMATOR_H_