aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-06 16:45:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-06 16:49:04 -0700
commit8f89b654f4d49a1b5d4462303ef27f7f7a2958b3 (patch)
tree4be61a6867376b086c1cb0b770f38c408df91b69 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent0ea0bf5aae2961be4edbe00c205bed01d293dce3 (diff)
Profile memory usage in VirtualScheduler and report peak memory usage.
To do so, NodeState now handles different output ports of a node (in case a node has multiple outputs). Also, VirtualScheduler code is cleaned up with more comments. PiperOrigin-RevId: 158209068
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 11a57921e5..75ff75123e 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -31,6 +31,8 @@ constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
constexpr char kRecv[] = "_Recv";
constexpr char kBatchMatMul[] = "BatchMatMul";
+constexpr char kVariable[] = "Variable";
+constexpr char kVariableV2[] = "VariableV2";
OpLevelCostEstimator::OpLevelCostEstimator() {
// Syntactic sugar to build and return a lambda that takes an OpInfo and
@@ -53,6 +55,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kVariable, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kVariableV2, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kBatchMatMul, wrap(&OpLevelCostEstimator::PredictBatchMatMul)}};
}
@@ -567,7 +571,7 @@ int64 OpLevelCostEstimator::CalculateSingleInputSize(
for (const auto& dim : input_shape.dim()) {
input_size *= dim.size();
}
- return input_size * DataTypeSize(input.dtype());
+ return input_size * DataTypeSize(BaseType(input.dtype()));
}
int64 OpLevelCostEstimator::CalculateInputSize(
@@ -589,7 +593,7 @@ int64 OpLevelCostEstimator::CalculateOutputSize(
for (const auto& output : op_features.outputs()) {
DataType dt = output.dtype();
const auto& original_output_shape = output.shape();
- int64 output_size = DataTypeSize(dt);
+ int64 output_size = DataTypeSize(BaseType(dt));
int num_dims = std::max(1, original_output_shape.dim_size());
auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
found_unknown_shapes);