aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 14:27:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 14:37:18 -0700
commit2116c6649cfe339ce8a3859eb425806db8ae32b9 (patch)
treefa151a4dae84ad09e8ebd7c5e509c11c8c594e28 /tensorflow/core/grappler
parentc551a7dbd08685160c233ccecd444f774666f98e (diff)
Misc. micro-optimizations in Grappler optimizers.
Make shape inference lazy in optimizers that may not trigger. PiperOrigin-RevId: 214669034
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/graph_view.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc12
4 files changed, 26 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 2619a9a8f3..0b8cb5e919 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -72,7 +72,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
- string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
+ const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 75ed12635e..3388ee8035 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
- string name = ParseNodeName(input, &position);
+ StringPiece name = ParseNodeNameAsStringPiece(input, &position);
if (name == node.name() && /*control input*/ position < 0) {
return true;
}
@@ -1568,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
for (NodeDef* output : outputs) {
if (IsControlInput(output->input(0))) continue;
int port;
- const string node_name = ParseNodeName(output->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(output->input(0), &port);
if (node_name == node.name()) {
tails->insert(ChainLink(output, port));
} else {
@@ -1618,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
} else {
for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
int port;
- const string node_name = ParseNodeName(new_tail->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(new_tail->input(0), &port);
if (node_name != tail->name()) {
return Status::OK();
}
@@ -2929,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
for (const auto& input : node.input()) {
int pos;
- string node_name = ParseNodeName(input, &pos);
- h = Hash64CombineUnordered(Hash64(node_name), h);
+ const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos);
+ h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h);
h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 008a289cfd..9ada8b7ff9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
+ optimized_graph->mutable_node()->Reserve(item.graph.node_size());
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
@@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 4542d17ccc..6ccb1cd783 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = item.graph;
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
@@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop =
properties.GetOutputProperties(reduce_indices.node->name());
if (prop.size() < reduce_indices.port_id) {
@@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop1 = properties.GetInputProperties(input1.node->name());
const auto& prop2 = properties.GetInputProperties(input2.node->name());
if (prop1.size() != 1 || prop2.size() != 1) {