aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2018-06-06 14:38:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-06 14:41:33 -0700
commit4a2104ce30cd2a931ca3bae260d7394815f5dcae (patch)
tree062f8795d13f4b07afe7ad1cdfcda7605a618d16
parent2cce1a8504f53a5d8bdc08b6d0b5c036b672ca0e (diff)
Estimate Squeeze cost in the same way as Reshape.
PiperOrigin-RevId: 199531069
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index b8e337582c..b994d26397 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -45,6 +45,7 @@ constexpr char kIdentityN[] = "IdentityN";
constexpr char kRefIdentity[] = "RefIdentity";
constexpr char kNoOp[] = "NoOp";
constexpr char kReshape[] = "Reshape";
+constexpr char kSqueeze[] = "Squeeze";
constexpr char kRecv[] = "_Recv";
constexpr char kSend[] = "_Send";
constexpr char kBatchMatMul[] = "BatchMatMul";
@@ -232,6 +233,7 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
{kStopGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kPreventGradient, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kReshape, wrap(&OpLevelCostEstimator::PredictIdentity)},
+ {kSqueeze, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},