aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/clusters/BUILD1
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc9
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.h3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine_test.cc8
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc12
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.h5
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc72
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h5
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc53
-rw-r--r--tensorflow/core/grappler/op_types.cc15
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc66
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc72
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc18
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc64
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc3
22 files changed, 383 insertions, 79 deletions
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
index d0b2cf01be..ab8f4bebb3 100644
--- a/tensorflow/core/grappler/clusters/BUILD
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -77,6 +77,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":cluster",
+ ":utils",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
index d33aaa7e4c..06db36b3aa 100644
--- a/tensorflow/core/grappler/clusters/cluster.h
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -95,7 +95,7 @@ class Cluster {
// The DeviceSet is not always available, but when it is it contains a
// superset of the devices listed in GetDevices/GetDeviceNames().
- const DeviceSet* GetDeviceSet() const { return device_set_; }
+ virtual const DeviceSet* GetDeviceSet() const { return nullptr; }
// Enables collecting the allocator stats. Call with enable=true must be made
// before Provision().
@@ -124,7 +124,6 @@ class Cluster {
protected:
std::unordered_map<string, DeviceProperties> devices_;
- const DeviceSet* device_set_ = nullptr; // Not owned
const int timeout_s_;
SessionOptions options_;
RunOptions run_options_;
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index 313ef90d81..b97603c890 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -368,6 +368,15 @@ Status SingleMachine::ResetSession() {
}
coordinator_.reset(new Coordinator());
+ // Build the DeviceSet.
+ device_set_.reset(new DeviceSet);
+ const DeviceMgr* device_mgr;
+ TF_RETURN_IF_ERROR(session_->LocalDeviceManager(&device_mgr));
+ for (auto d : device_mgr->ListDevices()) {
+ device_set_->AddDevice(d);
+ // We currently don't care about the client device.
+ }
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h
index 0ae188e0d6..c0421dd4de 100644
--- a/tensorflow/core/grappler/clusters/single_machine.h
+++ b/tensorflow/core/grappler/clusters/single_machine.h
@@ -43,6 +43,8 @@ class SingleMachine : public Cluster {
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
+ const DeviceSet* GetDeviceSet() const override { return device_set_.get(); }
+
Status EnablePeakMemoryStats(bool enable) override;
// It requires EnableAllocatorStats(true) be called before Provision().
@@ -73,6 +75,7 @@ class SingleMachine : public Cluster {
int64 expected_init_time_s_;
std::unique_ptr<Coordinator> coordinator_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<DeviceSet> device_set_;
RunMetadata init_metadata_;
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
index 352f08fede..31b19cfcfd 100644
--- a/tensorflow/core/grappler/clusters/single_machine_test.cc
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -546,7 +546,7 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
TF_CHECK_OK(cluster_->GetPeakMemoryUsage(&device_peak_memory_before));
EXPECT_EQ(device_peak_memory_before.size(), 1);
// There might be a bit memory used before session's running anything.
- EXPECT_LT(device_peak_memory_before.begin()->second, 200);
+ EXPECT_LT(device_peak_memory_before.begin()->second, 400);
RunMetadata metadata;
TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
@@ -567,8 +567,8 @@ TEST_F(SingleMachineTest, ReleaseMemoryAfterDestruction) {
// Check memory used by resources are released after cluster destruction.
EXPECT_EQ(device_peak_memory_before.size(), 1);
EXPECT_EQ(device_peak_memory_after.size(), 1);
- EXPECT_LT(device_peak_memory_before.begin()->second, 200);
- EXPECT_LT(device_peak_memory_after.begin()->second, 200);
+ EXPECT_LT(device_peak_memory_before.begin()->second, 400);
+ EXPECT_LT(device_peak_memory_after.begin()->second, 400);
}
TEST_F(SingleMachineTest, PeakMemory) {
@@ -597,7 +597,7 @@ TEST_F(SingleMachineTest, PeakMemory) {
device_peak_memory.end());
cpu_memory =
device_peak_memory["/job:localhost/replica:0/task:0/device:CPU:0"];
- EXPECT_LT(cpu_memory, 100);
+ EXPECT_LT(cpu_memory, 200);
}
TEST_F(SingleMachineTest, PeakMemoryStatsNotEnabled) {
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index 5c9b2320b5..12e3e46f65 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
@@ -38,11 +39,14 @@ VirtualCluster::VirtualCluster(
devices_ = devices;
}
-VirtualCluster::VirtualCluster(
- const std::unordered_map<string, DeviceProperties>& devices,
- const DeviceSet* device_set)
- : VirtualCluster(devices) {
+VirtualCluster::VirtualCluster(const DeviceSet* device_set)
+ : VirtualCluster(std::unordered_map<string, DeviceProperties>()) {
device_set_ = device_set;
+ for (const auto& device : device_set_->devices()) {
+ DeviceProperties props = GetDeviceInfo(device->parsed_name());
+ if (props.type() == "UNKNOWN") continue;
+ devices_[device->name()] = props;
+ }
}
VirtualCluster::~VirtualCluster() {}
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h
index eebac68e1b..6adb0b99bc 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.h
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.h
@@ -36,8 +36,7 @@ class VirtualCluster : public Cluster {
VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
OpLevelCostEstimator* node_estimator,
ReadyNodeManager* node_manager);
- VirtualCluster(const std::unordered_map<string, DeviceProperties>& devices,
- const DeviceSet* device_set);
+ VirtualCluster(const DeviceSet* device_set);
~VirtualCluster() override;
@@ -48,10 +47,12 @@ class VirtualCluster : public Cluster {
Status Run(const GraphDef& item,
const std::vector<std::pair<string, Tensor>>& feed,
const std::vector<string>& fetch, RunMetadata* metadata) override;
+ const DeviceSet* GetDeviceSet() const override { return device_set_; }
private:
std::unique_ptr<OpLevelCostEstimator> node_estimator_;
std::unique_ptr<ReadyNodeManager> node_manager_;
+ const DeviceSet* device_set_ = nullptr; // Not owned
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d9a08d42db..0c02876ac5 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -353,12 +353,12 @@ void VerboseLogUnknownDimensionSources(
class TopoQueue {
public:
explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order)
- : queue_(CompareNodes(topo_order)) {}
- void push(const NodeDef* n) { queue_.insert(n); }
+ : topo_order_(topo_order) {}
+ void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
const NodeDef* pop() {
CHECK(!empty());
auto it = queue_.begin();
- const NodeDef* n = *it;
+ const NodeDef* n = it->first;
queue_.erase(it);
return n;
}
@@ -367,20 +367,16 @@ class TopoQueue {
std::size_t size() const { return queue_.size(); }
private:
+ using NodeAndId = std::pair<const NodeDef*, int>;
// Graph nodes are created in (roughly) topological order. Therefore we can
// use their id to ensure they're sorted topologically.
- struct CompareNodes {
- explicit CompareNodes(
- const std::unordered_map<const NodeDef*, int>& topo_ordering)
- : topo_order(topo_ordering) {}
- bool operator()(const NodeDef* lhs, const NodeDef* rhs) const {
- return topo_order.at(lhs) < topo_order.at(rhs);
+ struct OrderByIdAscending {
+ bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
+ return lhs.second < rhs.second;
}
-
- private:
- const std::unordered_map<const NodeDef*, int>& topo_order;
};
- std::set<const NodeDef*, CompareNodes> queue_;
+ const std::unordered_map<const NodeDef*, int>& topo_order_;
+ std::set<NodeAndId, OrderByIdAscending> queue_;
};
// Processes symbolic shapes.
@@ -1082,6 +1078,9 @@ Status GraphProperties::UpdateShapes(
// itself.
TF_RETURN_IF_ERROR(
UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
+ } else if (IsQueue(*n)) {
+ // Set shapes and types of Queue ops, if needed.
+ TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
} else {
auto c = shape_refiner->GetNodeContext(n);
if (c && c->op_data && c->op_data->is_function_op) {
@@ -1147,6 +1146,53 @@ Status GraphProperties::PropagateShapes(
return Status::OK();
}
+Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
+ SymbolicShapeRefiner* shape_refiner,
+ bool* new_shapes) {
+ auto ctx = shape_refiner->GetNodeContext(queue_node);
+ if (!ctx) {
+ TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
+ ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
+ }
+ auto* ic = ctx->inference_context.get();
+
+ auto* outputs = ic->output_handle_shapes_and_types(0);
+ if (outputs) {
+ // Shapes and types are already set, presumably by Enqueue ops.
+ return shape_refiner->UpdateNode(queue_node, new_shapes);
+ }
+
+ if (queue_node->attr().count("shapes") <= 0 ||
+ queue_node->attr().count("component_types") <= 0 ||
+ queue_node->attr().at("shapes").list().shape_size() !=
+ queue_node->attr().at("component_types").list().type_size()) {
+ // Errors in shapes and component_types attr.
+ return shape_refiner->UpdateNode(queue_node, new_shapes);
+ }
+
+ // Extract types and shapes from Queue attr.
+ const auto& shapes = queue_node->attr().at("shapes").list().shape();
+ const auto& types = queue_node->attr().at("component_types").list().type();
+ std::vector<ShapeAndType> shapes_and_types;
+ for (int i = 0; i < types.size(); i++) {
+ const auto& shape = shapes[i];
+ ShapeHandle shape_handle;
+ TF_RETURN_IF_ERROR(
+ ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
+ DataType data_type =
+ queue_node->attr().at("component_types").list().type(i);
+ ShapeAndType shape_and_type(shape_handle, data_type);
+ shapes_and_types.push_back(shape_and_type);
+ }
+ ic->set_output_handle_shapes_and_types(0, shapes_and_types);
+
+ // Queue node is updated with output_handle_shapes_and_types, so set
+ // new_shapes and ignore it from UpdateNoe().
+ *new_shapes = true;
+ bool dummy_new_shapes = false;
+ return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
+}
+
Status GraphProperties::UpdateEnqueue(
const NodeDef* enqueue_node,
const std::unordered_map<const NodeDef*, const NodeDef*>& resource_handles,
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 8703613a12..f716cd72c9 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -91,6 +91,11 @@ class GraphProperties {
resource_handles,
SymbolicShapeRefiner* shape_refiner, bool* new_shapes);
+ // Update the shapes and types of the Queue node, if not set by Enqueue node.
+ static Status UpdateQueue(const NodeDef* queue_node,
+ SymbolicShapeRefiner* shape_refiner,
+ bool* new_shapes);
+
// Update the output shapes of a Merge node, and enqueue its fanout in
// new_shapes if needed.
Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 3e44b222fd..aa787ae620 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -262,6 +262,59 @@ TEST_F(GraphPropertiesTest, VarHandles) {
EXPECT_EQ(7, prop.shape().dim(1).size());
}
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_NoShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT});
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: ?", PropToString(props1[0]));
+}
+
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_ShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
+ ops::FIFOQueue::Attrs().Shapes({{3, 7, 1}}));
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: [3,7,1]", PropToString(props1[0]));
+}
+
+TEST_F(GraphPropertiesTest, QueueWithOnlyDequeue_PartialShapeAttr) {
+ tensorflow::Scope root = tensorflow::Scope::NewRootScope();
+ auto q1 = ops::FIFOQueue(root.WithOpName("Queue1"), {DataType::DT_FLOAT},
+ ops::FIFOQueue::Attrs().Shapes({{3, 7, -1}}));
+ auto dequeue1 =
+ ops::QueueDequeue(root.WithOpName("Dequeue1"), q1, {DataType::DT_FLOAT});
+
+ GrapplerItem item;
+ TF_CHECK_OK(root.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ const auto props1 = properties.GetOutputProperties("Dequeue1");
+ ASSERT_EQ(1, props1.size());
+ EXPECT_EQ("float: [3,7,-1]", PropToString(props1[0]));
+}
+
TEST_F(GraphPropertiesTest, Queues) {
// Create a graph with known input shapes, and propagate the shapes through a
// couple of queues.
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 2227904dbf..bdeb5c66fc 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -135,6 +135,18 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
+bool IsElementWiseMonotonic(const NodeDef& node) {
+ static const std::unordered_set<string>* element_wise_monotonic_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Relu",
+ "Relu6",
+ "Sigmoid",
+ "Sqrt",
+ "Tanh",
+ }));
+ return element_wise_monotonic_ops->count(node.op()) > 0;
+}
+
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
bool IsEnter(const NodeDef& node) {
@@ -617,7 +629,8 @@ bool HasOpDef(const NodeDef& node) {
}
bool IsIdempotent(const NodeDef& node) {
- return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node);
+ return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
+ !ModifiesFrameInfo(node);
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 7110a9c63d..2de7d8cc9a 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -55,6 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 33c2a0d420..8ca726df0b 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -679,6 +679,7 @@ cc_library(
deps = [
":constant_folding",
":graph_optimizer",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 9d500f8f54..90be051764 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1722,19 +1722,15 @@ class RemoveIdempotentStage : public ArithmeticOptimizerStage {
~RemoveIdempotentStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsIdempotent(*node) && !IsInPreserveSet(*node);
+ return node->input_size() == 1 && IsIdempotent(*node) &&
+ !IsInPreserveSet(*node);
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
NodeDef* input;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
- auto root_scope_and_name = ParseNodeScopeAndName(node->name());
- const string new_name = OptimizedNodeName(root_scope_and_name);
- if (input->op() == node->op() && input->device() == node->device() &&
- IsIdempotent(*input) && !ctx().node_map->NodeExists(new_name)) {
- NodeDef* new_input_node = AddCopyNode(new_name, input);
- ForwardControlDependencies(new_input_node, {node});
- *simplified_node_name = new_input_node->name();
+ if (input->op() == node->op() && input->device() == node->device()) {
+ *simplified_node_name = node->input(0);
}
return Status::OK();
}
@@ -2600,6 +2596,58 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
};
+// Performs conversions like:
+// Max(Sqrt(x)) => Sqrt(Max(x))
+// Checks for a max/min reduction over element-wise monotonic functions, such
+// as Sqrt, Sigmoid, Tanh, etc.
+class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
+ public:
+ explicit OptimizeMaxOrMinOfMonotonicStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
+ ctx_ext) {}
+ ~OptimizeMaxOrMinOfMonotonicStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ return IsMax(*node) || IsMin(*node);
+ }
+
+ Status TrySimplify(NodeDef* reduction_node,
+ string* simplified_node_name) override {
+ NodeDef* inner_function;
+ TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
+ // Optimize only if:
+ // 1. inner_function's Op is element-wise monotonic
+ // 2. inner_function's output is not being consumed elsewhere.
+ if (IsElementWiseMonotonic(*inner_function) &&
+ (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ // Swap the first inputs of the inner function Op & the reduction Op.
+ NodeDef* inner_input;
+ TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
+ reduction_node->set_input(0, inner_input->name());
+ UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ }
+ return Status::OK();
+ }
+
+ void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ const string& node_name = node->name();
+ const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
+ for (NodeDef* consumer : consumers) {
+ for (int i = 0; i < consumer->input_size(); ++i) {
+ if (consumer->input(i) == node_name && consumer->name() != new_input) {
+ consumer->set_input(i, new_input);
+ ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -2878,6 +2926,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
if (options_.convert_log1p)
pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
+ if (options_.optimize_max_or_min_of_monotonic)
+ pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 9a6081dcd8..824ef35ef6 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -63,6 +63,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool hoist_common_factor_out_of_aggregation = true;
bool hoist_cwise_unary_chains = false;
bool minimize_broadcasts = true;
+ bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
bool remove_identity_transpose = true;
bool remove_involution = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 177c237fe7..d0e6b04679 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -269,6 +269,11 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.convert_log1p = true;
}
+
+ void EnableOnlyOptimizeMaxOrMinOfMonotonic(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.optimize_max_or_min_of_monotonic = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2971,12 +2976,8 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 3.14f, {32});
- Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
- Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
- Output sn1 =
- ops::Snapshot(s.WithOpName("sn1").WithControlDependencies(ctrl1), a);
- Output sn2 =
- ops::Snapshot(s.WithOpName("sn2").WithControlDependencies(ctrl2), sn1);
+ Output sn1 = ops::Snapshot(s.WithOpName("sn1"), a);
+ Output sn2 = ops::Snapshot(s.WithOpName("sn2"), sn1);
Output out1 = ops::Identity(s.WithOpName("out1"), sn2);
Output id1 = ops::Identity(s.WithOpName("id1"), a);
Output id2 = ops::Identity(s.WithOpName("id2"), id1);
@@ -2992,32 +2993,24 @@ TEST_F(ArithmeticOptimizerTest, RemoveIdempotent) {
EnableOnlyRemoveIdempotent(&optimizer);
OptimizeTwice(&optimizer, &item, &output);
- EXPECT_EQ(11, output.node_size());
+ EXPECT_EQ(7, output.node_size());
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "out1") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_sn2", node.input(0));
- found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_sn2") {
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Snapshot", node.op());
- EXPECT_EQ("a", node.input(0));
- EXPECT_EQ("^ctrl1", node.input(1));
- EXPECT_EQ("^ctrl2", node.input(2));
+ EXPECT_EQ("sn1", node.input(0));
found++;
} else if (node.name() == "out2") {
EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("ArithmeticOptimizer/RemoveIdempotent_id2", node.input(0));
+ EXPECT_EQ("id1", node.input(0));
found++;
- } else if (node.name() == "ArithmeticOptimizer/RemoveIdempotent_id2") {
- EXPECT_EQ("Identity", node.op());
+ } else if (node.name() == "sn1") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("a", node.input(0));
found++;
}
}
- EXPECT_EQ(4, found);
+ EXPECT_EQ(3, found);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(tensors.size(), tensors_expected.size());
@@ -3125,5 +3118,46 @@ TEST_F(ArithmeticOptimizerTest, RemoveLogicalNot) {
}
}
+TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "sqrt") {
+ EXPECT_EQ("Sqrt", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Max", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index f4b384ec1e..76c928f995 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -354,12 +354,14 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
if (op == "TensorArraySizeV3") {
- const NodeDef* array = node_map_->GetNode(node->input(0));
- if (array->attr().count("dynamic_size") != 0 &&
- array->attr().at("dynamic_size").b()) {
+ const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
+ if (array->input_size() == 0 ||
+ (array->attr().count("dynamic_size") != 0 &&
+ array->attr().at("dynamic_size").b())) {
continue;
}
- const NodeDef* array_size = node_map_->GetNode(array->input(0));
+ const NodeDef* array_size =
+ CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
if (IsReallyConstant(*array_size)) {
// Don't materialize 0 sizes to avoid triggering incorrect static
// checks. A 0 sized array that can't grow isn't useful anyway.
@@ -374,6 +376,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
if (value.flat<int32>()(0) == 0) {
continue;
}
+
node->set_op("Const");
*node->mutable_attr() = array_size->attr();
node->set_input(0, AsControlDependency(NodeName(node->input(0))));
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 9f051ca248..b9765b9292 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -3000,6 +3000,10 @@ TEST_F(ConstantFoldingTest, Enter) {
TEST_F(ConstantFoldingTest, TensorArraySize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output size = ops::Const(scope.WithOpName("size"), 5, TensorShape({}));
+ Output placeholder =
+ ops::Placeholder(scope.WithOpName("placeholder"), DT_RESOURCE,
+ ops::Placeholder::Shape(TensorShape({2})));
+ Output foo = ops::Const(scope.WithOpName("foo"), 5.0f, TensorShape({}));
auto dynamic_array =
ops::TensorArray(scope.WithOpName("dynamic"), size, DT_FLOAT,
ops::TensorArray::DynamicSize(true));
@@ -3010,6 +3014,8 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
scope.WithOpName("dynamic_sz"), dynamic_array.handle, dynamic_array.flow);
auto static_sz = ops::TensorArraySize(scope.WithOpName("static_sz"),
static_array.handle, static_array.flow);
+ auto placeholder_sz = ops::TensorArraySize(scope.WithOpName("placeholder_sz"),
+ placeholder, foo);
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
@@ -3026,11 +3032,13 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
- EXPECT_EQ("dynamic_sz", output.node(3).name());
- EXPECT_EQ("TensorArraySizeV3", output.node(3).op());
- EXPECT_EQ("static_sz", output.node(4).name());
- EXPECT_EQ("Const", output.node(4).op());
+ EXPECT_EQ(8, output.node_size());
+ EXPECT_EQ("dynamic_sz", output.node(5).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(5).op());
+ EXPECT_EQ("static_sz", output.node(6).name());
+ EXPECT_EQ("Const", output.node(6).op());
+ EXPECT_EQ("placeholder_sz", output.node(7).name());
+ EXPECT_EQ("TensorArraySizeV3", output.node(7).op());
auto tensors_actual = EvaluateNodes(output, {"dynamic_sz", "static_sz"});
EXPECT_EQ(2, tensors_expected.size());
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
index 3f5bab9d3b..fdd82b9603 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc
@@ -260,14 +260,14 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
}
continue;
}
+ // Replace a normal input with a control input.
const string ctrl_input = ConstantFolding::AddControlDependency(
old_input, optimized_graph_, node_map_.get());
- if (ctrl_inputs.insert(ctrl_input).second) {
- node->set_input(pos, ctrl_input);
- node_map_->UpdateInput(node_name, old_input, ctrl_input);
- const NodeDef* old_input_node = node_map_->GetNode(old_input);
- nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
- }
+ ctrl_inputs.insert(ctrl_input);
+ node->set_input(pos, ctrl_input);
+ node_map_->UpdateInput(node_name, old_input, ctrl_input);
+ const NodeDef* old_input_node = node_map_->GetNode(old_input);
+ nodes_to_simplify->PushBack(node_to_idx_[old_input_node]);
++pos;
}
node->set_op("NoOp");
diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
index 0ae3b4ec34..c0f07562af 100644
--- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc
@@ -124,25 +124,62 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size(), output.node_size());
+ int found = 0;
for (int i = 0; i < item.graph.node_size(); ++i) {
const NodeDef& node = item.graph.node(i);
- if (node.name() == "add") {
- EXPECT_EQ("NoOp", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("^x", node.input(0));
- EXPECT_EQ("^y", node.input(1));
- } else if (node.name() == "id1") {
+ // "add" should get turned into a NoOp and removed.
+ EXPECT_NE("add", node.name());
+ if (node.name() == "id1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
+ ++found;
} else if (node.name() == "id2") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^x", node.input(1));
+ ++found;
+ }
+ }
+ EXPECT_EQ(2, found);
+}
+
+TEST_F(DependencyOptimizerTest, ChangeToNoop_RepeatedInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::RandomUniform(s.WithOpName("x"), {1, 2}, DT_FLOAT);
+ Output add = ops::Add(s.WithOpName("add"), x, x);
+ Output id1 =
+ ops::Identity(s.WithOpName("id1").WithControlDependencies(add), x);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch = {"id1"};
+
+ DependencyOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ LOG(INFO) << output.DebugString();
+
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ int found = 0;
+ for (int i = 0; i < item.graph.node_size(); ++i) {
+ const NodeDef& node = item.graph.node(i);
+ // "add" should get turned into a NoOp and removed.
+ EXPECT_NE("add", node.name());
+ if (node.name() == "id1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(1, found);
}
TEST_F(DependencyOptimizerTest, ChangeToNoop_SwitchIdentity) {
@@ -400,6 +437,7 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 3, output.node_size());
+ int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_NE("id_a", node.name());
EXPECT_NE("id_b", node.name());
@@ -407,30 +445,36 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity) {
if (node.name() == "a_a" || node.name() == "a_b") {
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("x", node.input(0));
+ ++found;
}
if (node.name() == "a_c" || node.name() == "a_d") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("z", node.input(0));
EXPECT_EQ("^x", node.input(1));
+ ++found;
}
if (node.name() == "b_a") {
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
EXPECT_EQ("^z", node.input(2));
+ ++found;
}
if (node.name() == "c_a") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
+ ++found;
}
if (node.name() == "c_b") {
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("z", node.input(0));
EXPECT_EQ("^x", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
}
}
+ EXPECT_EQ(found, 7);
}
TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
@@ -460,17 +504,20 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 1, output.node_size());
+ int found = 0;
for (const NodeDef& node : output.node()) {
EXPECT_NE("id0", node.name());
if (node.name() == "or0") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("switch:1", node.input(0));
EXPECT_EQ("switch:1", node.input(1));
+ ++found;
}
if (node.name() == "or1") {
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("switch:1", node.input(0));
EXPECT_EQ("y", node.input(1));
+ ++found;
}
if (node.name() == "or2") {
// or1 should be unchanged.
@@ -478,8 +525,10 @@ TEST_F(DependencyOptimizerTest, RemoveIdentity_RepeatedInputs) {
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("y", node.input(1));
EXPECT_EQ("^id1", node.input(2));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
}
TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
@@ -535,6 +584,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size() - 2, output.node_size());
+ bool found = false;
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
// "id0" and "id1" but neither "ConstantFoldingCtrl/switch_1",
@@ -545,8 +595,10 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_Identity) {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0));
+ found = true;
}
}
+ EXPECT_TRUE(found);
}
TEST_F(DependencyOptimizerTest, IdentityInputs) {
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 143d9dc1c6..b1f31ad0d0 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -42,6 +42,7 @@ namespace grappler {
namespace {
constexpr int kDefaultNumberOfIterations = 2;
+constexpr int kDefaultMinGraphNodes = 4;
int64 NumEdges(const GraphDef& graph) {
int64 num_edges = 0;
@@ -194,6 +195,15 @@ Status MetaOptimizer::InitializeOptimizersByName(
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
+ : cfg_.min_graph_nodes();
+ if (item.graph.node_size() < min_graph_nodes) {
+ VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
+ << " nodes.";
+ *optimized_graph = item.graph;
+ return Status::OK();
+ }
+
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
@@ -202,10 +212,11 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
- << " num_optimizers=" << optimizers.size();
+ << " num_optimizers=" << optimizers.size()
+ << ", num nodes = " << item.graph.node_size();
if (optimizers.empty()) {
- VLOG(3) << "Skip graph optimization, no optimizers registered";
+ VLOG(3) << "Skipping graph optimization, no optimizers registered";
*optimized_graph = item.graph;
return Status::OK();
}
@@ -221,8 +232,15 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphOptimizer* sa_optimizer = nullptr;
for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
- VLOG(4) << "Starting optimization iteration " << iteration + 1;
+ // Don't bother optimizing further if the graph is already tiny.
+ if (optimized_graph->node_size() < min_graph_nodes) {
+ VLOG(3) << "Stopping after iteration " << iteration
+ << ", graph is tiny (#nodes = " << optimized_graph->node_size()
+ << " < " << min_graph_nodes << ")";
+ break;
+ }
+ VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
@@ -235,7 +253,6 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
if (fusion_optimizer == nullptr) fusion_optimizer = optimizer.get();
continue;
}
-
Status status = RunOptimizer(optimizer.get(), cluster, &optimized_item,
optimized_graph, &optimization_result);
if (status.ok()) is_optimized = true;
@@ -297,7 +314,7 @@ Status MetaOptimizer::RunOptimizer(
PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
", time = ", duration_ms, "ms.");
}
- VLOG(4) << optimizer->name() << ": " << result;
+ VLOG(1) << optimizer->name() << ": " << result;
OptimizerResult optimizer_result{optimizer->name(), result};
optimization_result->results.push_back(optimizer_result);
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 8247cce339..9a03c7dfef 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -74,6 +74,7 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
TestOptimizer::SetOptimized(false);
RewriterConfig rewriter_config;
rewriter_config.add_optimizers("TestOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
GraphDef output;
@@ -89,6 +90,7 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
RewriterConfig rewriter_config;
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);
GraphDef output;
@@ -104,6 +106,7 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
rewriter_config.set_function_optimization(RewriterConfig::ON);
rewriter_config.add_optimizers("function");
+ rewriter_config.set_min_graph_nodes(-1);
MetaOptimizer optimizer(nullptr, rewriter_config);