aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/BUILD5
-rw-r--r--tensorflow/core/grappler/BUILD67
-rw-r--r--tensorflow/core/grappler/clusters/BUILD64
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc48
-rw-r--r--tensorflow/core/grappler/clusters/cluster.h88
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc248
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.h70
-rw-r--r--tensorflow/core/grappler/clusters/single_machine_test.cc134
-rw-r--r--tensorflow/core/grappler/costs/BUILD116
-rw-r--r--tensorflow/core/grappler/costs/cost_estimator.h146
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.cc109
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.h61
-rw-r--r--tensorflow/core/grappler/costs/graph_memory_test.cc53
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc159
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h57
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc133
-rw-r--r--tensorflow/core/grappler/costs/op_performance_data.proto116
-rw-r--r--tensorflow/core/grappler/costs/utils.cc159
-rw-r--r--tensorflow/core/grappler/costs/utils.h53
-rw-r--r--tensorflow/core/grappler/grappler_item.cc251
-rw-r--r--tensorflow/core/grappler/grappler_item.h80
-rw-r--r--tensorflow/core/grappler/grappler_item_test.cc49
-rw-r--r--tensorflow/core/grappler/inputs/BUILD72
-rw-r--r--tensorflow/core/grappler/inputs/input_yielder.h35
-rw-r--r--tensorflow/core/grappler/inputs/testdata/test_file.txt1
-rw-r--r--tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc111
-rw-r--r--tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h47
-rw-r--r--tensorflow/core/grappler/inputs/utils.cc33
-rw-r--r--tensorflow/core/grappler/inputs/utils.h35
-rw-r--r--tensorflow/core/grappler/inputs/utils_test.cc64
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD45
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer.h53
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc996
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.h42
-rw-r--r--tensorflow/core/grappler/utils.cc90
-rw-r--r--tensorflow/core/grappler/utils.h41
-rw-r--r--tensorflow/core/grappler/utils_test.cc59
-rw-r--r--tensorflow/core/platform/env.cc46
-rw-r--r--tensorflow/core/platform/env.h6
-rw-r--r--tensorflow/core/platform/file_system.cc16
-rw-r--r--tensorflow/core/platform/file_system.h6
41 files changed, 4064 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 755bf679d3..01244d7261 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -213,6 +213,11 @@ filegroup(
"//tensorflow/core/debug:all_files",
"//tensorflow/core/distributed_runtime:all_files",
"//tensorflow/core/distributed_runtime/rpc:all_files",
+ "//tensorflow/core/grappler:all_files",
+ "//tensorflow/core/grappler/clusters:all_files",
+ "//tensorflow/core/grappler/costs:all_files",
+ "//tensorflow/core/grappler/inputs:all_files",
+ "//tensorflow/core/grappler/optimizers:all_files",
"//tensorflow/core/kernels:all_files",
"//tensorflow/core/kernels/cloud:all_files",
"//tensorflow/core/kernels/hexagon:all_files",
diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD
new file mode 100644
index 0000000000..53714367f5
--- /dev/null
+++ b/tensorflow/core/grappler/BUILD
@@ -0,0 +1,67 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+cc_library(
+ name = "utils",
+ srcs = ["utils.cc"],
+ hdrs = ["utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:stream_executor",
+ ],
+)
+
+cc_test(
+ name = "utils_test",
+ srcs = ["utils_test.cc"],
+ deps = [
+ ":utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "grappler_item",
+ srcs = ["grappler_item.cc"],
+ hdrs = ["grappler_item.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":utils",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler/inputs:utils",
+ ],
+)
+
+cc_test(
+ name = "grappler_item_test",
+ srcs = ["grappler_item_test.cc"],
+ deps = [
+ ":grappler_item",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD
new file mode 100644
index 0000000000..c420818333
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/BUILD
@@ -0,0 +1,64 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+cc_library(
+ name = "cluster",
+ srcs = ["cluster.cc"],
+ hdrs = [
+ "cluster.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
+ name = "single_machine",
+ srcs = ["single_machine.cc"],
+ hdrs = [
+ "single_machine.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cluster",
+ "//tensorflow/cc:coordinator",
+ "//tensorflow/cc:queue_runner",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:direct_session",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+cc_test(
+ name = "single_machine_test",
+ srcs = ["single_machine_test.cc"],
+ args = ["--heap_check=local"], # The GPU tracer leaks memory
+ deps = [
+ ":single_machine",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
new file mode 100644
index 0000000000..089729e770
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include <atomic>
+
+namespace tensorflow {
+namespace grappler {
+
+static std::atomic<bool> already_created(false);
+
+Cluster::Cluster(int timeout_s) : timeout_s_(timeout_s) {
+ // This is really ugly: to avoid leaking variables, we need to reset the tf
+ // session every time we're done processing a grappler item. However,
+ // variables are global, and therefore we can't have more than 1 session alive
+ // at a time. This check detects when more that one cluster is created.
+ CHECK(!already_created);
+ already_created = true;
+
+ options_.config.mutable_graph_options()->set_build_cost_model(1);
+
+ run_options_.set_trace_level(RunOptions::HARDWARE_TRACE);
+}
+
+Cluster::~Cluster() {
+ CHECK(already_created);
+ already_created = false;
+}
+
+void Cluster::SetNumWarmupSteps(int num_steps) {
+ options_.config.mutable_graph_options()->set_build_cost_model_after(
+ num_steps);
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h
new file mode 100644
index 0000000000..02915cc132
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/cluster.h
@@ -0,0 +1,88 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
+#define TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// A cluster represents of collection of hardware resources available to run
+// the TensorFlow model.
+// A process can only create a single cluster at a time.
+class Cluster {
+ public:
+ explicit Cluster(int timeout_s);
+ virtual ~Cluster();
+
+ // Provision the hardware resources needed to run TensorFlow and start a
+ // TensorFlow session that can take advantage of these resources.
+ // The actual resources that are leveraged depend on the type of cluster
+ // instantiated.
+ // Returns OK iff all the requested resources could be reserved and a
+ // TensorFlow session successfully created. Returns an error otherwise.
+ // There is no graceful degradation to handle the case where only a subset
+ // of the requested resources are available.
+ virtual Status Provision() = 0;
+
+ // Set the number of steps required to warmup TensorFlow. Must be called
+ // before Provision().
+ void SetNumWarmupSteps(int num_steps);
+
+ // Return the list of TensorFlow devices that are available to execute a
+ // graph. This is empty until provision() is called.
+ const std::vector<DeviceAttributes>& GetDevices() const { return devices_; }
+
+ // Convenience method that returns the set of device names.
+ const std::vector<string> GetDeviceNames() const {
+ std::vector<string> device_names;
+ device_names.reserve(devices_.size());
+ for (const auto& device : devices_) {
+ device_names.push_back(device.name());
+ }
+ return device_names;
+ }
+
+ // Prepare the session to run the specified grappler item. This include
+ // initializing all the model variables.
+ virtual Status Initialize(const GrapplerItem& item) = 0;
+
+ // Run the specified graph_def and return the corresponding metadata.
+ virtual Status Run(const GraphDef& graph_def,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch,
+ RunMetadata* metadata) = 0;
+
+ protected:
+ std::vector<DeviceAttributes> devices_;
+ const int timeout_s_;
+ SessionOptions options_;
+ RunOptions run_options_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
new file mode 100644
index 0000000000..6e4a6a648e
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -0,0 +1,248 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/clusters/single_machine.h"
+#include "tensorflow/cc/training/queue_runner.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace grappler {
+
+SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus)
+ : Cluster(timeout_s),
+ num_gpus_(num_gpus),
+ running_(false),
+ closing_(false) {
+ thread_pool_.reset(new thread::ThreadPool(
+ Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
+
+ (*options_.config.mutable_device_count())["CPU"] = 1;
+ if (num_gpus > 0) {
+ (*options_.config.mutable_device_count())["GPU"] = num_gpus;
+ }
+ CHECK_GE(num_cpu_cores, 1);
+ options_.config.set_intra_op_parallelism_threads(num_cpu_cores);
+ options_.config.set_inter_op_parallelism_threads(num_cpu_cores);
+}
+
+SingleMachine::~SingleMachine() {
+ CloseSession(false /*use_timeout*/);
+
+ // Prevent the destructor from deleting mu_ until CloseSession() is done.
+ mutex_lock l(mu_);
+}
+
+Status SingleMachine::Provision() {
+ Status status = ResetSession();
+ if (!status.ok()) {
+ return status;
+ }
+
+ DeviceAttributes attr;
+ attr.set_name("/job:localhost/replica:0/task:0/cpu:0");
+ attr.set_device_type("CPU");
+ devices_.push_back(attr);
+
+ for (int i = 0; i < num_gpus_; ++i) {
+ DeviceAttributes attr;
+ attr.set_name(strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i));
+ attr.set_device_type("GPU");
+ devices_.push_back(attr);
+ }
+ return Status::OK();
+}
+
+Status SingleMachine::Initialize(const GrapplerItem& item) {
+ if (last_graph_ != &item.graph) {
+ init_ops_ = item.init_ops;
+ last_graph_ = nullptr;
+ queue_runner_defs_ = item.queue_runners;
+ }
+ return Status::OK();
+}
+
+Status SingleMachine::Run(const GraphDef& graph_def,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch,
+ RunMetadata* metadata) {
+ if (last_graph_ != &graph_def) {
+ Status status = ResetSession();
+ if (status.ok()) {
+ status = session_->Create(graph_def);
+ }
+ if (!init_ops_.empty() && status.ok()) {
+ status = RunWithTimeout({}, init_ops_, nullptr);
+ }
+ for (int i = 0; i < queue_runner_defs_.size() && status.ok(); ++i) {
+ std::unique_ptr<QueueRunner> queue_runner;
+ TF_RETURN_IF_ERROR(QueueRunner::New(queue_runner_defs_[i],
+ coordinator_.get(), &queue_runner));
+ TF_RETURN_IF_ERROR(queue_runner->Start(session_.get()));
+ TF_RETURN_IF_ERROR(coordinator_->RegisterRunner(std::move(queue_runner)));
+ status = coordinator_->GetStatus();
+ }
+
+ if (status.ok()) {
+ last_graph_ = &graph_def;
+ } else {
+ return status;
+ }
+
+ // Warmup TensorFlow if needed
+ for (int i = 0;
+ i < options_.config.graph_options().build_cost_model_after(); ++i) {
+ status = RunWithTimeout(feed, fetch, nullptr);
+ if (!status.ok()) {
+ return status;
+ }
+ }
+ }
+
+ return RunWithTimeout(feed, fetch, metadata);
+}
+
+Status SingleMachine::RunWithTimeout(
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch, RunMetadata* run_metadata) {
+ mutex_lock l(mu_);
+ // We shouldn't be running or closing the session at this point.
+ CHECK(!running_);
+ CHECK(!closing_);
+
+ running_ = true;
+ metadata_ = RunMetadata();
+
+ thread_pool_->Schedule([this, feed, fetch] {
+ Status status =
+ session_->Run(run_options_, feed, {}, fetch, nullptr, &this->metadata_);
+ mutex_lock l(mu_);
+ status_ = status;
+ running_ = false;
+ done_running_.notify_all();
+ });
+
+ while (running_) {
+ std::cv_status timeout =
+ done_running_.wait_for(l, std::chrono::milliseconds(timeout_s_ * 1000));
+ if (timeout != std::cv_status::no_timeout) {
+ last_graph_ = nullptr;
+ return Status(error::DEADLINE_EXCEEDED,
+ strings::StrCat("Failed to run the graph after ",
+ timeout_s_, " seconds, aborting"));
+ }
+ }
+ if (run_metadata && status_.ok()) {
+ *run_metadata = metadata_;
+ }
+ return status_;
+}
+
+Status SingleMachine::CloseSession(bool use_timeout) {
+ if (!session_) {
+ return Status::OK();
+ }
+
+ mutex_lock l(close_mu_);
+
+ if (!closing_) {
+ closing_ = true;
+
+ thread_pool_->Schedule([this] {
+ if (this->coordinator_) {
+ this->coordinator_->RequestStop();
+ // Wait for all the runners to have closed their queues.
+ while (!this->coordinator_->AllRunnersStopped()) {
+ sleep(1);
+ }
+ // Now we can close the session. This should cancel any pending I/O
+ // operation.
+ this->session_->Close();
+ // Last but not least, we can delete the coordinator.
+ this->coordinator_.reset();
+ } else {
+ this->session_->Close();
+ }
+
+ // Wait for any previous run to finish.
+ mutex_lock l(mu_);
+ while (running_) {
+ done_running_.wait(l);
+ }
+
+ mutex_lock l2(close_mu_);
+ closing_ = false;
+ done_closing_.notify_all();
+ });
+ }
+
+ while (closing_) {
+ if (!use_timeout) {
+ done_closing_.wait(l);
+ } else {
+ std::cv_status timeout = done_closing_.wait_for(
+ l, std::chrono::milliseconds(timeout_s_ * 1000));
+ if (timeout != std::cv_status::no_timeout) {
+ // Let the caller know that we can't shutdown the session, and therefore
+ // can't process any further.
+ return Status(
+ error::UNAVAILABLE,
+ strings::StrCat("Failed to close the previous session after ",
+ timeout_s_, " seconds, aborting"));
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status SingleMachine::ResetSession() {
+ if (session_) {
+ LOG(INFO) << "Cleaning up previous session";
+
+ // Make sure the session is properly closed
+ Status status = CloseSession(true /*use_timeout*/);
+ if (!status.ok()) {
+ return status;
+ }
+
+ // Flush all the pending closures (if any).
+ thread_pool_.reset(new thread::ThreadPool(
+ Env::Default(), SanitizeThreadSuffix("single_machine"), 2));
+
+ // We need to Reset the session to ensure that all the variables are
+ // deleted. But first we need to delete the session since Reset()
+ // deletes some of the containers referenced by the session.
+ session_.reset();
+ status = Reset(options_, {});
+ if (!status.ok()) {
+ return status;
+ }
+ }
+
+ LOG(INFO) << "Starting new session";
+
+ session_.reset(NewSession(options_));
+ CHECK(session_ != nullptr);
+
+ coordinator_.reset(new Coordinator());
+
+ return Status::OK();
+}
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h
new file mode 100644
index 0000000000..f53a3e849b
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/single_machine.h
@@ -0,0 +1,70 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_
+#define TENSORFLOW_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_
+
+#include "tensorflow/cc/training/coordinator.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Create a simple cluster that makes available to grappler a subset of the
+// nodes available on a single local computer.
+class SingleMachine : public Cluster {
+ public:
+ SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus);
+ ~SingleMachine() override;
+
+ Status Provision() override;
+ Status Initialize(const GrapplerItem& item) override;
+ Status Run(const GraphDef& item,
+ const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch, RunMetadata* metadata) override;
+
+ private:
+ Status RunWithTimeout(const std::vector<std::pair<string, Tensor>>& feed,
+ const std::vector<string>& fetch,
+ RunMetadata* run_metadata);
+ Status ResetSession();
+ Status CloseSession(bool use_timeout);
+
+ const int num_gpus_;
+ std::unique_ptr<Session> session_;
+ std::vector<QueueRunnerDef> queue_runner_defs_;
+ const GraphDef* last_graph_ = nullptr;
+ std::vector<string> init_ops_;
+ std::unique_ptr<Coordinator> coordinator_;
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+
+ Status status_;
+ RunMetadata metadata_;
+
+ mutex mu_;
+ bool running_;
+ condition_variable done_running_;
+
+ mutex close_mu_;
+ bool closing_;
+ condition_variable done_closing_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_
diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc
new file mode 100644
index 0000000000..3b39e5be61
--- /dev/null
+++ b/tensorflow/core/grappler/clusters/single_machine_test.cc
@@ -0,0 +1,134 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/clusters/single_machine.h"
+#include "tensorflow/core/framework/cost_graph.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class SingleMachineTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ // Provision a single machine with 3 cpu cores
+ cluster_.reset(new SingleMachine(5 * 60, 3, 0));
+ TF_CHECK_OK(cluster_->Provision());
+ }
+
+ void TearDown() override {
+ cluster_.reset();
+ }
+
+ protected:
+ std::unique_ptr<SingleMachine> cluster_;
+};
+
+TEST_F(SingleMachineTest, CostModel) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TF_CHECK_OK(cluster_->Initialize(item));
+
+ RunMetadata metadata;
+ const int64 start_micros = Env::Default()->NowMicros();
+ TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
+ const int64 run_duration_micros = Env::Default()->NowMicros() - start_micros;
+
+ // There should be at least 4 nodes corresponding to the 4 stages we created
+ // in the fake input.
+ EXPECT_LE(4, metadata.cost_graph().node_size());
+ for (const auto& node : metadata.cost_graph().node()) {
+ // Skip the special nodes inserted by TF: these are prefixed with an
+ // underscore.
+ if (node.name()[0] == '_' || node.name().find("/_") != string::npos) {
+ continue;
+ }
+ EXPECT_EQ(1, node.output_info_size());
+ EXPECT_LE(8, node.output_info(0).size());
+ const TensorShapeProto& shape = node.output_info(0).shape();
+ EXPECT_EQ(2, shape.dim_size());
+ EXPECT_EQ(10, shape.dim(0).size());
+ EXPECT_EQ(1, shape.dim(1).size());
+ EXPECT_LE(0, node.compute_cost());
+ EXPECT_GE(run_duration_micros, node.compute_cost());
+ }
+}
+
+TEST_F(SingleMachineTest, Queue) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, true,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TF_CHECK_OK(cluster_->Initialize(item));
+ RunMetadata metadata;
+ TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata));
+}
+
+TEST_F(SingleMachineTest, MultipleItems) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+
+ for (int i = 0; i < 3; ++i) {
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+ TF_CHECK_OK(cluster_->Initialize(item));
+ RunMetadata metadata1;
+ TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata1));
+ RunMetadata metadata2;
+ TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata2));
+
+ // There should be at least 4 nodes corresponding to the 4 stages we created
+ // in the fake input, plus 1 enqueue and 1 dequeue node.
+ EXPECT_LE(6, metadata1.cost_graph().node_size());
+ for (const auto& node : metadata1.cost_graph().node()) {
+ if (node.name()[0] == '_' || node.name().find("/_") != string::npos ||
+ node.name() == "queue") {
+ continue;
+ }
+ EXPECT_EQ(1, node.output_info_size());
+ const TensorShapeProto& shape = node.output_info(0).shape();
+ EXPECT_EQ(2, shape.dim_size());
+ EXPECT_EQ(10, shape.dim(0).size());
+ EXPECT_EQ(1, shape.dim(1).size());
+ }
+
+ for (int i = 0; i < metadata1.cost_graph().node_size(); ++i) {
+ metadata1.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0);
+ metadata1.clear_step_stats();
+ }
+ for (int i = 0; i < metadata2.cost_graph().node_size(); ++i) {
+ metadata2.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0);
+ metadata2.clear_step_stats();
+ }
+ string s1;
+ ::tensorflow::protobuf::TextFormat::PrintToString(metadata1, &s1);
+ string s2;
+ ::tensorflow::protobuf::TextFormat::PrintToString(metadata2, &s2);
+ EXPECT_EQ(s1, s2);
+ }
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
new file mode 100644
index 0000000000..161d1aa617
--- /dev/null
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -0,0 +1,116 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_proto_library",
+)
+load(
+ "@local_config_cuda//cuda:build_defs.bzl",
+ "if_cuda",
+)
+
+tf_proto_library(
+ name = "op_performance_data",
+ srcs = ["op_performance_data.proto"],
+ cc_api_version = 2,
+ protodeps = ["//tensorflow/core:protos_all"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "graph_properties",
+ srcs = ["graph_properties.cc"],
+ hdrs = ["graph_properties.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":op_performance_data_cc",
+ ":utils",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/clusters:cluster",
+ ],
+)
+
+cc_test(
+ name = "graph_properties_test",
+ srcs = ["graph_properties_test.cc"],
+ args = ["--heap_check=local"], # The GPU tracer leaks memory
+ deps = [
+ ":graph_properties",
+ "//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/clusters:single_machine",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
+cc_library(
+ name = "graph_memory",
+ srcs = ["graph_memory.cc"],
+ hdrs = ["graph_memory.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_properties",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/clusters:cluster",
+ ],
+)
+
+cc_test(
+ name = "graph_memory_test",
+ srcs = ["graph_memory_test.cc"],
+ args = ["--heap_check=local"], # The GPU tracer leaks memory
+ deps = [
+ ":graph_memory",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ ],
+)
+
+cc_library(
+ name = "utils",
+ srcs = ["utils.cc"],
+ hdrs = ["utils.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":op_performance_data_cc",
+ "//third_party/eigen3",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + if_cuda([
+ "//tensorflow/core:cuda",
+ "@local_config_cuda//cuda:cuda_headers",
+ ]),
+)
+
+cc_library(
+ name = "cost_estimator",
+ hdrs = ["cost_estimator.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h
new file mode 100644
index 0000000000..3c65c34f8d
--- /dev/null
+++ b/tensorflow/core/grappler/costs/cost_estimator.h
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
+#define TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
+
+#include <chrono>
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class GraphDef;
+class CostGraphDef;
+
+namespace grappler {
+struct GrapplerItem;
+
+constexpr int64 kMemoryUnknown = -1ll;
+constexpr int64 kZeroMemory = 0ll;
+
+// Holds the set of things we might want to estimate or measure in Grappler.
+// Always produce execution time. Other fields are optional depending on the
+// estimator being used.
+struct Costs {
+ // Returns a Costs structure with default values for all of the fields.
+ inline Costs();
+
+ // Builds a Costs structure with all zero values, rather than unknowns.
+ static inline Costs ZeroCosts();
+
+ struct MicroSeconds : std::chrono::microseconds {
+ MicroSeconds() : std::chrono::microseconds(0) {}
+ MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {}
+ MicroSeconds(std::chrono::microseconds& d) : std::chrono::microseconds(d) {}
+ MicroSeconds& operator=(const std::chrono::microseconds& d) {
+ std::chrono::microseconds::operator=(d);
+ return *this;
+ }
+ };
+ struct NanoSeconds : std::chrono::nanoseconds {
+ NanoSeconds() : std::chrono::nanoseconds(0) {}
+ NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {}
+ NanoSeconds(std::chrono::nanoseconds& d) : std::chrono::nanoseconds(d) {}
+ NanoSeconds& operator=(const std::chrono::nanoseconds& d) {
+ std::chrono::nanoseconds::operator=(d);
+ return *this;
+ }
+ MicroSeconds asMicroSeconds() const {
+ std::chrono::microseconds us =
+ std::chrono::duration_cast<std::chrono::microseconds>(*this);
+ return MicroSeconds(us);
+ }
+ };
+
+ // We store all our times in nanoseconds. If needs be, we can always switch to
+ // picoseconds in the future by updating this typedef.
+ typedef NanoSeconds Duration;
+
+ // Overall cost of running the graph; latency.
+ Duration execution_time;
+
+ // Computation cost of running the graph.
+ Duration compute_time;
+
+ // Memory access cost of running the graph.
+ Duration memory_time;
+
+ // This field can be a very pessimistic estimate of the main memory
+ // requirements of a graph. For example, it might assume that all activations
+ // are live for all of a graph's execution.
+ int64 max_memory; // Maximum main memory requirement in bytes over all ops.
+
+ // These fields are used for TPU-related estimations. They are per-op
+ // maximums, so each op is evaluated independently, but we want the maximum of
+ // the value over all ops.
+ int64 max_per_op_buffers; // Sum of all buffers used by the ops.
+ int64 max_per_op_streaming; // Ignore largest input buffer, assuming it
+ // streams from main memory.
+};
+
+inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
+ os << d.count() << "us";
+ return os;
+}
+inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) {
+ os << d.count() << "ns";
+ return os;
+}
+
+Costs::Costs() {
+ execution_time = Duration::zero();
+ compute_time = Duration::zero();
+ memory_time = Duration::zero();
+ max_memory = kMemoryUnknown;
+ max_per_op_buffers = kMemoryUnknown;
+ max_per_op_streaming = kMemoryUnknown;
+}
+
+Costs Costs::ZeroCosts() {
+ Costs costs;
+ costs.execution_time = Duration::zero();
+ costs.max_memory = kZeroMemory;
+ costs.max_per_op_buffers = kZeroMemory;
+ costs.max_per_op_streaming = kZeroMemory;
+ return costs;
+}
+
+// Given a GrapperItem and an optimized implementation of the corresponding
+// TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
+// running the graph.
+class CostEstimator {
+ public:
+ virtual ~CostEstimator() {}
+
+ // Initalizes the estimator for the specified grappler item.
+ // The estimator shouldn't be used if this function returns any status other
+ // that OK.
+ virtual Status Initialize(const GrapplerItem& item) = 0;
+
+ // Predicts the cost of running the given optimized version of the grappler
+ // item.
+ // If a CostGraphDef is passed, it will be populated with detailed information
+ // about the cost of running each operation of the optimized graph.
+ // if a double value is passed, it will be set to a value that reflects the
+ // overall cost of running the graph (e.g. the latency of the computation).
+ // Returns a status that indicate is the performance could be estimated or
+ // not.
+ virtual Status PredictCosts(const GraphDef& optimized_graph,
+ CostGraphDef* cost_graph, Costs* cost) const = 0;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
new file mode 100644
index 0000000000..b7827fc1ad
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/graph_memory.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status GraphMemory::InferStatically() {
+ GraphProperties properties(item_);
+ TF_RETURN_IF_ERROR(properties.InferStatically());
+ return InferFromGraphProperties(&properties);
+}
+
+Status GraphMemory::InferDynamically(Cluster* cluster) {
+ GraphProperties properties(item_);
+ TF_RETURN_IF_ERROR(properties.InferDynamically(cluster));
+ return InferFromGraphProperties(&properties);
+}
+
+Status GraphMemory::InferFromGraphProperties(GraphProperties* properties) {
+ // Compute the worst case usage between initialization and normal mode.
+ // TODO(bsteiner): we should consider persistent memory usage separately.
+ int64 worst_case_init_mem_usage;
+ int64 best_case_init_mem_usage;
+ InferMemUsageForNodes(item_.InitOpsFanin(), properties,
+ &worst_case_init_mem_usage, &best_case_init_mem_usage);
+ int64 worst_case_main_mem_usage;
+ int64 best_case_main_mem_usage;
+ InferMemUsageForNodes(item_.MainOpsFanin(), properties,
+ &worst_case_main_mem_usage, &best_case_main_mem_usage);
+
+ worst_case_memory_usage_ =
+ std::max(worst_case_init_mem_usage, worst_case_main_mem_usage);
+ best_case_memory_usage_ =
+ std::max(best_case_init_mem_usage, best_case_main_mem_usage);
+
+ return Status::OK();
+}
+
+void GraphMemory::InferMemUsageForNodes(
+ const std::vector<const NodeDef*>& nodes, GraphProperties* properties,
+ int64* worst_case_memory_usage, int64* best_case_memory_usage) const {
+ // TODO(bsteiner) refine this: we should consider the multidevice case.
+ *worst_case_memory_usage = 0;
+ *best_case_memory_usage = 0;
+ for (const auto& node : item_.graph.node()) {
+ // Estimate the memory required to store the tensors generated by the node.
+ std::vector<OpInfo::TensorProperties> outputs =
+ properties->GetOutputProperties(node.name());
+ int64 node_memory_usage = InferMemUsageForNeighbors(outputs);
+
+ // Worst case memory usage corresponds to the case where all the nodes are
+ // alive.
+ *worst_case_memory_usage += node_memory_usage;
+
+ // Estimate the memory required to store the input tensors needed by the
+ // node.
+ std::vector<OpInfo::TensorProperties> inputs =
+ properties->GetInputProperties(node.name());
+ node_memory_usage += InferMemUsageForNeighbors(inputs);
+
+ *best_case_memory_usage =
+ std::max(*best_case_memory_usage, node_memory_usage);
+ }
+}
+
+int64 GraphMemory::InferMemUsageForNeighbors(
+ const std::vector<OpInfo::TensorProperties>& props) const {
+ int64 neighbors_memory_usage = 0;
+ for (const auto& prop : props) {
+ DataType dtype = prop.dtype();
+ int size = DataTypeSize(dtype);
+ TensorShapeProto shape = prop.shape();
+ if (shape.unknown_rank()) {
+ // Can't infer the size if the rank is unknown, just skip.
+ continue;
+ }
+ // If one of the dimensions is unknown statically, assume it's one.
+ for (int i = 0; i < shape.dim_size(); ++i) {
+ if (shape.dim(i).size() < 0) {
+ shape.mutable_dim(i)->set_size(1);
+ }
+ }
+ int num_elems = TensorShape(shape).num_elements();
+ neighbors_memory_usage += num_elems * size;
+ }
+ return neighbors_memory_usage;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_memory.h b/tensorflow/core/grappler/costs/graph_memory.h
new file mode 100644
index 0000000000..a3e152a0e1
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_memory.h
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_COSTS_GRAPH_MEMORY_H_
+#define TENSORFLOW_GRAPPLER_COSTS_GRAPH_MEMORY_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Infer the worst case memory usage for a given grappler item.
+class GraphMemory {
+ public:
+ explicit GraphMemory(const GrapplerItem& item)
+ : item_(item), worst_case_memory_usage_(-1) {}
+
+ Status InferStatically();
+ Status InferDynamically(Cluster* cluster);
+ Status InferFromGraphProperties(GraphProperties* properties);
+
+ // Worst case memory usage in bytes, or -1 if the usage is unknown.
+ int64 GetWorstCaseMemoryUsage() const { return worst_case_memory_usage_; }
+
+ // Best case memory usage in bytes, or -1 if the usage is unknown.
+ // This corresponds to the case where all the data is swapped out excepted
+ // that which is needed for a single node to perform its computations.
+ int64 GetBestCaseMemoryUsage() const { return best_case_memory_usage_; }
+
+ private:
+ void InferMemUsageForNodes(const std::vector<const NodeDef*>& nodes,
+ GraphProperties* properties, int64* worst_case,
+ int64* best_case) const;
+ int64 InferMemUsageForNeighbors(
+ const std::vector<OpInfo::TensorProperties>& props) const;
+
+ // Inputs
+ GrapplerItem item_;
+ int64 worst_case_memory_usage_;
+ int64 best_case_memory_usage_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_COSTS_GRAPH_MEMORY_H_
diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc
new file mode 100644
index 0000000000..a3c58a1d76
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_memory_test.cc
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/graph_memory.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class GraphMemoryTest : public ::testing::Test {};
+
+TEST_F(GraphMemoryTest, Basic) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {{"CPU:0"}});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphMemory memory(item);
+ Status s = memory.InferStatically();
+ TF_CHECK_OK(s);
+ EXPECT_EQ(240, memory.GetWorstCaseMemoryUsage());
+ EXPECT_EQ(80, memory.GetBestCaseMemoryUsage());
+}
+
+TEST_F(GraphMemoryTest, UnknownBatchSize) {
+ TrivialTestGraphInputYielder fake_input(4, 1, -1, false, {{"CPU:0"}});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphMemory memory(item);
+ Status s = memory.InferStatically();
+ TF_CHECK_OK(s);
+ EXPECT_EQ(24, memory.GetWorstCaseMemoryUsage());
+ EXPECT_EQ(8, memory.GetBestCaseMemoryUsage());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
new file mode 100644
index 0000000000..345c7e2f21
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -0,0 +1,159 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+
+#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/grappler/costs/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status GraphProperties::InferStatically() {
+ Graph graph(OpRegistry::Global());
+ ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry());
+ ImportGraphDefOptions options;
+ Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
+ TF_RETURN_IF_ERROR(s);
+
+ for (const Node* const node : graph.nodes()) {
+ VLOG(1) << "<Node> " << node->name();
+ auto ctx = shape_refiner.GetContext(node);
+ if (!ctx) {
+ continue;
+ }
+ CHECK_EQ(ctx->num_inputs(), node->num_inputs());
+ std::vector<OpInfo::TensorProperties> input_properties;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ OpInfo::TensorProperties properties;
+ properties.set_dtype(node->input_type(i));
+ shape_inference::ShapeHandle shp = ctx->input(i);
+ if (!ctx->RankKnown(shp)) {
+ properties.mutable_shape()->set_unknown_rank(true);
+ } else {
+ for (int j = 0; j < ctx->Rank(shp); ++j) {
+ shape_inference::DimensionHandle dim = ctx->Dim(shp, j);
+ int64 d = ctx->Value(dim);
+ properties.mutable_shape()->add_dim()->set_size(d);
+ }
+ }
+ input_properties.push_back(properties);
+ }
+ input_properties_[node->name()] = input_properties;
+
+ // TODO(bsteiner): share this code with the input processing above.
+ CHECK_EQ(ctx->num_outputs(), node->num_outputs());
+ std::vector<OpInfo::TensorProperties> output_properties;
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ OpInfo::TensorProperties properties;
+ properties.set_dtype(node->output_type(i));
+ shape_inference::ShapeHandle shp = ctx->output(i);
+ if (!ctx->RankKnown(shp)) {
+ properties.mutable_shape()->set_unknown_rank(true);
+ } else {
+ for (int j = 0; j < ctx->Rank(shp); ++j) {
+ shape_inference::DimensionHandle dim = ctx->Dim(shp, j);
+ int64 d = ctx->Value(dim);
+ properties.mutable_shape()->add_dim()->set_size(d);
+ }
+ }
+ output_properties.push_back(properties);
+ }
+ output_properties_[node->name()] = output_properties;
+
+ if (!node->assigned_device_name().empty()) {
+ device_names_[node->name()] = node->assigned_device_name();
+ } else if (!node->def().device().empty()) {
+ device_names_[node->name()] = node->def().device();
+ } else {
+ device_names_[node->name()] = "not set";
+ }
+ }
+
+ return Status::OK();
+}
+
+Status GraphProperties::InferDynamically(Cluster* cluster) {
+ TF_RETURN_IF_ERROR(cluster->Initialize(item_));
+
+ // Runs the model once to collect the shapes in the cost model.
+ RunMetadata metadata;
+ TF_RETURN_IF_ERROR(
+ cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
+
+ std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
+ for (auto& node : metadata.cost_graph().node()) {
+ name_to_cost[node.name()] = &node;
+
+ std::vector<OpInfo::TensorProperties> output_properties;
+ for (const auto& out : node.output_info()) {
+ OpInfo::TensorProperties properties;
+ properties.set_dtype(out.dtype());
+ *properties.mutable_shape() = out.shape();
+ output_properties.push_back(properties);
+ }
+ output_properties_[node.name()] = output_properties;
+ }
+
+ for (const auto& node : item_.graph.node()) {
+ // Skip the nodes that are not in the cost graph: these are nodes that
+ // aren't run, because they aren't in the intersection of transitive fan-in
+ // of a fetch node and the transitive fan-out of an input, or nodes that
+ // were optimized away by the optimizer.
+ auto it = name_to_cost.find(node.name());
+ if (it == name_to_cost.end()) {
+ continue;
+ }
+ std::vector<OpInfo::TensorProperties> inputs =
+ FindInputFeatures(node, name_to_cost);
+
+ input_properties_[node.name()] = inputs;
+
+ const CostGraphDef::Node* cost_node = it->second;
+ device_names_[node.name()] = cost_node->device();
+ }
+ return Status::OK();
+}
+
+std::vector<OpInfo::TensorProperties> GraphProperties::GetInputProperties(
+ const string& node_name) const {
+ auto it = input_properties_.find(node_name);
+ if (it != input_properties_.end()) {
+ return it->second;
+ }
+ return std::vector<OpInfo::TensorProperties>();
+}
+
+std::vector<OpInfo::TensorProperties> GraphProperties::GetOutputProperties(
+ const string& node_name) const {
+ auto it = output_properties_.find(node_name);
+ if (it != output_properties_.end()) {
+ return it->second;
+ }
+ return std::vector<OpInfo::TensorProperties>();
+}
+
+string GraphProperties::GetDeviceName(const string& node_name) const {
+ auto it = device_names_.find(node_name);
+ if (it != device_names_.end()) {
+ return it->second;
+ }
+ return "";
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
new file mode 100644
index 0000000000..c49313a220
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_
+#define TENSORFLOW_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_
+
+#include <unordered_map>
+#include <vector>
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// A TensorFlow model to optimize.
+// Models are represented by the combination of a graph, one of more fetch
+// nodes, and potentially a set of nodes to feed.
+class GraphProperties {
+ public:
+ // Factory method for creating a GrapplerShapes from a MetaGraphDef.
+ // Returns nullptr if the given meta_graph cannot be converted.
+ explicit GraphProperties(const GrapplerItem& item) : item_(item) {}
+
+ Status InferStatically();
+ Status InferDynamically(Cluster* cluster);
+
+ std::vector<OpInfo::TensorProperties> GetInputProperties(
+ const string& node_name) const;
+ std::vector<OpInfo::TensorProperties> GetOutputProperties(
+ const string& node_name) const;
+ string GetDeviceName(const string& node_name) const;
+
+ private:
+ // Inputs
+ GrapplerItem item_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
+ std::map<string, string> device_names_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
new file mode 100644
index 0000000000..d2f448e6d3
--- /dev/null
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/clusters/single_machine.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class GraphPropertiesTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ // Provision a single machine with 3 cpu cores
+ cluster_.reset(new SingleMachine(5 * 60, 3, 0));
+ TF_CHECK_OK(cluster_->Provision());
+ }
+
+ void TearDown() override { cluster_.reset(); }
+
+ protected:
+ std::unique_ptr<SingleMachine> cluster_;
+};
+
+TEST_F(GraphPropertiesTest, StaticProperties) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphProperties properties(item);
+ Status s = properties.InferStatically();
+ TF_CHECK_OK(s);
+
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "Const") {
+ // The const node has no input.
+ EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
+ // The const node has one output.
+ const auto props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, props.size());
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT, prop.dtype());
+ EXPECT_FALSE(prop.shape().unknown_rank());
+ EXPECT_EQ(2, prop.shape().dim_size());
+ EXPECT_EQ(10, prop.shape().dim(0).size());
+ EXPECT_EQ(1, prop.shape().dim(1).size());
+ } else if (node.op() == "AddN") {
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ EXPECT_EQ(10, in_prop.shape().dim(0).size());
+ EXPECT_EQ(1, in_prop.shape().dim(1).size());
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ string in_prop_str;
+ ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str);
+ string out_prop_str;
+ ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
+ &out_prop_str);
+ EXPECT_EQ(in_prop_str, out_prop_str);
+ }
+ }
+}
+
+TEST_F(GraphPropertiesTest, DynamicProperties) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(cluster_->Initialize(item));
+ Status s = properties.InferDynamically(cluster_.get());
+ TF_CHECK_OK(s);
+
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "Const") {
+ // The constant node is missing from the cost graph
+ EXPECT_EQ(0, properties.GetInputProperties(node.name()).size());
+ } else if (node.op() == "AddN") {
+ // Since the const node is missing, we can't infer the input properties of
+ // the first AddN node. THe other AddN have the expected properties
+ if (node.name() == "AddN") {
+ const auto props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, props.size());
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_INVALID, prop.dtype());
+ EXPECT_TRUE(prop.shape().unknown_rank());
+ } else {
+ const auto props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, props.size());
+ const OpInfo::TensorProperties& prop = props[0];
+ EXPECT_EQ(DT_FLOAT, prop.dtype());
+ EXPECT_FALSE(prop.shape().unknown_rank());
+ EXPECT_EQ(2, prop.shape().dim_size());
+ EXPECT_EQ(10, prop.shape().dim(0).size());
+ EXPECT_EQ(1, prop.shape().dim(1).size());
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ string prop_str;
+ ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str);
+ string out_prop_str;
+ ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0],
+ &out_prop_str);
+ EXPECT_EQ(prop_str, out_prop_str);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto
new file mode 100644
index 0000000000..2d3fb1939c
--- /dev/null
+++ b/tensorflow/core/grappler/costs/op_performance_data.proto
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+option cc_enable_arenas = true;
+
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+import "tensorflow/core/framework/attr_value.proto";
+
+// Description of an operation as well as the parameters expected to impact its
+// performance.
+message OpInfo {
+ // The operation name. There may be custom parameters in attrs.
+ string op = 1;
+
+ // Custom parameters impacting the behavior of the op.
+ map<string, AttrValue> attr = 2;
+
+ // Input types and shapes
+ message TensorProperties {
+ DataType dtype = 1;
+ TensorShapeProto shape = 2;
+ };
+ repeated TensorProperties inputs = 3;
+
+ // Device on which the operation is run.
+ message DeviceProperties {
+ // Device type (CPU, GPU, ...)
+ string type = 1;
+ // Vendor (Intel, nvidia, ...)
+ string vendor = 2;
+ // Model (Haswell, K40, ...)
+ string model = 3;
+ // Core Frequency in Mhz
+ int64 frequency = 4;
+ // Number of cores
+ int64 num_cores = 5;
+ // Version of the tools and libraries used with this device (e.g. gcc 4.9,
+ // cudnn 5.1)
+ map<string, string> environment = 6;
+ // Number of registers per core.
+ int64 num_registers = 7;
+ // L1 cache size in bytes
+ int64 l1_cache_size = 8;
+ // L2 cache size in bytes
+ int64 l2_cache_size = 9;
+ // L3 cache size in bytes
+ int64 l3_cache_size = 10;
+ // Shared memory size per multiprocessor in bytes. This field is
+ // applicable to GPUs only.
+ int64 shared_memory_size_per_multiprocessor = 11;
+ // Memory size in bytes
+ int64 memory_size = 12;
+ // Memory bandwidth in KB/s
+ int64 bandwidth = 13;
+ }
+ DeviceProperties device = 4;
+}
+
+// Performance data for tensorflow operations
+message OpPerformance {
+ // The op
+ OpInfo op = 1;
+
+ // The node name (optional). Makes it easier to associate the performance data
+ // with a specific graph node.
+ string node = 5;
+
+ // Temporary memory used by this node (in bytes).
+ int64 temporary_memory_size = 2;
+
+ // Time it takes to run the op (in nanoseconds).
+ int64 compute_cost = 3;
+
+ // Analytical compute cost (in nanoseconds).
+ int64 compute_time = 6;
+
+ // Analytical memory access cost (in nanoseconds).
+ int64 memory_time = 7;
+
+ // Percentage of theoretical compute performance.
+ double compute_efficiency = 4;
+
+ // Percentage of theoretical memory performance.
+ double memory_efficiency = 8;
+
+ // Memory usage data for a tensorflow operation.
+ message OpMemory {
+ // The output information may have memory usage and output shapes.
+ repeated int64 output_memory = 1;
+
+ // Temporary memory allocated by this node.
+ int64 host_temp_memory = 2;
+ int64 device_temp_memory = 3;
+
+ // The persisted_memory doesn't include outputs.
+ int64 host_persistent_memory = 4;
+ int64 device_persistent_memory = 5;
+ }
+ OpMemory op_memory = 9;
+}
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
new file mode 100644
index 0000000000..19266208fe
--- /dev/null
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -0,0 +1,159 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/costs/utils.h"
+
+#include <stddef.h>
+#include <utility>
+
+#include "third_party/eigen3/Eigen/Core"
+
+#if GOOGLE_CUDA
+#include "cuda/include/cuda.h"
+#include "cuda/include/cuda_runtime_api.h"
+#include "cuda/include/cudnn.h"
+#endif
+
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+std::vector<OpInfo::TensorProperties> FindInputFeatures(
+ const NodeDef& node,
+ const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost) {
+ std::vector<OpInfo::TensorProperties> inputs;
+ for (const auto& input_name : node.input()) {
+ // Skip control inputs. These are prefixed with the ^ character.
+ CHECK(!input_name.empty());
+ if (input_name[0] == '^') {
+ continue;
+ }
+
+ // Each input is "node_name:output_imdex" with "node_name" being a string
+ // name and "output_index" indicating which output tensor to use from
+ // "node_name". If "output_index" is 0 the ":0" suffix can be omitted.
+ string input_node_name;
+ int output_index = -1;
+ const size_t pos = input_name.rfind(':');
+ if (pos == string::npos) {
+ input_node_name = input_name;
+ output_index = 0;
+ } else {
+ string index = input_name.substr(pos);
+ if (strings::safe_strto32(index, &output_index)) {
+ input_node_name = input_name.substr(0, pos);
+ }
+ }
+
+ auto it = name_to_cost.find(input_name);
+ if (it == name_to_cost.end() || output_index < 0) {
+ OpInfo::TensorProperties input;
+ input.set_dtype(DataType::DT_INVALID);
+ input.mutable_shape()->set_unknown_rank(true);
+ inputs.push_back(input);
+ } else {
+ const CostGraphDef::Node* input_cost = it->second;
+ const CostGraphDef::Node::OutputInfo& output =
+ input_cost->output_info(output_index);
+ OpInfo::TensorProperties input;
+ input.set_dtype(output.dtype());
+ *input.mutable_shape() = output.shape();
+ inputs.push_back(input);
+ }
+ }
+
+ return inputs;
+}
+
+OpInfo::DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) {
+ DeviceNameUtils::ParsedName parsed;
+ if (DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
+ if (parsed.type == "GPU") {
+ return GetLocalGPUInfo(parsed.id);
+ } else if (parsed.type == "CPU") {
+ return GetLocalCPUInfo();
+ }
+ }
+ OpInfo::DeviceProperties device;
+ device.set_type("UNKNOWN");
+ return device;
+}
+
+OpInfo::DeviceProperties GetLocalCPUInfo() {
+ OpInfo::DeviceProperties device;
+ device.set_type("CPU");
+
+ device.set_num_cores(port::NumSchedulableCPUs());
+ device.set_l1_cache_size(Eigen::l1CacheSize());
+ device.set_l2_cache_size(Eigen::l2CacheSize());
+ device.set_l3_cache_size(Eigen::l3CacheSize());
+
+ (*device.mutable_environment())["cpu_instruction_set"] =
+ Eigen::SimdInstructionSetsInUse();
+
+ (*device.mutable_environment())["eigen"] = strings::StrCat(
+ EIGEN_WORLD_VERSION, ".", EIGEN_MAJOR_VERSION, ".", EIGEN_MINOR_VERSION);
+#ifdef EIGEN_USE_LIBXSMM
+ (*device.mutable_environment())["libxsmm"] = LIBXSMM_VERSION;
+#endif
+
+ return device;
+}
+
+OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) {
+ OpInfo::DeviceProperties device;
+ device.set_type("GPU");
+
+#if GOOGLE_CUDA
+ cudaDeviceProp properties;
+ cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id);
+ if (error == cudaSuccess) {
+ device.set_vendor("NVidia");
+ device.set_model(properties.name);
+ device.set_frequency(properties.clockRate / 1000);
+ device.set_num_cores(properties.multiProcessorCount);
+ device.set_num_registers(properties.regsPerMultiprocessor);
+ // For compute capability less than 5, l1 cache size is configurable to
+ // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For
+ // compute capability larger or equal to 5, l1 cache (unified with texture
+ // cache) size is 24 KB. This number may need to be updated for future
+ // compute capabilities.
+ device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024);
+ device.set_l2_cache_size(properties.l2CacheSize);
+ device.set_l3_cache_size(0);
+ device.set_shared_memory_size_per_multiprocessor(
+ properties.sharedMemPerMultiprocessor);
+ device.set_memory_size(properties.totalGlobalMem);
+ // 8 is the number of bits per byte. 2 is accounted for
+ // double data rate (DDR).
+ device.set_bandwidth(properties.memoryBusWidth / 8 *
+ properties.memoryClockRate * 2);
+ }
+
+ (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION);
+ (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION);
+#endif
+
+ return device;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
new file mode 100644
index 0000000000..79be906128
--- /dev/null
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_COSTS_UTILS_H_
+#define TENSORFLOW_GRAPPLER_COSTS_UTILS_H_
+
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/cost_graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/graph/types.h"
+#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Returns a vector of InputProperties for 'node'. The vector will contain one
+// entry for each input of 'node'.
+// For each node in the graph, the 'name_to_cost' map stores a pointer to the
+// corresponding cost graph node indexed by node name.
+std::vector<OpInfo::TensorProperties> FindInputFeatures(
+ const NodeDef& node,
+ const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost);
+
+// Returns the DeviceProperties of the device on which 'node' runs.
+OpInfo::DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node);
+
+// Returns the DeviceProperties of the CPU on which grappler is running.
+OpInfo::DeviceProperties GetLocalCPUInfo();
+
+// Returns the DeviceProperties for the specified GPU attached to the server on
+// which grappler is running.
+OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id);
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_COSTS_UTILS_H_
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
new file mode 100644
index 0000000000..907ef2096a
--- /dev/null
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -0,0 +1,251 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/grappler_item.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/variable.pb.h"
+#include "tensorflow/core/framework/versions.h"
+#include "tensorflow/core/grappler/inputs/utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// static
+std::unique_ptr<GrapplerItem> GrapplerItem::FromMetaGraphDef(
+ const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
+ // Check if the graph is compatible with the current version of TensorFlow.
+ Status status =
+ CheckVersions(meta_graph.graph_def().versions(), TF_GRAPH_DEF_VERSION,
+ TF_GRAPH_DEF_VERSION_MIN_PRODUCER, "GraphDef", "graph");
+ if (!status.ok()) {
+ LOG(ERROR) << "Cannot process item: " << status.error_message();
+ return nullptr;
+ }
+
+ std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
+ if (id.empty()) {
+ LOG(ERROR) << "id must be non-empty.";
+ return nullptr;
+ }
+ new_item->id = id;
+ new_item->graph = meta_graph.graph_def();
+
+ // Attempt to detect the fetch node(s).
+ if (meta_graph.collection_def().count("train_op") > 0) {
+ const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
+ if (nodes.has_node_list()) {
+ for (const auto& node : nodes.node_list().value()) {
+ const string name = NodeName(node);
+ if (name.empty()) {
+ LOG(ERROR) << "Invalid fetch node name " << node
+ << ", skipping this input";
+ return nullptr;
+ }
+ LOG(INFO) << "Will use fetch node " << name;
+ new_item->fetch.push_back(name);
+ }
+ }
+ }
+ if (new_item->fetch.empty()) {
+ LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
+ return nullptr;
+ }
+
+ for (auto& node : *new_item->graph.mutable_node()) {
+ // Delete user specified placement if requested.
+ if (cfg.ignore_user_placement) {
+ node.clear_device();
+ }
+
+ if (node.op() == "Placeholder" || node.op() == "PlaceholderV2") {
+ if (node.attr().count("dtype") == 0) {
+ LOG(ERROR) << "Unknown type for placeholder " << node.name()
+ << ", skipping this input";
+ return nullptr;
+ }
+ DataType type = node.attr().at("dtype").type();
+
+ if (node.attr().count("shape") == 0) {
+ LOG(INFO) << "Unknown shape for placeholder " << node.name()
+ << ", skipping this input";
+ return nullptr;
+ }
+ TensorShape shape(node.attr().at("shape").shape());
+ // Some placeholder nodes have a mis-match between the node
+ // attribute "shape" and a different node attribute "_output_shapes".
+ // Specifically, a shape with shape.dims() == 0 could indicate either
+ // a scalar or an unknown shape. In those cases, we check _output_shapes
+ // for additional information.
+ // This case is observed in the bnmt graphs. Have not observed any
+ // cases where there was more than 1 _output_shapes, so limit it
+ // to cases where there is only 1 _output_shapes.
+ // We only do this if cfg.placeholder_unknown_output_shape_dim has
+ // been set to avoid crashing non-BNMT graphs.
+ if ((cfg.placeholder_unknown_output_shape_dim >= 0) &&
+ (shape.dims() == 0) && (node.attr().count("_output_shapes") == 1) &&
+ (node.attr().at("_output_shapes").list().shape(0).dim_size() != 0)) {
+ shape.Clear();
+ for (int dim_i = 0;
+ dim_i <
+ node.attr().at("_output_shapes").list().shape(0).dim_size();
+ dim_i++) {
+ const ::tensorflow::TensorShapeProto_Dim dim =
+ node.attr().at("_output_shapes").list().shape(0).dim(dim_i);
+ if (dim.size() == -1) {
+ shape.AddDim(cfg.placeholder_unknown_output_shape_dim);
+ } else {
+ shape.AddDim(node.attr()
+ .at("_output_shapes")
+ .list()
+ .shape(0)
+ .dim(dim_i)
+ .size());
+ }
+ }
+ }
+ Tensor fake_input(type, shape);
+ // TODO(bsteiner): figure out a better way to initialize the feeds, for
+ // example by recording a sample of the fed inputs in mldash when running
+ // the graph.
+ memset(const_cast<char*>(fake_input.tensor_data().data()), 0,
+ fake_input.tensor_data().size());
+ new_item->feed.emplace_back(node.name(), fake_input);
+ }
+
+ if (cfg.ignore_colocation) {
+ auto attr = node.mutable_attr();
+ auto it = attr->find("_class");
+ if (it != attr->end()) {
+ attr->erase(it);
+ }
+ }
+ }
+
+ for (const string& var_collection :
+ {"variables", "local_variables", "model_variables",
+ "trainable_variables"}) {
+ if (meta_graph.collection_def().count(var_collection) == 0) {
+ continue;
+ }
+ const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
+ for (const auto& raw_var : vars.bytes_list().value()) {
+ VariableDef var;
+ var.ParseFromString(raw_var);
+ if (!var.initializer_name().empty()) {
+ new_item->init_ops.push_back(var.initializer_name());
+ }
+ }
+ }
+
+ if (meta_graph.collection_def().count("table_initializer") > 0) {
+ const CollectionDef& inits =
+ meta_graph.collection_def().at("table_initializer");
+ if (inits.has_node_list()) {
+ for (const auto& node : inits.node_list().value()) {
+ new_item->init_ops.push_back(node);
+ }
+ }
+ }
+
+ if (meta_graph.collection_def().count("queue_runners") > 0) {
+ const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
+ for (const auto& raw : vars.bytes_list().value()) {
+ QueueRunnerDef queue_runner;
+ if (!queue_runner.ParseFromString(raw)) {
+ LOG(ERROR) << "Could parse queue_runners, skipping this input";
+ return nullptr;
+ }
+ if (queue_runner.cancel_op_name().empty()) {
+ LOG(ERROR) << "Queue without a cancel op, skipping this input";
+ return nullptr;
+ }
+ new_item->queue_runners.push_back(queue_runner);
+ }
+ }
+
+ // Make sure we still can access the input files (aka "asset_filepaths") since
+ // these might have been moved or deleted, the cns cell might have been shut
+ // down, or we might be running as a user who does not have access to the
+ // files.
+ if (meta_graph.collection_def().count("asset_filepaths") > 0) {
+ const CollectionDef& file_paths =
+ meta_graph.collection_def().at("asset_filepaths");
+ std::vector<string> paths;
+ for (const auto& raw_path : file_paths.bytes_list().value()) {
+ paths.push_back(raw_path);
+ }
+ if (!FilesExist(paths, nullptr)) {
+ LOG(ERROR)
+ << "Can't access one or more of the asset files, skipping this input";
+ return nullptr;
+ }
+ }
+
+ return new_item;
+}
+
+std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const {
+ return ComputeTransitiveFanin(graph, fetch);
+}
+
+std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const {
+ return ComputeTransitiveFanin(graph, init_ops);
+}
+
+std::vector<const NodeDef*> ComputeTransitiveFanin(
+ const GraphDef& graph, const std::vector<string>& terminal_nodes) {
+ std::unordered_map<string, const NodeDef*> name_to_node;
+ for (const auto& node : graph.node()) {
+ name_to_node[node.name()] = &node;
+ }
+
+ std::vector<const NodeDef*> queue;
+ for (const string& root : terminal_nodes) {
+ const NodeDef* node = name_to_node[NodeName(root)];
+ CHECK(node);
+ queue.push_back(node);
+ }
+
+ std::vector<const NodeDef*> result;
+ std::unordered_set<const NodeDef*> visited;
+
+ while (!queue.empty()) {
+ const NodeDef* node = queue.back();
+ queue.pop_back();
+ if (!visited.insert(node).second) {
+ // The node has already been visited.
+ continue;
+ }
+ result.push_back(node);
+ for (const string& input : node->input()) {
+ const NodeDef* in = name_to_node[NodeName(input)];
+ CHECK(in);
+ queue.push_back(in);
+ }
+ }
+ return result;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
new file mode 100644
index 0000000000..dff288de9f
--- /dev/null
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_H_
+#define TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/protobuf/queue_runner.pb.h"
+
+namespace tensorflow {
+
+class MetaGraphDef;
+
+namespace grappler {
+
+struct ItemConfig {
+ // If true, ignore all user specified node placement.
+ bool ignore_user_placement = true;
+ // If true, ignore all user specified colocation attributes.
+ bool ignore_colocation = true;
+ // Dimension to use if a placeholder node has an _output_shapes attribute with
+ // a dimension of -1.
+ int32 placeholder_unknown_output_shape_dim = -1;
+};
+
+// A TensorFlow model to optimize.
+// Models are represented by the combination of a graph, one of more fetch
+// nodes, and potentially a set of nodes to feed.
+// TODO(volunteer_needed): turn this struct into a class.
+struct GrapplerItem {
+ // Factory method for creating a GrapplerItem from a MetaGraphDef.
+ // Returns nullptr if the given meta_graph cannot be converted.
+ static std::unique_ptr<GrapplerItem> FromMetaGraphDef(
+ const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg);
+
+ string id; // A unique id for this item
+
+ // Inputs
+ GraphDef graph;
+ std::vector<std::pair<string, Tensor>> feed;
+ std::vector<string> fetch;
+
+ // Initialization op(s).
+ std::vector<string> init_ops;
+
+ // Queue runner(s) required to run the queue(s) of this model.
+ std::vector<QueueRunnerDef> queue_runners;
+
+ // Return the set of node evaluated during a regular train/inference step.
+ std::vector<const NodeDef*> MainOpsFanin() const;
+ // Return the set nodes used by TensorFlow to initialize the graph.
+ std::vector<const NodeDef*> InitOpsFanin() const;
+};
+
+// Return the transitive fanin of a set of terminal nodes.
+std::vector<const NodeDef*> ComputeTransitiveFanin(
+ const GraphDef& graph, const std::vector<string>& terminal_nodes);
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_H_
diff --git a/tensorflow/core/grappler/grappler_item_test.cc b/tensorflow/core/grappler/grappler_item_test.cc
new file mode 100644
index 0000000000..72a9f481ca
--- /dev/null
+++ b/tensorflow/core/grappler/grappler_item_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class GrapplerItemTest : public ::testing::Test {};
+
+TEST_F(GrapplerItemTest, Basic) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {{"CPU:0"}});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ EXPECT_TRUE(item.InitOpsFanin().empty());
+
+ std::vector<string> graph_nodes;
+ for (const auto& node : item.graph.node()) {
+ graph_nodes.push_back(node.name());
+ }
+ std::vector<string> main_ops;
+ for (const auto& node : item.MainOpsFanin()) {
+ main_ops.push_back(node->name());
+ }
+ std::sort(graph_nodes.begin(), graph_nodes.end());
+ std::sort(main_ops.begin(), main_ops.end());
+ EXPECT_EQ(main_ops, graph_nodes);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/inputs/BUILD b/tensorflow/core/grappler/inputs/BUILD
new file mode 100644
index 0000000000..f0ca36d85d
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/BUILD
@@ -0,0 +1,72 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+cc_library(
+ name = "utils",
+ srcs = [
+ "utils.cc",
+ ],
+ hdrs = [
+ "utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "utils_test",
+ srcs = [
+ "utils_test.cc",
+ ],
+ deps = [
+ ":utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "input_yielder",
+ hdrs = [
+ "input_yielder.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [],
+)
+
+cc_library(
+ name = "trivial_test_graph_input_yielder",
+ srcs = ["trivial_test_graph_input_yielder.cc"],
+ hdrs = [
+ "trivial_test_graph_input_yielder.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":input_yielder",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/kernels:aggregate_ops",
+ "//tensorflow/core/kernels:array",
+ ],
+)
diff --git a/tensorflow/core/grappler/inputs/input_yielder.h b/tensorflow/core/grappler/inputs/input_yielder.h
new file mode 100644
index 0000000000..c9f90820a9
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/input_yielder.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_INPUTS_INPUT_YIELDER_H_
+#define TENSORFLOW_GRAPPLER_INPUTS_INPUT_YIELDER_H_
+
+namespace tensorflow {
+namespace grappler {
+
+struct GrapplerItem;
+
+// Abstract interface for yielding graphs that we want to optimize.
+class InputYielder {
+ public:
+ virtual ~InputYielder() {}
+
+ virtual bool NextItem(GrapplerItem* item) = 0;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_INPUTS_INPUT_YIELDER_H_
diff --git a/tensorflow/core/grappler/inputs/testdata/test_file.txt b/tensorflow/core/grappler/inputs/testdata/test_file.txt
new file mode 100644
index 0000000000..557db03de9
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/testdata/test_file.txt
@@ -0,0 +1 @@
+Hello World
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
new file mode 100644
index 0000000000..8370133fc4
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The builtin inputs provide a mechanism to generate simple TensorFlow graphs
+// and feed them as inputs to Grappler. This can be used for quick experiments
+// or to derive small regression tests.
+
+#include "tensorflow/cc/ops/standard_ops.h"
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Make a program with specified number of stages and "width" ops per stage.
+namespace {
+GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
+ bool use_multiple_devices, bool insert_queue,
+ const std::vector<string>& device_names) {
+ CHECK_GE(device_names.size(), width);
+
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ // x is from the feed.
+ const int batch_size = tensor_size < 0 ? 1 : tensor_size;
+ Output x = Const(s.WithOpName("x"), 0.0f, {batch_size, 1});
+
+ // Create stages.
+ std::vector<Output> last_stage;
+ last_stage.push_back(x);
+ for (int i = 0; i < num_stages; i++) {
+ std::vector<Output> this_stage;
+ for (int j = 0; j < width; j++) {
+ Output combine = AddN(
+ s.WithDevice(device_names[use_multiple_devices ? j : 0]), last_stage);
+ this_stage.push_back(combine);
+ }
+ last_stage = this_stage;
+ }
+
+ if (insert_queue) {
+ FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT});
+ QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage);
+ QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT});
+ QueueClose cancel(s.WithOpName("cancel"), queue,
+ QueueClose::CancelPendingEnqueues(true));
+ last_stage = {dequeue[0]};
+ }
+
+ // Create output.
+ AddN output(s.WithOpName("y"), last_stage);
+
+ GraphDef def;
+ TF_CHECK_OK(s.ToGraphDef(&def));
+ return def;
+}
+} // namespace
+
+TrivialTestGraphInputYielder::TrivialTestGraphInputYielder(
+ int num_stages, int width, int tensor_size, bool insert_queue,
+ const std::vector<string>& device_names)
+ : num_stages_(num_stages),
+ width_(width),
+ tensor_size_(tensor_size),
+ insert_queue_(insert_queue),
+ device_names_(device_names) {}
+
+bool TrivialTestGraphInputYielder::NextItem(GrapplerItem* item) {
+ GrapplerItem r;
+ r.id = strings::StrCat("ns:", num_stages_, "/", // wrap
+ "w:", width_, "/", // wrap
+ "ts:", tensor_size_);
+ r.graph = CreateGraphDef(num_stages_, width_, tensor_size_,
+ true /*use_multiple_devices*/, insert_queue_,
+ device_names_);
+ // If the batch size is variable, we need to choose a value to create a feed
+ const int batch_size = tensor_size_ < 0 ? 1 : tensor_size_;
+ Tensor x(DT_FLOAT, TensorShape({batch_size, 1}));
+ r.feed.push_back(std::make_pair("x", x));
+ r.fetch.push_back("y");
+
+ if (insert_queue_) {
+ QueueRunnerDef queue_runner;
+ queue_runner.set_queue_name("queue");
+ queue_runner.set_cancel_op_name("cancel");
+ *queue_runner.add_enqueue_op_name() = "enqueue";
+ r.queue_runners.push_back(queue_runner);
+ }
+
+ *item = std::move(r);
+ return true;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h
new file mode 100644
index 0000000000..4c5600c816
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_
+#define TENSORFLOW_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/grappler/inputs/input_yielder.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class Cluster;
+class GrapplerItem;
+
+class TrivialTestGraphInputYielder : public InputYielder {
+ public:
+ TrivialTestGraphInputYielder(int num_stages, int width, int tensor_size,
+ bool insert_queue,
+ const std::vector<string>& device_names);
+ bool NextItem(GrapplerItem* item) override;
+
+ private:
+ const int num_stages_;
+ const int width_;
+ const int tensor_size_;
+ const bool insert_queue_;
+ std::vector<string> device_names_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
new file mode 100644
index 0000000000..17f41105b2
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/inputs/utils.h"
+#include "tensorflow/core/platform/env.h"
+
+#include <vector>
+
+namespace tensorflow {
+namespace grappler {
+
+bool FilesExist(const std::vector<string>& files, std::vector<Status>* status) {
+ return Env::Default()->FilesExist(files, status);
+}
+
+bool FilesExist(const std::set<string>& files) {
+ return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
+}
+
+} // End namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
new file mode 100644
index 0000000000..ee65ca031d
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_INPUTS_UTILS_H_
+#define TENSORFLOW_GRAPPLER_INPUTS_UTILS_H_
+
+#include <set>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+
+bool FilesExist(const std::vector<string>& files,
+ std::vector<Status>* status = nullptr);
+bool FilesExist(const std::set<string>& files);
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_INPUTS_UTILS_H_
diff --git a/tensorflow/core/grappler/inputs/utils_test.cc b/tensorflow/core/grappler/inputs/utils_test.cc
new file mode 100644
index 0000000000..2a7c4834f1
--- /dev/null
+++ b/tensorflow/core/grappler/inputs/utils_test.cc
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/inputs/utils.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class UtilsTest : public ::testing::Test {
+ protected:
+ string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); }
+
+ void SetUp() override {
+ TF_CHECK_OK(env_->CreateDir(BaseDir()));
+ non_existent_file_ = io::JoinPath(BaseDir(), "non_existent_file.txt");
+ actual_file_ = io::JoinPath(BaseDir(), "test_file.txt");
+ TF_CHECK_OK(WriteStringToFile(env_, actual_file_, "Some test data"));
+ }
+
+ void TearDown() override {
+ int64 undeleted_files, undeleted_dirs;
+ TF_CHECK_OK(
+ env_->DeleteRecursively(BaseDir(), &undeleted_files, &undeleted_dirs));
+ }
+
+ string non_existent_file_;
+ string actual_file_;
+ Env* env_ = Env::Default();
+};
+
+TEST_F(UtilsTest, FilesExist) {
+ EXPECT_FALSE(FilesExist(std::vector<string>{{non_existent_file_}}));
+ EXPECT_FALSE(
+ FilesExist(std::vector<string>{{non_existent_file_}, {actual_file_}}));
+ EXPECT_TRUE(FilesExist(std::vector<string>{{actual_file_}}));
+
+ std::vector<Status> status;
+ EXPECT_FALSE(FilesExist(
+ std::vector<string>{{non_existent_file_}, {actual_file_}}, &status));
+ EXPECT_EQ(status.size(), 2);
+ EXPECT_FALSE(status[0].ok());
+ EXPECT_TRUE(status[1].ok());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
new file mode 100644
index 0000000000..518d12e3ab
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -0,0 +1,45 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+cc_library(
+ name = "graph_optimizer",
+ hdrs = [
+ "graph_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "layout_optimizer",
+ srcs = ["layout_optimizer.cc"],
+ hdrs = [
+ "layout_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h
new file mode 100644
index 0000000000..34e126b0af
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
+#define TENSORFLOW_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+
+class Cluster;
+class GrapplerItem;
+
+// An abstract interface for an algorithm for generating a candidate
+// optimization of a GrapplerItem for running on a cluster.
+class GraphOptimizer {
+ public:
+ virtual ~GraphOptimizer() {}
+
+ virtual string name() const = 0;
+
+ // Routine called to allow an algorithm to propose a rewritten graph
+ // for the graph, feeds and fetches in "item" to run more efficiently
+ // on "cluster".
+ // Returns true iff it managed to generate a solution, false otherwise.
+ virtual Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) = 0;
+
+ // Method invoked by the framework so that it can provide feedback
+ // on how well the "optimize_output" (produced as *output from a
+ // call to Optimize) performed. Lower "result" scores are better.
+ virtual void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) = 0;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
new file mode 100644
index 0000000000..417f71fed2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -0,0 +1,996 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <unordered_set>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace grappler {
+
+const char kConcatConst[] = "LayoutOptimizerConcatConst";
+const char kPermNHWCToNCHW[] = "LayoutOptimizerPermConstNHWCToNCHW";
+const char kPermNCHWToNHWC[] = "LayoutOptimizerPermConstNCHWToNHWC";
+const char kTransposeNHWCToNCHW[] = "LayoutOptimizerTransposeNHWCToNCHW";
+const char kTransposeNCHWToNHWC[] = "LayoutOptimizerTransposeNCHWToNHWC";
+const char kPermVecNHWCToNCHW[] = "LayoutOptimizerPermVecNHWCToNCHW";
+const char kReshapeNHWCToNCHW[] = "LayoutOptimizerReshapeNHWCToNCHW";
+const char kReshapeConst[] = "LayoutOptimizerReshapeConst";
+const char kReductionConst[] = "LayoutOptimizerReductionConst";
+
+std::set<string> GetOpsFormatSupported() {
+ std::set<string> ops_format_supported = {"AvgPool",
+ "AvgPoolGrad",
+ "Conv2D",
+ "Conv2DBackpropFilter",
+ "Conv2DBackpropInput",
+ "BiasAdd",
+ "BiasAddGrad",
+ "FusedBatchNorm",
+ "FusedBatchNormGrad",
+ "MaxPool",
+ "MaxPoolGrad"};
+ return ops_format_supported;
+}
+
+std::set<string> GetOpsFormatAgnostic() {
+ std::set<string> ops_format_agnostic = {"Add",
+ "AddN",
+ "Concat",
+ "ConcatV2",
+ "Floor",
+ "Identity",
+ "Mul",
+ "Neg",
+ "RealDiv",
+ "Relu",
+ "ReluGrad",
+ "Slice",
+ "SquaredDifference",
+ "Squeeze",
+ "Sub",
+ "Sum"};
+ return ops_format_agnostic;
+}
+
+class NodeMap {
+ public:
+ explicit NodeMap(GraphDef* graph) : graph_(graph) {
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
+ nodes_.insert(std::make_pair(node->name(), node));
+ for (const auto& input : node->input()) {
+ outputs_[input].insert(nodes_[node->name()]);
+ }
+ }
+ }
+
+ NodeDef* GetNode(const string& name) {
+ string node_name = NodeName(name);
+ return nodes_[node_name];
+ }
+
+ std::set<NodeDef*> GetOutputs(const string& name) { return outputs_[name]; }
+
+ void AddNode(const string& name, NodeDef* node) {
+ nodes_.insert(std::make_pair(name, node));
+ }
+
+ void AddOutput(const string& node, const string& output) {
+ outputs_[node].insert(nodes_[output]);
+ }
+
+ void UpdateOutput(const string& node, const string& old_output,
+ const string& new_output) {
+ outputs_[node].erase(nodes_[old_output]);
+ outputs_[node].insert(nodes_[new_output]);
+ }
+
+ private:
+ GraphDef* graph_;
+ std::unordered_map<string, NodeDef*> nodes_;
+ std::unordered_map<string, std::set<NodeDef*>> outputs_;
+};
+
+bool IsNodeNHWCToNCHW(const string& node_name) {
+ const string transpose_node_prefix = kTransposeNHWCToNCHW;
+ string prefix = node_name.substr(0, transpose_node_prefix.length());
+ if (prefix.compare(transpose_node_prefix) == 0) {
+ return true;
+ }
+ return false;
+}
+
+bool IsNodeNCHWToNHWC(const string& node_name) {
+ const string transpose_node_prefix = kTransposeNCHWToNHWC;
+ string prefix = node_name.substr(0, transpose_node_prefix.length());
+ if (prefix.compare(transpose_node_prefix) == 0) {
+ return true;
+ }
+ return false;
+}
+
+class NodeProcessor {
+ public:
+ NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : graph_(graph), node_(node), node_map_(node_map) {}
+ virtual ~NodeProcessor() {}
+ virtual void ConvertNode() {
+ if (ShouldProcess()) {
+ UpdateAttrDataFormat();
+ UpdateAttrKSize();
+ UpdateAttrStrides();
+ UpdateAttrShape();
+ AddLayoutTransposeToInputs();
+ AddLayoutTransposeToOutputs();
+ CustomizedProcessing();
+ }
+ }
+
+ protected:
+ bool IsDimsN(NodeDef* node, int n) const {
+ if (node->attr().find("_output_shapes") != node->attr().end()) {
+ auto shape = node->attr().at("_output_shapes").list().shape(0);
+ if (shape.dim_size() == n) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool IsDimsFour(NodeDef* node) const { return IsDimsN(node, 4); }
+
+ bool IsNHWC() const {
+ if (node_->attr().find("data_format") != node_->attr().end()) {
+ if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool HasOutputs() const {
+ auto outputs = node_map_->GetOutputs(node_->name());
+ return !outputs.empty();
+ }
+
+ virtual bool ShouldProcess() const {
+ return IsNHWC() && IsDimsFour(node_) && HasOutputs();
+ }
+
+ void UpdateAttrDataFormat() {
+ if (node_->attr().find("data_format") != node_->attr().end()) {
+ if (node_->attr().at("data_format").s().compare("NHWC") == 0) {
+ string* data_format =
+ node_->mutable_attr()->at("data_format").mutable_s();
+ *data_format = "NCHW";
+ }
+ }
+ }
+
+ virtual void UpdateAttrShape() {
+ if (node_->attr().find("_output_shapes") != node_->attr().end()) {
+ auto shape = node_->mutable_attr()
+ ->at("_output_shapes")
+ .mutable_list()
+ ->mutable_shape(0);
+ if (shape->dim_size() == 4) {
+ int64 h = shape->dim(1).size();
+ int64 w = shape->dim(2).size();
+ int64 c = shape->dim(3).size();
+ shape->mutable_dim(1)->set_size(c);
+ shape->mutable_dim(2)->set_size(h);
+ shape->mutable_dim(3)->set_size(w);
+ }
+ }
+ }
+
+ void UpdateAttrKSize() {
+ if (node_->attr().find("ksize") != node_->attr().end()) {
+ auto list = node_->mutable_attr()->at("ksize").mutable_list();
+ UpdateTuple(list);
+ }
+ }
+
+ void UpdateAttrStrides() {
+ if (node_->attr().find("strides") != node_->attr().end()) {
+ auto list = node_->mutable_attr()->at("strides").mutable_list();
+ UpdateTuple(list);
+ }
+ }
+
+ void UpdateAttrValue(const string& name) {
+ NodeDef* node = node_map_->GetNode(name);
+ Tensor tensor;
+ auto success =
+ tensor.FromProto(node->mutable_attr()->at({"value"}).tensor());
+ if (!success) {
+ LOG(ERROR) << "Failed to parse TensorProto.";
+ }
+ int c = tensor.flat<int>()(3);
+ tensor.flat<int>()(3) = tensor.flat<int>()(2);
+ tensor.flat<int>()(2) = tensor.flat<int>()(1);
+ tensor.flat<int>()(1) = c;
+ tensor.AsProtoTensorContent(
+ node->mutable_attr()->at({"value"}).mutable_tensor());
+ }
+
+ virtual std::vector<int> GetInputPos() const {
+ std::vector<int> input_pos = {0};
+ return input_pos;
+ }
+
+ void AddNodeTranspose(const string& node_name, const string& input_name,
+ DataType data_type, const TensorShapeProto& input_shape,
+ bool NHWCToNCHW) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(node_name, node);
+ node->set_name(node_name);
+ *node->add_input() = input_name;
+ *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC;
+ node->set_op("Transpose");
+ AttrValue attr_data_type;
+ attr_data_type.set_type(data_type);
+ node->mutable_attr()->insert({"T", attr_data_type});
+ AttrValue attr_data_type_perm;
+ attr_data_type_perm.set_type(DT_INT32);
+ node->mutable_attr()->insert({"Tperm", attr_data_type_perm});
+ AttrValue attr_output_shape;
+ auto output_shape = attr_output_shape.mutable_list()->add_shape();
+ if (NHWCToNCHW) {
+ output_shape->add_dim()->set_size(input_shape.dim(0).size());
+ output_shape->add_dim()->set_size(input_shape.dim(3).size());
+ output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ output_shape->add_dim()->set_size(input_shape.dim(2).size());
+ } else {
+ output_shape->add_dim()->set_size(input_shape.dim(0).size());
+ output_shape->add_dim()->set_size(input_shape.dim(2).size());
+ output_shape->add_dim()->set_size(input_shape.dim(3).size());
+ output_shape->add_dim()->set_size(input_shape.dim(1).size());
+ }
+ node->mutable_attr()->insert({"_output_shapes", attr_output_shape});
+ }
+
+ virtual void AddLayoutTransposeToInputs() {
+ std::vector<int> input_pos = GetInputPos();
+ for (const auto& pos : input_pos) {
+ string node_name_NHWCToNCHW = strings::StrCat(
+ kTransposeNHWCToNCHW, "-", node_->name(), "-", node_->input(pos));
+ auto input_node = node_map_->GetNode(node_->input(pos));
+ int output_pos = NodePosition(node_->input(pos));
+ AddNodeTranspose(
+ node_name_NHWCToNCHW, node_->input(pos), node_->attr().at("T").type(),
+ input_node->attr().at("_output_shapes").list().shape(output_pos),
+ true);
+ node_map_->UpdateOutput(node_->input(pos), node_->name(),
+ node_name_NHWCToNCHW);
+ node_map_->AddOutput(node_name_NHWCToNCHW, node_->name());
+ *node_->mutable_input(pos) = node_name_NHWCToNCHW;
+ }
+ }
+
+ virtual void AddLayoutTransposeToOutputs() {
+ auto outputs = node_map_->GetOutputs(node_->name());
+ for (const auto& output : outputs) {
+ string node_name_NCHWToNHWC = strings::StrCat(
+ kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name());
+ auto it = std::find_if(output->mutable_input()->begin(),
+ output->mutable_input()->end(),
+ [this](const string& input) {
+ return input.compare(node_->name()) == 0;
+ });
+ int output_pos = NodePosition(*it);
+ AddNodeTranspose(
+ node_name_NCHWToNHWC, node_->name(), node_->attr().at("T").type(),
+ node_->attr().at("_output_shapes").list().shape(output_pos), false);
+ *it = node_name_NCHWToNHWC;
+ node_map_->UpdateOutput(node_->name(), output->name(),
+ node_name_NCHWToNHWC);
+ node_map_->AddOutput(node_name_NCHWToNHWC, output->name());
+ }
+ }
+
+ virtual void CustomizedProcessing() {}
+
+ GraphDef* graph_;
+ NodeDef* node_;
+ NodeMap* node_map_;
+
+ private:
+ void UpdateTuple(AttrValue_ListValue* list) {
+ int64 h = list->i(1);
+ int64 w = list->i(2);
+ int64 c = list->i(3);
+ list->set_i(1, c);
+ list->set_i(2, h);
+ list->set_i(3, w);
+ }
+};
+
+class AvgPoolGradProcessor : public NodeProcessor {
+ public:
+ AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {1};
+ return input_pos;
+ }
+ void CustomizedProcessing() override { UpdateAttrValue(node_->input(0)); }
+};
+
+class BiasAddGradProcessor : public NodeProcessor {
+ public:
+ BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ bool ShouldProcess() const override {
+ auto input = node_map_->GetNode(node_->input(0));
+ if (input) {
+ if ((IsNHWC() && IsDimsFour(input)) || IsNodeNCHWToNHWC(input->name())) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void AddLayoutTransposeToOutputs() override {}
+};
+
+class Conv2DBackpropFilterProcessor : public NodeProcessor {
+ public:
+ Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node,
+ NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {0, 2};
+ return input_pos;
+ }
+
+ void AddLayoutTransposeToOutputs() override {}
+ // No need to update output shape, as it is always of shape
+ // [filter_height, filter_width, in_channels, out_channels], regardless of
+ // whether NCHW or NHWC is used.
+ void UpdateAttrShape() override {}
+};
+
+class Conv2DBackpropInputProcessor : public NodeProcessor {
+ public:
+ Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node,
+ NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {2};
+ return input_pos;
+ }
+ void CustomizedProcessing() override { UpdateAttrValue(node_->input(0)); }
+};
+
+class FusedBatchNormGradProcessor : public NodeProcessor {
+ public:
+ FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {0, 1};
+ return input_pos;
+ }
+};
+
+class MaxPoolGradProcessor : public NodeProcessor {
+ public:
+ MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {0, 1, 2};
+ return input_pos;
+ }
+};
+
+class AgnosticNodeProcessor : public NodeProcessor {
+ public:
+ AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : NodeProcessor(graph, node, node_map) {}
+
+ protected:
+ bool ShouldProcess() const override {
+ return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC();
+ }
+
+ bool IsNodeAfterNCHWToNHWC() const {
+ std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
+ auto node = node_map_->GetNode(node_->name());
+ while (node->input_size() > 0) {
+ int data_input_pos = 0;
+ if (node->op().compare("Concat") == 0) {
+ data_input_pos = 1;
+ }
+ node = node_map_->GetNode(node->input(data_input_pos));
+ if (IsNodeNCHWToNHWC(node->name())) {
+ return true;
+ }
+ bool connected =
+ ops_format_agnostic.find(node->name()) != ops_format_agnostic.end();
+ if (!connected) {
+ return false;
+ }
+ }
+ return false;
+ }
+};
+
+class AddNProcessor : public AgnosticNodeProcessor {
+ public:
+ AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos;
+ for (int i = 0; i < node_->input_size(); i++) {
+ input_pos.push_back(i);
+ }
+ return input_pos;
+ }
+};
+
+class BinaryOpProcessor : public AgnosticNodeProcessor {
+ public:
+ BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {
+ is_4d_with_vector_ = Is4DOperateWithVector();
+ }
+
+ protected:
+ bool ShouldProcess() const override {
+ return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ (Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
+ Is4DOperateWithVector());
+ }
+
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {0};
+ if (Is4DOperateWithND(4)) {
+ input_pos.push_back(1);
+ }
+ return input_pos;
+ }
+
+ bool Is4DOperateWithND(int n) const {
+ auto input0 = node_map_->GetNode(node_->input(0));
+ auto input1 = node_map_->GetNode(node_->input(1));
+ if (input0 && input1) {
+ return (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) &&
+ ((n == 4)
+ ? (IsDimsFour(input1) || IsNodeNCHWToNHWC(input1->name()))
+ : IsDimsN(input1, n));
+ }
+ return false;
+ }
+
+ bool Is4DOperateWithScalar() const { return Is4DOperateWithND(0); }
+
+ bool Is4DOperateWithVector() const { return Is4DOperateWithND(1); }
+
+ void AddNodeShapeConst(const string& name, int num_channels) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(name, node);
+ node->set_name(name);
+ node->set_op("Const");
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_INT32);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+
+ AttrValue attr_tensor;
+ Tensor tensor(DT_INT32, TensorShape({4}));
+ std::vector<int> shape = {1, num_channels, 1, 1};
+ for (int i = 0; i < shape.size(); i++) {
+ tensor.flat<int>()(i) = shape[i];
+ }
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ }
+
+ void AddNodeReshape(const string& node_name, const string& input_name,
+ const string& shape_const_node_name, DataType data_type) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(node_name, node);
+ node->set_name(node_name);
+ *node->add_input() = input_name;
+ *node->add_input() = shape_const_node_name;
+ node->set_op("Reshape");
+
+ AttrValue attr_type_indices;
+ attr_type_indices.set_type(DT_INT32);
+ node->mutable_attr()->insert({"Tshape", attr_type_indices});
+
+ AttrValue attr_type_params;
+ attr_type_params.set_type(data_type);
+ node->mutable_attr()->insert({"T", attr_type_params});
+ }
+
+ void CustomizedProcessing() override {
+ if (is_4d_with_vector_) {
+ string suffix = strings::StrCat("-", node_->name(), "-", node_->input(1));
+ string reshape_node_name = strings::StrCat(kReshapeNHWCToNCHW, suffix);
+ string shape_const_node_name = strings::StrCat(kReshapeConst, suffix);
+ int vector_size = node_map_->GetNode(node_->input(1))
+ ->attr()
+ .at("_output_shapes")
+ .list()
+ .shape(0)
+ .dim(0)
+ .size();
+ AddNodeShapeConst(shape_const_node_name, vector_size);
+ AddNodeReshape(reshape_node_name, node_->input(1), shape_const_node_name,
+ node_->attr().at("T").type());
+ node_map_->AddOutput(shape_const_node_name, reshape_node_name);
+ node_map_->UpdateOutput(node_->input(1), node_->name(),
+ reshape_node_name);
+ node_map_->AddOutput(reshape_node_name, node_->name());
+ *node_->mutable_input(1) = reshape_node_name;
+ }
+ }
+
+ private:
+ bool is_4d_with_vector_;
+};
+
+class ConcatProcessor : public AgnosticNodeProcessor {
+ public:
+ ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {
+ // For Concat, the concat axis is the first input; for ConcatV2,
+ // the last input.
+ axis_node_pos_ =
+ (node_->op().compare("Concat") == 0) ? 0 : (node_->input_size() - 1);
+ }
+
+ protected:
+ bool ShouldProcess() const override {
+ return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ IsAlongDimC();
+ }
+
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos;
+ int start = (node_->op().compare("Concat") == 0) ? 1 : 0;
+ int end = (node_->op().compare("Concat") == 0) ? node_->input_size()
+ : (node_->input_size() - 1);
+ for (int i = start; i < end; i++) {
+ input_pos.push_back(i);
+ }
+ return input_pos;
+ }
+
+ void CustomizedProcessing() override {
+ node_map_->AddOutput(kConcatConst, node_->name());
+ *node_->mutable_input(axis_node_pos_) = kConcatConst;
+ }
+
+ bool IsAlongDimC() const {
+ auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ if (axis_node->attr().find("value") != axis_node->attr().end()) {
+ return axis_node->attr().at("value").tensor().int_val(0) == 3;
+ }
+ return false;
+ }
+
+ int axis_node_pos_;
+};
+
+class ReluGradProcessor : public AgnosticNodeProcessor {
+ public:
+ ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {}
+
+ protected:
+ std::vector<int> GetInputPos() const override {
+ std::vector<int> input_pos = {0, 1};
+ return input_pos;
+ }
+};
+
+// This is the older, less optimized gather-based SliceProcessor. We keep it as
+// a test case for constant propagation optimization.
+class SliceProcessorGatherBased : public AgnosticNodeProcessor {
+ public:
+ SliceProcessorGatherBased(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {}
+
+ protected:
+ void CustomizedProcessing() override {
+ // Skip the first input, which is the data to be sliced.
+ for (int i = 1; i < node_->input_size(); i++) {
+ string node_name_NHWCToNCHW =
+ strings::StrCat(kPermVecNHWCToNCHW, "-", node_->name(), "-input", i);
+ AddNodePermVec(node_name_NHWCToNCHW, node_->input(i),
+ node_->attr().at("Index").type(), true);
+ node_map_->UpdateOutput(node_->input(i), node_->name(),
+ node_name_NHWCToNCHW);
+ node_map_->AddOutput(node_name_NHWCToNCHW, node_->name());
+ *node_->mutable_input(i) = node_name_NHWCToNCHW;
+ }
+ }
+
+ private:
+ void AddNodePermVec(const string& node_name, const string& input_name,
+ DataType data_type, bool NHWCToNCHW) {
+ NodeDef* node = graph_->add_node();
+ node_map_->AddNode(node_name, node);
+ node->set_name(node_name);
+ *node->add_input() = input_name;
+ *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC;
+ node->set_op("Gather");
+
+ AttrValue attr_type_indices;
+ attr_type_indices.set_type(DT_INT32);
+ node->mutable_attr()->insert({"Tindices", attr_type_indices});
+
+ AttrValue attr_type_params;
+ attr_type_params.set_type(data_type);
+ node->mutable_attr()->insert({"Tparams", attr_type_params});
+
+ AttrValue attr_validate;
+ attr_validate.set_b(true);
+ node->mutable_attr()->insert({"validate_indices", attr_validate});
+ }
+};
+
+class SliceProcessor : public AgnosticNodeProcessor {
+ public:
+ SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {}
+
+ protected:
+ void CustomizedProcessing() override {
+ auto maybe_concatoffset_node =
+ node_map_->GetNode(NodeName(node_->input(1)));
+ if (maybe_concatoffset_node->op() == "ConcatOffset") {
+ auto axis_node = node_map_->GetNode(maybe_concatoffset_node->input(0));
+ // Need to process if the channel is at dimension 3, which indicates the
+ // NHWC format is being used. As mutiple Slice nodes may share the same
+ // ConcatOffset node, the NHWC to NCHW conversion may have already
+ // been performed when processing other Slice nodes.
+ if (axis_node->attr().at("value").tensor().int_val(0) == 3) {
+ for (int i = 1; i < maybe_concatoffset_node->input_size(); i++) {
+ auto shape_node =
+ node_map_->GetNode(maybe_concatoffset_node->input(i));
+ AttrValue attr_tensor;
+ Tensor tensor;
+ CHECK(tensor.FromProto(shape_node->attr().at({"value"}).tensor()));
+ int h = tensor.flat<int>()(1);
+ int w = tensor.flat<int>()(2);
+ int c = tensor.flat<int>()(3);
+ tensor.flat<int>()(1) = c;
+ tensor.flat<int>()(2) = h;
+ tensor.flat<int>()(3) = w;
+ tensor.AsProtoTensorContent(
+ shape_node->mutable_attr()->at({"value"}).mutable_tensor());
+ }
+ // Set the channel dimension to 1, as we have converted the vector
+ // element order from NHWC to NCHW.
+ axis_node->mutable_attr()->at("value").mutable_tensor()->set_int_val(0,
+ 1);
+ }
+ }
+ }
+};
+
+class SqueezeProcessor : public AgnosticNodeProcessor {
+ public:
+ SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {}
+
+ protected:
+ bool ShouldProcess() const override {
+ return IsDimsN(node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ IsInputConvertible() && IsAlongDimHW();
+ }
+
+ void AddLayoutTransposeToOutputs() override {}
+
+ bool IsInputConvertible() const {
+ auto input = node_map_->GetNode(node_->input(0));
+ if (IsNodeNCHWToNHWC(input->name())) {
+ input = node_map_->GetNode(input->input(0));
+ }
+ if (input->attr().find("_output_shapes") != input->attr().end()) {
+ auto shape = input->attr().at("_output_shapes").list().shape(0);
+ if (shape.dim_size() != 4) {
+ return false;
+ }
+ if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ bool IsAlongDimHW() const {
+ if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
+ auto list = node_->attr().at("squeeze_dims").list();
+ if (list.i(0) == 1 && list.i(1) == 2) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void CustomizedProcessing() override {
+ auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
+ list->set_i(0, 2);
+ list->set_i(1, 3);
+ }
+};
+
+class SumProcessor : public AgnosticNodeProcessor {
+ public:
+ SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map)
+ : AgnosticNodeProcessor(graph, node, node_map) {}
+
+ protected:
+ bool ShouldProcess() const override {
+ auto input0 = node_map_->GetNode(node_->input(0));
+ return HasOutputs() && IsNodeAfterNCHWToNHWC() &&
+ (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) &&
+ IsAlongDimNHW();
+ }
+
+ void AddLayoutTransposeToOutputs() override {}
+
+ void CustomizedProcessing() override {
+ node_map_->AddOutput(kReductionConst, node_->name());
+ *node_->mutable_input(1) = kReductionConst;
+ }
+
+ private:
+ bool IsAlongDimNHW() const {
+ NodeDef* node = node_map_->GetNode(node_->input(1));
+ Tensor tensor;
+ if (node->attr().find({"value"}) == node->attr().end()) {
+ return false;
+ }
+ auto success = tensor.FromProto(node->attr().at({"value"}).tensor());
+ if (!success) {
+ LOG(ERROR) << "Failed to parse TensorProto.";
+ return false;
+ }
+ if (tensor.flat<int>().size() != 3) {
+ return false;
+ }
+ if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 &&
+ tensor.flat<int>()(2) == 2) {
+ return true;
+ }
+ return false;
+ }
+};
+
+class DataLayoutOptimizer {
+ public:
+ explicit DataLayoutOptimizer(GraphDef* graph)
+ : graph_(graph), node_map_(graph_) {
+ LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size();
+ Expand();
+ LOG(INFO) << "Number of nodes after Expand: " << graph_->node_size();
+ Collapse();
+ LOG(INFO) << "Number of nodes after Collapse: " << graph_->node_size();
+ }
+
+ private:
+ void AddNodePermConst(const string& name,
+ const std::vector<int>& permutation) {
+ NodeDef* node = graph_->add_node();
+ node_map_.AddNode(name, node);
+ node->set_name(name);
+ node->set_op("Const");
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_INT32);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+ AttrValue attr_tensor;
+ Tensor tensor(DT_INT32, TensorShape({4}));
+ for (int i = 0; i < permutation.size(); i++) {
+ tensor.flat<int>()(i) = permutation[i];
+ }
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ }
+
+ void AddNodeConcatConst() {
+ NodeDef* node = graph_->add_node();
+ node_map_.AddNode(kConcatConst, node);
+ node->set_name(kConcatConst);
+ node->set_op("Const");
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_INT32);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+ AttrValue attr_tensor;
+ Tensor tensor(DT_INT32, TensorShape({}));
+ tensor.scalar<int>()() = 1;
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ }
+
+ void AddNodeReductionConst() {
+ NodeDef* node = graph_->add_node();
+ node_map_.AddNode(kReductionConst, node);
+ node->set_name(kReductionConst);
+ node->set_op("Const");
+ AttrValue attr_data_type;
+ attr_data_type.set_type(DT_INT32);
+ node->mutable_attr()->insert({"dtype", attr_data_type});
+
+ AttrValue attr_tensor;
+ Tensor tensor(DT_INT32, TensorShape({3}));
+ std::vector<int> axis = {0, 2, 3};
+ for (int i = 0; i < axis.size(); i++) {
+ tensor.flat<int>()(i) = axis[i];
+ }
+ tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
+ node->mutable_attr()->insert({"value", attr_tensor});
+ }
+
+ // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
+ void Expand() {
+ int node_size_original = graph_->node_size();
+ // This is the first pass where we expand the nodes which support NCHW.
+ std::set<string> ops_format_supported = GetOpsFormatSupported();
+ for (int i = 0; i < graph_->node_size(); i++) {
+ if (ops_format_supported.find(graph_->node(i).op()) !=
+ ops_format_supported.end()) {
+ auto node = graph_->mutable_node(i);
+ std::unique_ptr<NodeProcessor> node_processor;
+ if (node->op().compare("AvgPoolGrad") == 0) {
+ node_processor.reset(
+ new AvgPoolGradProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("BiasAddGrad") == 0) {
+ node_processor.reset(
+ new BiasAddGradProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Conv2DBackpropFilter") == 0) {
+ node_processor.reset(
+ new Conv2DBackpropFilterProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Conv2DBackpropInput") == 0) {
+ node_processor.reset(
+ new Conv2DBackpropInputProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("FusedBatchNormGrad") == 0) {
+ node_processor.reset(
+ new FusedBatchNormGradProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("MaxPoolGrad") == 0) {
+ node_processor.reset(
+ new MaxPoolGradProcessor(graph_, node, &node_map_));
+ } else {
+ node_processor.reset(new NodeProcessor(graph_, node, &node_map_));
+ }
+ node_processor->ConvertNode();
+ }
+ }
+
+ // This is the second pass where we expand layout-agnostic nodes. This pass
+ // only needs to be performed if at least one node in the previous pass is
+ // expanded.
+ if (graph_->node_size() > node_size_original) {
+ AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2});
+ AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1});
+ AddNodeConcatConst();
+ AddNodeReductionConst();
+ std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
+ for (int i = 0; i < graph_->node_size(); i++) {
+ if (ops_format_agnostic.find(graph_->node(i).op()) !=
+ ops_format_agnostic.end()) {
+ auto node = graph_->mutable_node(i);
+ std::unique_ptr<NodeProcessor> node_processor;
+ if (node->op().compare("AddN") == 0) {
+ node_processor.reset(new AddNProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Add") == 0 ||
+ node->op().compare("Mul") == 0 ||
+ node->op().compare("RealDiv") == 0 ||
+ node->op().compare("SquaredDifference") == 0 ||
+ node->op().compare("Sub") == 0) {
+ node_processor.reset(
+ new BinaryOpProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Concat") == 0 ||
+ node->op().compare("ConcatV2") == 0) {
+ node_processor.reset(new ConcatProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("ReluGrad") == 0) {
+ node_processor.reset(
+ new ReluGradProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Slice") == 0) {
+ node_processor.reset(new SliceProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Squeeze") == 0) {
+ node_processor.reset(
+ new SqueezeProcessor(graph_, node, &node_map_));
+ } else if (node->op().compare("Sum") == 0) {
+ node_processor.reset(new SumProcessor(graph_, node, &node_map_));
+ } else {
+ node_processor.reset(
+ new AgnosticNodeProcessor(graph_, node, &node_map_));
+ }
+ node_processor->ConvertNode();
+ }
+ }
+ }
+ }
+
+ // Remove all node pairs, where a NCHW-to-NHWC node is followed by
+ // a NHWC-to-NCHW node.
+ void Collapse() {
+ std::unordered_set<string> nodes_removable;
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
+ if (IsNodeNHWCToNCHW(node->name())) {
+ if (IsNodeNCHWToNHWC(node->input(0))) {
+ const string& trans_first = node->input(0);
+ const string& trans_second = node->name();
+ auto outputs = node_map_.GetOutputs(trans_second);
+ CHECK(outputs.size() == 1)
+ << "There is always only a single output for a Transpose node, "
+ << "due to the way it is added by NodeProcessor.";
+ NodeDef* output = *outputs.begin();
+ string input = node_map_.GetNode(trans_first)->input(0);
+ for (int i = 0; i < output->input_size(); i++) {
+ if (output->input(i).compare(trans_second) == 0) {
+ *output->mutable_input(i) = input;
+ break;
+ }
+ }
+ nodes_removable.insert(trans_first);
+ nodes_removable.insert(trans_second);
+ }
+ }
+ }
+ graph_->mutable_node()->erase(
+ std::remove_if(
+ graph_->mutable_node()->begin(), graph_->mutable_node()->end(),
+ [nodes_removable](const NodeDef& node) {
+ return nodes_removable.find(node.name()) != nodes_removable.end();
+ }),
+ graph_->mutable_node()->end());
+ }
+
+ GraphDef* graph_;
+ NodeMap node_map_;
+};
+
+Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ DataLayoutOptimizer layout_optimizer(output);
+ return Status::OK();
+}
+
+void LayoutOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // Nothing to do for LayoutOptimizer.
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h
new file mode 100644
index 0000000000..66dec17a35
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
+#define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
+
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Convert the NHWC layout to NCHW for Conv-related ops on GPUs.
+class LayoutOptimizer : public GraphOptimizer {
+ public:
+ LayoutOptimizer() {}
+ ~LayoutOptimizer() override {}
+
+ string name() const override { return "layout"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
new file mode 100644
index 0000000000..daef641fd5
--- /dev/null
+++ b/tensorflow/core/grappler/utils.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/stream_executor.h"
+
+namespace tensorflow {
+namespace grappler {
+
+int GetNumAvailableGPUs() {
+ int num_eligible_gpus = 0;
+ if (ValidateGPUMachineManager().ok()) {
+ perftools::gputools::Platform* gpu_manager = GPUMachineManager();
+ if (gpu_manager != nullptr) {
+ int num_gpus = gpu_manager->VisibleDeviceCount();
+ for (int i = 0; i < num_gpus; i++) {
+ auto exec_status = gpu_manager->ExecutorForDevice(i);
+ if (exec_status.ok()) {
+ perftools::gputools::StreamExecutor* se = exec_status.ValueOrDie();
+ const perftools::gputools::DeviceDescription& desc =
+ se->GetDeviceDescription();
+ int min_gpu_core_count = 8;
+ if (desc.core_count() >= min_gpu_core_count) {
+ num_eligible_gpus++;
+ }
+ }
+ }
+ }
+ }
+ LOG(INFO) << "Number of eligible GPUs (core count >= 8): "
+ << num_eligible_gpus;
+ return num_eligible_gpus;
+}
+
+int GetNumAvailableLogicalCPUCores() { return port::NumSchedulableCPUs(); }
+
+string ParseNodeName(const string& name, int* position) {
+ // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
+ // to get a node name.
+ strings::Scanner scan(name);
+ scan.ZeroOrOneLiteral("^")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER_DIGIT_DOT)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
+ StringPiece capture;
+ StringPiece remaining;
+ if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ *position = 0;
+ return "";
+ } else {
+ if (name[0] == '^') {
+ *position = -1;
+ } else if (remaining.empty()) {
+ *position = 0;
+ } else {
+ // Skip the first ':' character.
+ *position = std::stoi(remaining.substr(1).ToString());
+ }
+ return capture.ToString();
+ }
+}
+
+string NodeName(const string& name) {
+ int position;
+ return ParseNodeName(name, &position);
+}
+
+int NodePosition(const string& name) {
+ int position;
+ ParseNodeName(name, &position);
+ return position;
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
new file mode 100644
index 0000000000..1aa91e25e4
--- /dev/null
+++ b/tensorflow/core/grappler/utils.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_GRAPPLER_UTILS_H_
+#define TENSORFLOW_GRAPPLER_UTILS_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Get the number of available GPUs whose number of multiprocessors is no less
+// than 8.
+int GetNumAvailableGPUs();
+
+// Get the number of logical CPU cores (aka hyperthreads) available.
+int GetNumAvailableLogicalCPUCores();
+
+// Return the node name corresponding to 'name' if name is valid, or the empty
+// string otherwise.
+string NodeName(const string& name);
+
+// Get the trailing position number ":{digits}" (if any) of a node name.
+int NodePosition(const string& name);
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_GRAPPLER_UTILS_H_
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
new file mode 100644
index 0000000000..57ba45352a
--- /dev/null
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class UtilsTest : public ::testing::Test {};
+
+TEST_F(UtilsTest, NodeName) {
+ EXPECT_EQ("abc", NodeName("abc"));
+ EXPECT_EQ("abc", NodeName("^abc"));
+ EXPECT_EQ("abc", NodeName("abc:0"));
+ EXPECT_EQ("abc", NodeName("^abc:0"));
+
+ EXPECT_EQ("abc/def", NodeName("abc/def"));
+ EXPECT_EQ("abc/def", NodeName("^abc/def"));
+ EXPECT_EQ("abc/def", NodeName("abc/def:1"));
+ EXPECT_EQ("abc/def", NodeName("^abc/def:1"));
+
+ EXPECT_EQ("abc/def0", NodeName("abc/def0"));
+ EXPECT_EQ("abc/def0", NodeName("^abc/def0"));
+ EXPECT_EQ("abc/def0", NodeName("abc/def0:0"));
+ EXPECT_EQ("abc/def0", NodeName("^abc/def0:0"));
+
+ EXPECT_EQ("abc/def_0", NodeName("abc/def_0"));
+ EXPECT_EQ("abc/def_0", NodeName("^abc/def_0"));
+ EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3"));
+ EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3"));
+
+ EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214"));
+}
+
+TEST_F(UtilsTest, NodePosition) {
+ EXPECT_EQ(2, NodePosition("abc:2"));
+ EXPECT_EQ(123, NodePosition("abc:123"));
+ EXPECT_EQ(-1, NodePosition("^abc:123"));
+ EXPECT_EQ(-1, NodePosition("^abc"));
+ EXPECT_EQ(0, NodePosition(""));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index d2b252629e..d729963616 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -132,6 +132,52 @@ Status Env::FileExists(const string& fname) {
return fs->FileExists(fname);
}
+bool Env::FilesExist(const std::vector<string>& files,
+ std::vector<Status>* status) {
+ std::unordered_map<string, std::vector<string>> files_per_fs;
+ for (const auto& file : files) {
+ StringPiece scheme, host, path;
+ io::ParseURI(file, &scheme, &host, &path);
+ files_per_fs[scheme.ToString()].push_back(file);
+ }
+
+ std::unordered_map<string, Status> per_file_status;
+ bool result = true;
+ for (auto itr : files_per_fs) {
+ FileSystem* file_system = file_system_registry_->Lookup(itr.first);
+ bool fs_result;
+ std::vector<Status> local_status;
+ std::vector<Status>* fs_status = status ? &local_status : nullptr;
+ if (!file_system) {
+ fs_result = false;
+ if (fs_status) {
+ Status s = errors::Unimplemented("File system scheme ", itr.first,
+ " not implemented");
+ local_status.resize(itr.second.size(), s);
+ }
+ } else {
+ fs_result = file_system->FilesExist(itr.second, fs_status);
+ }
+ if (fs_status) {
+ result &= fs_result;
+ for (int i = 0; i < itr.second.size(); ++i) {
+ per_file_status[itr.second[i]] = fs_status->at(i);
+ }
+ } else if (!fs_result) {
+ // Return early
+ return false;
+ }
+ }
+
+ if (status) {
+ for (const auto& file : files) {
+ status->push_back(per_file_status[file]);
+ }
+ }
+
+ return result;
+}
+
Status Env::GetChildren(const string& dir, std::vector<string>* result) {
FileSystem* fs;
TF_RETURN_IF_ERROR(GetFileSystemForFile(dir, &fs));
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 354748eb3e..1b7e024b0f 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -136,6 +136,12 @@ class Env {
/// Returns OK if the named path exists and NOT_FOUND otherwise.
Status FileExists(const string& fname);
+ /// Returns true if all the listed files exist, false otherwise.
+ /// if status is not null, populate the vector with a detailed status
+ /// for each file.
+ bool FilesExist(const std::vector<string>& files,
+ std::vector<Status>* status);
+
/// \brief Stores in *result the names of the children of the specified
/// directory. The names are relative to "dir".
///
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 4564297b5f..3d7553e6da 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -76,6 +76,22 @@ WritableFile::~WritableFile() {}
FileSystemRegistry::~FileSystemRegistry() {}
+bool FileSystem::FilesExist(const std::vector<string>& files,
+ std::vector<Status>* status) {
+ bool result = true;
+ for (const auto& file : files) {
+ Status s = FileExists(file);
+ result &= s.ok();
+ if (status != nullptr) {
+ status->push_back(s);
+ } else if (!result) {
+ // Return early since there is no need to check other files.
+ return false;
+ }
+ }
+ return result;
+}
+
Status FileSystem::GetMatchingPaths(const string& pattern,
std::vector<string>* results) {
results->clear();
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index b2499769aa..903df96b58 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -105,6 +105,12 @@ class FileSystem {
/// Returns OK if the named path exists and NOT_FOUND otherwise.
virtual Status FileExists(const string& fname) = 0;
+ /// Returns true if all the listed files exist, false otherwise.
+ /// if status is not null, populate the vector with a detailed status
+ /// for each file.
+ virtual bool FilesExist(const std::vector<string>& files,
+ std::vector<Status>* status);
+
/// \brief Returns the immediate children in the given directory.
///
/// The returned paths are relative to 'dir'.