aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-07-21 12:44:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 12:49:02 -0700
commita4538ecd63699f3efb91c9e7c54409d40944e434 (patch)
tree65d95fd7c6f141d2e6ccf048403c4cba8eb41c46 /tensorflow/core/grappler/costs/op_level_cost_estimator.cc
parent9d8f6ac30db4c922b34997667dc43ba4bc27cf79 (diff)
Set the op cost of RefIdentity, StopGradient, and PreventGradient to zero in analytical cost estimator.
PiperOrigin-RevId: 162773156
Diffstat (limited to 'tensorflow/core/grappler/costs/op_level_cost_estimator.cc')
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc6
1 files changed, 6 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 7f4cc95f31..5148a4b99e 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -31,6 +31,7 @@ constexpr char kConv2dBackPropInput[] = "Conv2DBackpropInput";
constexpr char kMatMul[] = "MatMul";
constexpr char kSparseMatMul[] = "SparseMatMul";
constexpr char kIdentity[] = "Identity";
+constexpr char kRefIdentity[] = "RefIdentity";
constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
constexpr char kRecv[] = "_Recv";
@@ -40,6 +41,8 @@ constexpr char kVariableV2[] = "VariableV2";
constexpr char kRank[] = "Rank";
constexpr char kShape[] = "Shape";
constexpr char kSize[] = "Size";
+constexpr char kStopGradient[] = "StopGradient";
+constexpr char kPreventGradient[] = "PreventGradient";
namespace {
@@ -155,6 +158,9 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kSparseMatMul, wrap(&OpLevelCostEstimator::PredictMatMul)},
{kIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kRefIdentity, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kStopGradient, wrap(&OpLevelCostEstimator::PredictNoOp)},
+ {kPreventGradient, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kReshape, wrap(&OpLevelCostEstimator::PredictNoOp)},
{kRecv, wrap(&OpLevelCostEstimator::PredictNoOp)},