diff options
41 files changed, 4064 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 755bf679d3..01244d7261 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -213,6 +213,11 @@ filegroup( "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", "//tensorflow/core/distributed_runtime/rpc:all_files", + "//tensorflow/core/grappler:all_files", + "//tensorflow/core/grappler/clusters:all_files", + "//tensorflow/core/grappler/costs:all_files", + "//tensorflow/core/grappler/inputs:all_files", + "//tensorflow/core/grappler/optimizers:all_files", "//tensorflow/core/kernels:all_files", "//tensorflow/core/kernels/cloud:all_files", "//tensorflow/core/kernels/hexagon:all_files", diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD new file mode 100644 index 0000000000..53714367f5 --- /dev/null +++ b/tensorflow/core/grappler/BUILD @@ -0,0 +1,67 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:gpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:stream_executor", + ], +) + +cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + deps = [ + ":utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "grappler_item", + srcs = ["grappler_item.cc"], + hdrs = ["grappler_item.h"], + visibility = ["//visibility:public"], + deps = [ + ":utils", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/grappler/inputs:utils", + ], +) + +cc_test( + name = "grappler_item_test", + srcs = ["grappler_item_test.cc"], + deps = [ + ":grappler_item", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD new file mode 100644 index 0000000000..c420818333 --- /dev/null +++ b/tensorflow/core/grappler/clusters/BUILD @@ -0,0 +1,64 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "cluster", + srcs = ["cluster.cc"], + hdrs = [ + "cluster.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + ], +) + +cc_library( + name = "single_machine", + srcs = ["single_machine.cc"], + hdrs = [ + "single_machine.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":cluster", + "//tensorflow/cc:coordinator", + "//tensorflow/cc:queue_runner", + "//tensorflow/core:core_cpu", + "//tensorflow/core:direct_session", + "//tensorflow/core:lib", + "//tensorflow/core/kernels:ops_util", + ], +) + +cc_test( + name = "single_machine_test", + srcs = ["single_machine_test.cc"], + args = ["--heap_check=local"], # The GPU tracer leaks memory + deps = [ + ":single_machine", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc new file mode 100644 index 0000000000..089729e770 --- /dev/null +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/cluster.h" +#include <atomic> + +namespace tensorflow { +namespace grappler { + +static std::atomic<bool> already_created(false); + +Cluster::Cluster(int timeout_s) : timeout_s_(timeout_s) { + // This is really ugly: to avoid leaking variables, we need to reset the tf + // session every time we're done processing a grappler item. However, + // variables are global, and therefore we can't have more than 1 session alive + // at a time. This check detects when more that one cluster is created. + CHECK(!already_created); + already_created = true; + + options_.config.mutable_graph_options()->set_build_cost_model(1); + + run_options_.set_trace_level(RunOptions::HARDWARE_TRACE); +} + +Cluster::~Cluster() { + CHECK(already_created); + already_created = false; +} + +void Cluster::SetNumWarmupSteps(int num_steps) { + options_.config.mutable_graph_options()->set_build_cost_model_after( + num_steps); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/cluster.h b/tensorflow/core/grappler/clusters/cluster.h new file mode 100644 index 0000000000..02915cc132 --- /dev/null +++ b/tensorflow/core/grappler/clusters/cluster.h @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_ +#define TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_ + +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { + +// A cluster represents of collection of hardware resources available to run +// the TensorFlow model. +// A process can only create a single cluster at a time. +class Cluster { + public: + explicit Cluster(int timeout_s); + virtual ~Cluster(); + + // Provision the hardware resources needed to run TensorFlow and start a + // TensorFlow session that can take advantage of these resources. + // The actual resources that are leveraged depend on the type of cluster + // instantiated. + // Returns OK iff all the requested resources could be reserved and a + // TensorFlow session successfully created. Returns an error otherwise. + // There is no graceful degradation to handle the case where only a subset + // of the requested resources are available. + virtual Status Provision() = 0; + + // Set the number of steps required to warmup TensorFlow. Must be called + // before Provision(). + void SetNumWarmupSteps(int num_steps); + + // Return the list of TensorFlow devices that are available to execute a + // graph. This is empty until provision() is called. + const std::vector<DeviceAttributes>& GetDevices() const { return devices_; } + + // Convenience method that returns the set of device names. + const std::vector<string> GetDeviceNames() const { + std::vector<string> device_names; + device_names.reserve(devices_.size()); + for (const auto& device : devices_) { + device_names.push_back(device.name()); + } + return device_names; + } + + // Prepare the session to run the specified grappler item. This include + // initializing all the model variables. + virtual Status Initialize(const GrapplerItem& item) = 0; + + // Run the specified graph_def and return the corresponding metadata. + virtual Status Run(const GraphDef& graph_def, + const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, + RunMetadata* metadata) = 0; + + protected: + std::vector<DeviceAttributes> devices_; + const int timeout_s_; + SessionOptions options_; + RunOptions run_options_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_CLUSTERS_CLUSTER_H_ diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc new file mode 100644 index 0000000000..6e4a6a648e --- /dev/null +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/single_machine.h" +#include "tensorflow/cc/training/queue_runner.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace grappler { + +SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus) + : Cluster(timeout_s), + num_gpus_(num_gpus), + running_(false), + closing_(false) { + thread_pool_.reset(new thread::ThreadPool( + Env::Default(), SanitizeThreadSuffix("single_machine"), 2)); + + (*options_.config.mutable_device_count())["CPU"] = 1; + if (num_gpus > 0) { + (*options_.config.mutable_device_count())["GPU"] = num_gpus; + } + CHECK_GE(num_cpu_cores, 1); + options_.config.set_intra_op_parallelism_threads(num_cpu_cores); + options_.config.set_inter_op_parallelism_threads(num_cpu_cores); +} + +SingleMachine::~SingleMachine() { + CloseSession(false /*use_timeout*/); + + // Prevent the destructor from deleting mu_ until CloseSession() is done. + mutex_lock l(mu_); +} + +Status SingleMachine::Provision() { + Status status = ResetSession(); + if (!status.ok()) { + return status; + } + + DeviceAttributes attr; + attr.set_name("/job:localhost/replica:0/task:0/cpu:0"); + attr.set_device_type("CPU"); + devices_.push_back(attr); + + for (int i = 0; i < num_gpus_; ++i) { + DeviceAttributes attr; + attr.set_name(strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i)); + attr.set_device_type("GPU"); + devices_.push_back(attr); + } + return Status::OK(); +} + +Status SingleMachine::Initialize(const GrapplerItem& item) { + if (last_graph_ != &item.graph) { + init_ops_ = item.init_ops; + last_graph_ = nullptr; + queue_runner_defs_ = item.queue_runners; + } + return Status::OK(); +} + +Status SingleMachine::Run(const GraphDef& graph_def, + const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, + RunMetadata* metadata) { + if (last_graph_ != &graph_def) { + Status status = ResetSession(); + if (status.ok()) { + status = session_->Create(graph_def); + } + if (!init_ops_.empty() && status.ok()) { + status = RunWithTimeout({}, init_ops_, nullptr); + } + for (int i = 0; i < queue_runner_defs_.size() && status.ok(); ++i) { + std::unique_ptr<QueueRunner> queue_runner; + TF_RETURN_IF_ERROR(QueueRunner::New(queue_runner_defs_[i], + coordinator_.get(), &queue_runner)); + TF_RETURN_IF_ERROR(queue_runner->Start(session_.get())); + TF_RETURN_IF_ERROR(coordinator_->RegisterRunner(std::move(queue_runner))); + status = coordinator_->GetStatus(); + } + + if (status.ok()) { + last_graph_ = &graph_def; + } else { + return status; + } + + // Warmup TensorFlow if needed + for (int i = 0; + i < options_.config.graph_options().build_cost_model_after(); ++i) { + status = RunWithTimeout(feed, fetch, nullptr); + if (!status.ok()) { + return status; + } + } + } + + return RunWithTimeout(feed, fetch, metadata); +} + +Status SingleMachine::RunWithTimeout( + const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, RunMetadata* run_metadata) { + mutex_lock l(mu_); + // We shouldn't be running or closing the session at this point. + CHECK(!running_); + CHECK(!closing_); + + running_ = true; + metadata_ = RunMetadata(); + + thread_pool_->Schedule([this, feed, fetch] { + Status status = + session_->Run(run_options_, feed, {}, fetch, nullptr, &this->metadata_); + mutex_lock l(mu_); + status_ = status; + running_ = false; + done_running_.notify_all(); + }); + + while (running_) { + std::cv_status timeout = + done_running_.wait_for(l, std::chrono::milliseconds(timeout_s_ * 1000)); + if (timeout != std::cv_status::no_timeout) { + last_graph_ = nullptr; + return Status(error::DEADLINE_EXCEEDED, + strings::StrCat("Failed to run the graph after ", + timeout_s_, " seconds, aborting")); + } + } + if (run_metadata && status_.ok()) { + *run_metadata = metadata_; + } + return status_; +} + +Status SingleMachine::CloseSession(bool use_timeout) { + if (!session_) { + return Status::OK(); + } + + mutex_lock l(close_mu_); + + if (!closing_) { + closing_ = true; + + thread_pool_->Schedule([this] { + if (this->coordinator_) { + this->coordinator_->RequestStop(); + // Wait for all the runners to have closed their queues. + while (!this->coordinator_->AllRunnersStopped()) { + sleep(1); + } + // Now we can close the session. This should cancel any pending I/O + // operation. + this->session_->Close(); + // Last but not least, we can delete the coordinator. + this->coordinator_.reset(); + } else { + this->session_->Close(); + } + + // Wait for any previous run to finish. + mutex_lock l(mu_); + while (running_) { + done_running_.wait(l); + } + + mutex_lock l2(close_mu_); + closing_ = false; + done_closing_.notify_all(); + }); + } + + while (closing_) { + if (!use_timeout) { + done_closing_.wait(l); + } else { + std::cv_status timeout = done_closing_.wait_for( + l, std::chrono::milliseconds(timeout_s_ * 1000)); + if (timeout != std::cv_status::no_timeout) { + // Let the caller know that we can't shutdown the session, and therefore + // can't process any further. + return Status( + error::UNAVAILABLE, + strings::StrCat("Failed to close the previous session after ", + timeout_s_, " seconds, aborting")); + } + } + } + + return Status::OK(); +} + +Status SingleMachine::ResetSession() { + if (session_) { + LOG(INFO) << "Cleaning up previous session"; + + // Make sure the session is properly closed + Status status = CloseSession(true /*use_timeout*/); + if (!status.ok()) { + return status; + } + + // Flush all the pending closures (if any). + thread_pool_.reset(new thread::ThreadPool( + Env::Default(), SanitizeThreadSuffix("single_machine"), 2)); + + // We need to Reset the session to ensure that all the variables are + // deleted. But first we need to delete the session since Reset() + // deletes some of the containers referenced by the session. + session_.reset(); + status = Reset(options_, {}); + if (!status.ok()) { + return status; + } + } + + LOG(INFO) << "Starting new session"; + + session_.reset(NewSession(options_)); + CHECK(session_ != nullptr); + + coordinator_.reset(new Coordinator()); + + return Status::OK(); +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/single_machine.h b/tensorflow/core/grappler/clusters/single_machine.h new file mode 100644 index 0000000000..f53a3e849b --- /dev/null +++ b/tensorflow/core/grappler/clusters/single_machine.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ +#define TENSORFLOW_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ + +#include "tensorflow/cc/training/coordinator.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace grappler { + +// Create a simple cluster that makes available to grappler a subset of the +// nodes available on a single local computer. +class SingleMachine : public Cluster { + public: + SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus); + ~SingleMachine() override; + + Status Provision() override; + Status Initialize(const GrapplerItem& item) override; + Status Run(const GraphDef& item, + const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, RunMetadata* metadata) override; + + private: + Status RunWithTimeout(const std::vector<std::pair<string, Tensor>>& feed, + const std::vector<string>& fetch, + RunMetadata* run_metadata); + Status ResetSession(); + Status CloseSession(bool use_timeout); + + const int num_gpus_; + std::unique_ptr<Session> session_; + std::vector<QueueRunnerDef> queue_runner_defs_; + const GraphDef* last_graph_ = nullptr; + std::vector<string> init_ops_; + std::unique_ptr<Coordinator> coordinator_; + std::unique_ptr<thread::ThreadPool> thread_pool_; + + Status status_; + RunMetadata metadata_; + + mutex mu_; + bool running_; + condition_variable done_running_; + + mutex close_mu_; + bool closing_; + condition_variable done_closing_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_CLUSTERS_SINGLE_MACHINE_H_ diff --git a/tensorflow/core/grappler/clusters/single_machine_test.cc b/tensorflow/core/grappler/clusters/single_machine_test.cc new file mode 100644 index 0000000000..3b39e5be61 --- /dev/null +++ b/tensorflow/core/grappler/clusters/single_machine_test.cc @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/single_machine.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class SingleMachineTest : public ::testing::Test { + public: + void SetUp() override { + // Provision a single machine with 3 cpu cores + cluster_.reset(new SingleMachine(5 * 60, 3, 0)); + TF_CHECK_OK(cluster_->Provision()); + } + + void TearDown() override { + cluster_.reset(); + } + + protected: + std::unique_ptr<SingleMachine> cluster_; +}; + +TEST_F(SingleMachineTest, CostModel) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TF_CHECK_OK(cluster_->Initialize(item)); + + RunMetadata metadata; + const int64 start_micros = Env::Default()->NowMicros(); + TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata)); + const int64 run_duration_micros = Env::Default()->NowMicros() - start_micros; + + // There should be at least 4 nodes corresponding to the 4 stages we created + // in the fake input. + EXPECT_LE(4, metadata.cost_graph().node_size()); + for (const auto& node : metadata.cost_graph().node()) { + // Skip the special nodes inserted by TF: these are prefixed with an + // underscore. + if (node.name()[0] == '_' || node.name().find("/_") != string::npos) { + continue; + } + EXPECT_EQ(1, node.output_info_size()); + EXPECT_LE(8, node.output_info(0).size()); + const TensorShapeProto& shape = node.output_info(0).shape(); + EXPECT_EQ(2, shape.dim_size()); + EXPECT_EQ(10, shape.dim(0).size()); + EXPECT_EQ(1, shape.dim(1).size()); + EXPECT_LE(0, node.compute_cost()); + EXPECT_GE(run_duration_micros, node.compute_cost()); + } +} + +TEST_F(SingleMachineTest, Queue) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, true, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TF_CHECK_OK(cluster_->Initialize(item)); + RunMetadata metadata; + TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata)); +} + +TEST_F(SingleMachineTest, MultipleItems) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + + for (int i = 0; i < 3; ++i) { + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + TF_CHECK_OK(cluster_->Initialize(item)); + RunMetadata metadata1; + TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata1)); + RunMetadata metadata2; + TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata2)); + + // There should be at least 4 nodes corresponding to the 4 stages we created + // in the fake input, plus 1 enqueue and 1 dequeue node. + EXPECT_LE(6, metadata1.cost_graph().node_size()); + for (const auto& node : metadata1.cost_graph().node()) { + if (node.name()[0] == '_' || node.name().find("/_") != string::npos || + node.name() == "queue") { + continue; + } + EXPECT_EQ(1, node.output_info_size()); + const TensorShapeProto& shape = node.output_info(0).shape(); + EXPECT_EQ(2, shape.dim_size()); + EXPECT_EQ(10, shape.dim(0).size()); + EXPECT_EQ(1, shape.dim(1).size()); + } + + for (int i = 0; i < metadata1.cost_graph().node_size(); ++i) { + metadata1.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0); + metadata1.clear_step_stats(); + } + for (int i = 0; i < metadata2.cost_graph().node_size(); ++i) { + metadata2.mutable_cost_graph()->mutable_node(i)->set_compute_cost(0); + metadata2.clear_step_stats(); + } + string s1; + ::tensorflow::protobuf::TextFormat::PrintToString(metadata1, &s1); + string s2; + ::tensorflow::protobuf::TextFormat::PrintToString(metadata2, &s2); + EXPECT_EQ(s1, s2); + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD new file mode 100644 index 0000000000..161d1aa617 --- /dev/null +++ b/tensorflow/core/grappler/costs/BUILD @@ -0,0 +1,116 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", +) +load( + "@local_config_cuda//cuda:build_defs.bzl", + "if_cuda", +) + +tf_proto_library( + name = "op_performance_data", + srcs = ["op_performance_data.proto"], + cc_api_version = 2, + protodeps = ["//tensorflow/core:protos_all"], + visibility = ["//visibility:public"], +) + +cc_library( + name = "graph_properties", + srcs = ["graph_properties.cc"], + hdrs = ["graph_properties.h"], + visibility = ["//visibility:public"], + deps = [ + ":op_performance_data_cc", + ":utils", + "//tensorflow/core:core_cpu", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + ], +) + +cc_test( + name = "graph_properties_test", + srcs = ["graph_properties_test.cc"], + args = ["--heap_check=local"], # The GPU tracer leaks memory + deps = [ + ":graph_properties", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:single_machine", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + +cc_library( + name = "graph_memory", + srcs = ["graph_memory.cc"], + hdrs = ["graph_memory.h"], + visibility = ["//visibility:public"], + deps = [ + ":graph_properties", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/clusters:cluster", + ], +) + +cc_test( + name = "graph_memory_test", + srcs = ["graph_memory_test.cc"], + args = ["--heap_check=local"], # The GPU tracer leaks memory + deps = [ + ":graph_memory", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + +cc_library( + name = "utils", + srcs = ["utils.cc"], + hdrs = ["utils.h"], + visibility = ["//visibility:public"], + deps = [ + ":op_performance_data_cc", + "//third_party/eigen3", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ] + if_cuda([ + "//tensorflow/core:cuda", + "@local_config_cuda//cuda:cuda_headers", + ]), +) + +cc_library( + name = "cost_estimator", + hdrs = ["cost_estimator.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/core/grappler/costs/cost_estimator.h b/tensorflow/core/grappler/costs/cost_estimator.h new file mode 100644 index 0000000000..3c65c34f8d --- /dev/null +++ b/tensorflow/core/grappler/costs/cost_estimator.h @@ -0,0 +1,146 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_ +#define TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_ + +#include <chrono> +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +class GraphDef; +class CostGraphDef; + +namespace grappler { +struct GrapplerItem; + +constexpr int64 kMemoryUnknown = -1ll; +constexpr int64 kZeroMemory = 0ll; + +// Holds the set of things we might want to estimate or measure in Grappler. +// Always produce execution time. Other fields are optional depending on the +// estimator being used. +struct Costs { + // Returns a Costs structure with default values for all of the fields. + inline Costs(); + + // Builds a Costs structure with all zero values, rather than unknowns. + static inline Costs ZeroCosts(); + + struct MicroSeconds : std::chrono::microseconds { + MicroSeconds() : std::chrono::microseconds(0) {} + MicroSeconds(double d) : std::chrono::microseconds(static_cast<int64>(d)) {} + MicroSeconds(std::chrono::microseconds& d) : std::chrono::microseconds(d) {} + MicroSeconds& operator=(const std::chrono::microseconds& d) { + std::chrono::microseconds::operator=(d); + return *this; + } + }; + struct NanoSeconds : std::chrono::nanoseconds { + NanoSeconds() : std::chrono::nanoseconds(0) {} + NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64>(d)) {} + NanoSeconds(std::chrono::nanoseconds& d) : std::chrono::nanoseconds(d) {} + NanoSeconds& operator=(const std::chrono::nanoseconds& d) { + std::chrono::nanoseconds::operator=(d); + return *this; + } + MicroSeconds asMicroSeconds() const { + std::chrono::microseconds us = + std::chrono::duration_cast<std::chrono::microseconds>(*this); + return MicroSeconds(us); + } + }; + + // We store all our times in nanoseconds. If needs be, we can always switch to + // picoseconds in the future by updating this typedef. + typedef NanoSeconds Duration; + + // Overall cost of running the graph; latency. + Duration execution_time; + + // Computation cost of running the graph. + Duration compute_time; + + // Memory access cost of running the graph. + Duration memory_time; + + // This field can be a very pessimistic estimate of the main memory + // requirements of a graph. For example, it might assume that all activations + // are live for all of a graph's execution. + int64 max_memory; // Maximum main memory requirement in bytes over all ops. + + // These fields are used for TPU-related estimations. They are per-op + // maximums, so each op is evaluated independently, but we want the maximum of + // the value over all ops. + int64 max_per_op_buffers; // Sum of all buffers used by the ops. + int64 max_per_op_streaming; // Ignore largest input buffer, assuming it + // streams from main memory. +}; + +inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) { + os << d.count() << "us"; + return os; +} +inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) { + os << d.count() << "ns"; + return os; +} + +Costs::Costs() { + execution_time = Duration::zero(); + compute_time = Duration::zero(); + memory_time = Duration::zero(); + max_memory = kMemoryUnknown; + max_per_op_buffers = kMemoryUnknown; + max_per_op_streaming = kMemoryUnknown; +} + +Costs Costs::ZeroCosts() { + Costs costs; + costs.execution_time = Duration::zero(); + costs.max_memory = kZeroMemory; + costs.max_per_op_buffers = kZeroMemory; + costs.max_per_op_streaming = kZeroMemory; + return costs; +} + +// Given a GrapperItem and an optimized implementation of the corresponding +// TensorFlow graph, the CostEstimator attempts to predicts the actual cost of +// running the graph. +class CostEstimator { + public: + virtual ~CostEstimator() {} + + // Initalizes the estimator for the specified grappler item. + // The estimator shouldn't be used if this function returns any status other + // that OK. + virtual Status Initialize(const GrapplerItem& item) = 0; + + // Predicts the cost of running the given optimized version of the grappler + // item. + // If a CostGraphDef is passed, it will be populated with detailed information + // about the cost of running each operation of the optimized graph. + // if a double value is passed, it will be set to a value that reflects the + // overall cost of running the graph (e.g. the latency of the computation). + // Returns a status that indicate is the performance could be estimated or + // not. + virtual Status PredictCosts(const GraphDef& optimized_graph, + CostGraphDef* cost_graph, Costs* cost) const = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_COSTS_COST_ESTIMATOR_H_ diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc new file mode 100644 index 0000000000..b7827fc1ad --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_memory.cc @@ -0,0 +1,109 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/graph_memory.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" + +namespace tensorflow { +namespace grappler { + +Status GraphMemory::InferStatically() { + GraphProperties properties(item_); + TF_RETURN_IF_ERROR(properties.InferStatically()); + return InferFromGraphProperties(&properties); +} + +Status GraphMemory::InferDynamically(Cluster* cluster) { + GraphProperties properties(item_); + TF_RETURN_IF_ERROR(properties.InferDynamically(cluster)); + return InferFromGraphProperties(&properties); +} + +Status GraphMemory::InferFromGraphProperties(GraphProperties* properties) { + // Compute the worst case usage between initialization and normal mode. + // TODO(bsteiner): we should consider persistent memory usage separately. + int64 worst_case_init_mem_usage; + int64 best_case_init_mem_usage; + InferMemUsageForNodes(item_.InitOpsFanin(), properties, + &worst_case_init_mem_usage, &best_case_init_mem_usage); + int64 worst_case_main_mem_usage; + int64 best_case_main_mem_usage; + InferMemUsageForNodes(item_.MainOpsFanin(), properties, + &worst_case_main_mem_usage, &best_case_main_mem_usage); + + worst_case_memory_usage_ = + std::max(worst_case_init_mem_usage, worst_case_main_mem_usage); + best_case_memory_usage_ = + std::max(best_case_init_mem_usage, best_case_main_mem_usage); + + return Status::OK(); +} + +void GraphMemory::InferMemUsageForNodes( + const std::vector<const NodeDef*>& nodes, GraphProperties* properties, + int64* worst_case_memory_usage, int64* best_case_memory_usage) const { + // TODO(bsteiner) refine this: we should consider the multidevice case. + *worst_case_memory_usage = 0; + *best_case_memory_usage = 0; + for (const auto& node : item_.graph.node()) { + // Estimate the memory required to store the tensors generated by the node. + std::vector<OpInfo::TensorProperties> outputs = + properties->GetOutputProperties(node.name()); + int64 node_memory_usage = InferMemUsageForNeighbors(outputs); + + // Worst case memory usage corresponds to the case where all the nodes are + // alive. + *worst_case_memory_usage += node_memory_usage; + + // Estimate the memory required to store the input tensors needed by the + // node. + std::vector<OpInfo::TensorProperties> inputs = + properties->GetInputProperties(node.name()); + node_memory_usage += InferMemUsageForNeighbors(inputs); + + *best_case_memory_usage = + std::max(*best_case_memory_usage, node_memory_usage); + } +} + +int64 GraphMemory::InferMemUsageForNeighbors( + const std::vector<OpInfo::TensorProperties>& props) const { + int64 neighbors_memory_usage = 0; + for (const auto& prop : props) { + DataType dtype = prop.dtype(); + int size = DataTypeSize(dtype); + TensorShapeProto shape = prop.shape(); + if (shape.unknown_rank()) { + // Can't infer the size if the rank is unknown, just skip. + continue; + } + // If one of the dimensions is unknown statically, assume it's one. + for (int i = 0; i < shape.dim_size(); ++i) { + if (shape.dim(i).size() < 0) { + shape.mutable_dim(i)->set_size(1); + } + } + int num_elems = TensorShape(shape).num_elements(); + neighbors_memory_usage += num_elems * size; + } + return neighbors_memory_usage; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_memory.h b/tensorflow/core/grappler/costs/graph_memory.h new file mode 100644 index 0000000000..a3e152a0e1 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_memory.h @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_COSTS_GRAPH_MEMORY_H_ +#define TENSORFLOW_GRAPPLER_COSTS_GRAPH_MEMORY_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + +// Infer the worst case memory usage for a given grappler item. +class GraphMemory { + public: + explicit GraphMemory(const GrapplerItem& item) + : item_(item), worst_case_memory_usage_(-1) {} + + Status InferStatically(); + Status InferDynamically(Cluster* cluster); + Status InferFromGraphProperties(GraphProperties* properties); + + // Worst case memory usage in bytes, or -1 if the usage is unknown. + int64 GetWorstCaseMemoryUsage() const { return worst_case_memory_usage_; } + + // Best case memory usage in bytes, or -1 if the usage is unknown. + // This corresponds to the case where all the data is swapped out excepted + // that which is needed for a single node to perform its computations. + int64 GetBestCaseMemoryUsage() const { return best_case_memory_usage_; } + + private: + void InferMemUsageForNodes(const std::vector<const NodeDef*>& nodes, + GraphProperties* properties, int64* worst_case, + int64* best_case) const; + int64 InferMemUsageForNeighbors( + const std::vector<OpInfo::TensorProperties>& props) const; + + // Inputs + GrapplerItem item_; + int64 worst_case_memory_usage_; + int64 best_case_memory_usage_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_COSTS_GRAPH_MEMORY_H_ diff --git a/tensorflow/core/grappler/costs/graph_memory_test.cc b/tensorflow/core/grappler/costs/graph_memory_test.cc new file mode 100644 index 0000000000..a3c58a1d76 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_memory_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/graph_memory.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class GraphMemoryTest : public ::testing::Test {}; + +TEST_F(GraphMemoryTest, Basic) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {{"CPU:0"}}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphMemory memory(item); + Status s = memory.InferStatically(); + TF_CHECK_OK(s); + EXPECT_EQ(240, memory.GetWorstCaseMemoryUsage()); + EXPECT_EQ(80, memory.GetBestCaseMemoryUsage()); +} + +TEST_F(GraphMemoryTest, UnknownBatchSize) { + TrivialTestGraphInputYielder fake_input(4, 1, -1, false, {{"CPU:0"}}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphMemory memory(item); + Status s = memory.InferStatically(); + TF_CHECK_OK(s); + EXPECT_EQ(24, memory.GetWorstCaseMemoryUsage()); + EXPECT_EQ(8, memory.GetBestCaseMemoryUsage()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc new file mode 100644 index 0000000000..345c7e2f21 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/graph_properties.h" + +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/grappler/costs/utils.h" + +namespace tensorflow { +namespace grappler { + +Status GraphProperties::InferStatically() { + Graph graph(OpRegistry::Global()); + ShapeRefiner shape_refiner(graph.versions().producer(), graph.op_registry()); + ImportGraphDefOptions options; + Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); + TF_RETURN_IF_ERROR(s); + + for (const Node* const node : graph.nodes()) { + VLOG(1) << "<Node> " << node->name(); + auto ctx = shape_refiner.GetContext(node); + if (!ctx) { + continue; + } + CHECK_EQ(ctx->num_inputs(), node->num_inputs()); + std::vector<OpInfo::TensorProperties> input_properties; + for (int i = 0; i < ctx->num_inputs(); ++i) { + OpInfo::TensorProperties properties; + properties.set_dtype(node->input_type(i)); + shape_inference::ShapeHandle shp = ctx->input(i); + if (!ctx->RankKnown(shp)) { + properties.mutable_shape()->set_unknown_rank(true); + } else { + for (int j = 0; j < ctx->Rank(shp); ++j) { + shape_inference::DimensionHandle dim = ctx->Dim(shp, j); + int64 d = ctx->Value(dim); + properties.mutable_shape()->add_dim()->set_size(d); + } + } + input_properties.push_back(properties); + } + input_properties_[node->name()] = input_properties; + + // TODO(bsteiner): share this code with the input processing above. + CHECK_EQ(ctx->num_outputs(), node->num_outputs()); + std::vector<OpInfo::TensorProperties> output_properties; + for (int i = 0; i < ctx->num_outputs(); ++i) { + OpInfo::TensorProperties properties; + properties.set_dtype(node->output_type(i)); + shape_inference::ShapeHandle shp = ctx->output(i); + if (!ctx->RankKnown(shp)) { + properties.mutable_shape()->set_unknown_rank(true); + } else { + for (int j = 0; j < ctx->Rank(shp); ++j) { + shape_inference::DimensionHandle dim = ctx->Dim(shp, j); + int64 d = ctx->Value(dim); + properties.mutable_shape()->add_dim()->set_size(d); + } + } + output_properties.push_back(properties); + } + output_properties_[node->name()] = output_properties; + + if (!node->assigned_device_name().empty()) { + device_names_[node->name()] = node->assigned_device_name(); + } else if (!node->def().device().empty()) { + device_names_[node->name()] = node->def().device(); + } else { + device_names_[node->name()] = "not set"; + } + } + + return Status::OK(); +} + +Status GraphProperties::InferDynamically(Cluster* cluster) { + TF_RETURN_IF_ERROR(cluster->Initialize(item_)); + + // Runs the model once to collect the shapes in the cost model. + RunMetadata metadata; + TF_RETURN_IF_ERROR( + cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata)); + + std::unordered_map<string, const CostGraphDef::Node*> name_to_cost; + for (auto& node : metadata.cost_graph().node()) { + name_to_cost[node.name()] = &node; + + std::vector<OpInfo::TensorProperties> output_properties; + for (const auto& out : node.output_info()) { + OpInfo::TensorProperties properties; + properties.set_dtype(out.dtype()); + *properties.mutable_shape() = out.shape(); + output_properties.push_back(properties); + } + output_properties_[node.name()] = output_properties; + } + + for (const auto& node : item_.graph.node()) { + // Skip the nodes that are not in the cost graph: these are nodes that + // aren't run, because they aren't in the intersection of transitive fan-in + // of a fetch node and the transitive fan-out of an input, or nodes that + // were optimized away by the optimizer. + auto it = name_to_cost.find(node.name()); + if (it == name_to_cost.end()) { + continue; + } + std::vector<OpInfo::TensorProperties> inputs = + FindInputFeatures(node, name_to_cost); + + input_properties_[node.name()] = inputs; + + const CostGraphDef::Node* cost_node = it->second; + device_names_[node.name()] = cost_node->device(); + } + return Status::OK(); +} + +std::vector<OpInfo::TensorProperties> GraphProperties::GetInputProperties( + const string& node_name) const { + auto it = input_properties_.find(node_name); + if (it != input_properties_.end()) { + return it->second; + } + return std::vector<OpInfo::TensorProperties>(); +} + +std::vector<OpInfo::TensorProperties> GraphProperties::GetOutputProperties( + const string& node_name) const { + auto it = output_properties_.find(node_name); + if (it != output_properties_.end()) { + return it->second; + } + return std::vector<OpInfo::TensorProperties>(); +} + +string GraphProperties::GetDeviceName(const string& node_name) const { + auto it = device_names_.find(node_name); + if (it != device_names_.end()) { + return it->second; + } + return ""; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h new file mode 100644 index 0000000000..c49313a220 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_ +#define TENSORFLOW_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_ + +#include <unordered_map> +#include <vector> +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { + +// A TensorFlow model to optimize. +// Models are represented by the combination of a graph, one of more fetch +// nodes, and potentially a set of nodes to feed. +class GraphProperties { + public: + // Factory method for creating a GrapplerShapes from a MetaGraphDef. + // Returns nullptr if the given meta_graph cannot be converted. + explicit GraphProperties(const GrapplerItem& item) : item_(item) {} + + Status InferStatically(); + Status InferDynamically(Cluster* cluster); + + std::vector<OpInfo::TensorProperties> GetInputProperties( + const string& node_name) const; + std::vector<OpInfo::TensorProperties> GetOutputProperties( + const string& node_name) const; + string GetDeviceName(const string& node_name) const; + + private: + // Inputs + GrapplerItem item_; + std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_; + std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_; + std::map<string, string> device_names_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_COSTS_GRAPH_PROPERTIES_H_ diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc new file mode 100644 index 0000000000..d2f448e6d3 --- /dev/null +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -0,0 +1,133 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/clusters/single_machine.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class GraphPropertiesTest : public ::testing::Test { + public: + void SetUp() override { + // Provision a single machine with 3 cpu cores + cluster_.reset(new SingleMachine(5 * 60, 3, 0)); + TF_CHECK_OK(cluster_->Provision()); + } + + void TearDown() override { cluster_.reset(); } + + protected: + std::unique_ptr<SingleMachine> cluster_; +}; + +TEST_F(GraphPropertiesTest, StaticProperties) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphProperties properties(item); + Status s = properties.InferStatically(); + TF_CHECK_OK(s); + + for (const auto& node : item.graph.node()) { + if (node.op() == "Const") { + // The const node has no input. + EXPECT_EQ(0, properties.GetInputProperties(node.name()).size()); + // The const node has one output. + const auto props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(10, prop.shape().dim(0).size()); + EXPECT_EQ(1, prop.shape().dim(1).size()); + } else if (node.op() == "AddN") { + const auto in_props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, in_props.size()); + const OpInfo::TensorProperties& in_prop = in_props[0]; + EXPECT_EQ(DT_FLOAT, in_prop.dtype()); + EXPECT_FALSE(in_prop.shape().unknown_rank()); + EXPECT_EQ(2, in_prop.shape().dim_size()); + EXPECT_EQ(10, in_prop.shape().dim(0).size()); + EXPECT_EQ(1, in_prop.shape().dim(1).size()); + const auto out_props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, out_props.size()); + string in_prop_str; + ::tensorflow::protobuf::TextFormat::PrintToString(in_prop, &in_prop_str); + string out_prop_str; + ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0], + &out_prop_str); + EXPECT_EQ(in_prop_str, out_prop_str); + } + } +} + +TEST_F(GraphPropertiesTest, DynamicProperties) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + GraphProperties properties(item); + TF_CHECK_OK(cluster_->Initialize(item)); + Status s = properties.InferDynamically(cluster_.get()); + TF_CHECK_OK(s); + + for (const auto& node : item.graph.node()) { + if (node.op() == "Const") { + // The constant node is missing from the cost graph + EXPECT_EQ(0, properties.GetInputProperties(node.name()).size()); + } else if (node.op() == "AddN") { + // Since the const node is missing, we can't infer the input properties of + // the first AddN node. THe other AddN have the expected properties + if (node.name() == "AddN") { + const auto props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_INVALID, prop.dtype()); + EXPECT_TRUE(prop.shape().unknown_rank()); + } else { + const auto props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, props.size()); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_FALSE(prop.shape().unknown_rank()); + EXPECT_EQ(2, prop.shape().dim_size()); + EXPECT_EQ(10, prop.shape().dim(0).size()); + EXPECT_EQ(1, prop.shape().dim(1).size()); + const auto out_props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, out_props.size()); + string prop_str; + ::tensorflow::protobuf::TextFormat::PrintToString(prop, &prop_str); + string out_prop_str; + ::tensorflow::protobuf::TextFormat::PrintToString(out_props[0], + &out_prop_str); + EXPECT_EQ(prop_str, out_prop_str); + } + } + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/op_performance_data.proto b/tensorflow/core/grappler/costs/op_performance_data.proto new file mode 100644 index 0000000000..2d3fb1939c --- /dev/null +++ b/tensorflow/core/grappler/costs/op_performance_data.proto @@ -0,0 +1,116 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/attr_value.proto"; + +// Description of an operation as well as the parameters expected to impact its +// performance. +message OpInfo { + // The operation name. There may be custom parameters in attrs. + string op = 1; + + // Custom parameters impacting the behavior of the op. + map<string, AttrValue> attr = 2; + + // Input types and shapes + message TensorProperties { + DataType dtype = 1; + TensorShapeProto shape = 2; + }; + repeated TensorProperties inputs = 3; + + // Device on which the operation is run. + message DeviceProperties { + // Device type (CPU, GPU, ...) + string type = 1; + // Vendor (Intel, nvidia, ...) + string vendor = 2; + // Model (Haswell, K40, ...) + string model = 3; + // Core Frequency in Mhz + int64 frequency = 4; + // Number of cores + int64 num_cores = 5; + // Version of the tools and libraries used with this device (e.g. gcc 4.9, + // cudnn 5.1) + map<string, string> environment = 6; + // Number of registers per core. + int64 num_registers = 7; + // L1 cache size in bytes + int64 l1_cache_size = 8; + // L2 cache size in bytes + int64 l2_cache_size = 9; + // L3 cache size in bytes + int64 l3_cache_size = 10; + // Shared memory size per multiprocessor in bytes. This field is + // applicable to GPUs only. + int64 shared_memory_size_per_multiprocessor = 11; + // Memory size in bytes + int64 memory_size = 12; + // Memory bandwidth in KB/s + int64 bandwidth = 13; + } + DeviceProperties device = 4; +} + +// Performance data for tensorflow operations +message OpPerformance { + // The op + OpInfo op = 1; + + // The node name (optional). Makes it easier to associate the performance data + // with a specific graph node. + string node = 5; + + // Temporary memory used by this node (in bytes). + int64 temporary_memory_size = 2; + + // Time it takes to run the op (in nanoseconds). + int64 compute_cost = 3; + + // Analytical compute cost (in nanoseconds). + int64 compute_time = 6; + + // Analytical memory access cost (in nanoseconds). + int64 memory_time = 7; + + // Percentage of theoretical compute performance. + double compute_efficiency = 4; + + // Percentage of theoretical memory performance. + double memory_efficiency = 8; + + // Memory usage data for a tensorflow operation. + message OpMemory { + // The output information may have memory usage and output shapes. + repeated int64 output_memory = 1; + + // Temporary memory allocated by this node. + int64 host_temp_memory = 2; + int64 device_temp_memory = 3; + + // The persisted_memory doesn't include outputs. + int64 host_persistent_memory = 4; + int64 device_persistent_memory = 5; + } + OpMemory op_memory = 9; +} diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc new file mode 100644 index 0000000000..19266208fe --- /dev/null +++ b/tensorflow/core/grappler/costs/utils.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/costs/utils.h" + +#include <stddef.h> +#include <utility> + +#include "third_party/eigen3/Eigen/Core" + +#if GOOGLE_CUDA +#include "cuda/include/cuda.h" +#include "cuda/include/cuda_runtime_api.h" +#include "cuda/include/cudnn.h" +#endif + +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +std::vector<OpInfo::TensorProperties> FindInputFeatures( + const NodeDef& node, + const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost) { + std::vector<OpInfo::TensorProperties> inputs; + for (const auto& input_name : node.input()) { + // Skip control inputs. These are prefixed with the ^ character. + CHECK(!input_name.empty()); + if (input_name[0] == '^') { + continue; + } + + // Each input is "node_name:output_imdex" with "node_name" being a string + // name and "output_index" indicating which output tensor to use from + // "node_name". If "output_index" is 0 the ":0" suffix can be omitted. + string input_node_name; + int output_index = -1; + const size_t pos = input_name.rfind(':'); + if (pos == string::npos) { + input_node_name = input_name; + output_index = 0; + } else { + string index = input_name.substr(pos); + if (strings::safe_strto32(index, &output_index)) { + input_node_name = input_name.substr(0, pos); + } + } + + auto it = name_to_cost.find(input_name); + if (it == name_to_cost.end() || output_index < 0) { + OpInfo::TensorProperties input; + input.set_dtype(DataType::DT_INVALID); + input.mutable_shape()->set_unknown_rank(true); + inputs.push_back(input); + } else { + const CostGraphDef::Node* input_cost = it->second; + const CostGraphDef::Node::OutputInfo& output = + input_cost->output_info(output_index); + OpInfo::TensorProperties input; + input.set_dtype(output.dtype()); + *input.mutable_shape() = output.shape(); + inputs.push_back(input); + } + } + + return inputs; +} + +OpInfo::DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) { + DeviceNameUtils::ParsedName parsed; + if (DeviceNameUtils::ParseFullName(node.device(), &parsed)) { + if (parsed.type == "GPU") { + return GetLocalGPUInfo(parsed.id); + } else if (parsed.type == "CPU") { + return GetLocalCPUInfo(); + } + } + OpInfo::DeviceProperties device; + device.set_type("UNKNOWN"); + return device; +} + +OpInfo::DeviceProperties GetLocalCPUInfo() { + OpInfo::DeviceProperties device; + device.set_type("CPU"); + + device.set_num_cores(port::NumSchedulableCPUs()); + device.set_l1_cache_size(Eigen::l1CacheSize()); + device.set_l2_cache_size(Eigen::l2CacheSize()); + device.set_l3_cache_size(Eigen::l3CacheSize()); + + (*device.mutable_environment())["cpu_instruction_set"] = + Eigen::SimdInstructionSetsInUse(); + + (*device.mutable_environment())["eigen"] = strings::StrCat( + EIGEN_WORLD_VERSION, ".", EIGEN_MAJOR_VERSION, ".", EIGEN_MINOR_VERSION); +#ifdef EIGEN_USE_LIBXSMM + (*device.mutable_environment())["libxsmm"] = LIBXSMM_VERSION; +#endif + + return device; +} + +OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) { + OpInfo::DeviceProperties device; + device.set_type("GPU"); + +#if GOOGLE_CUDA + cudaDeviceProp properties; + cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id); + if (error == cudaSuccess) { + device.set_vendor("NVidia"); + device.set_model(properties.name); + device.set_frequency(properties.clockRate / 1000); + device.set_num_cores(properties.multiProcessorCount); + device.set_num_registers(properties.regsPerMultiprocessor); + // For compute capability less than 5, l1 cache size is configurable to + // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For + // compute capability larger or equal to 5, l1 cache (unified with texture + // cache) size is 24 KB. This number may need to be updated for future + // compute capabilities. + device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); + device.set_l2_cache_size(properties.l2CacheSize); + device.set_l3_cache_size(0); + device.set_shared_memory_size_per_multiprocessor( + properties.sharedMemPerMultiprocessor); + device.set_memory_size(properties.totalGlobalMem); + // 8 is the number of bits per byte. 2 is accounted for + // double data rate (DDR). + device.set_bandwidth(properties.memoryBusWidth / 8 * + properties.memoryClockRate * 2); + } + + (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); + (*device.mutable_environment())["cudnn"] = strings::StrCat(CUDNN_VERSION); +#endif + + return device; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h new file mode 100644 index 0000000000..79be906128 --- /dev/null +++ b/tensorflow/core/grappler/costs/utils.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_COSTS_UTILS_H_ +#define TENSORFLOW_GRAPPLER_COSTS_UTILS_H_ + +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/types.h" +#include "tensorflow/core/grappler/costs/op_performance_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +// Returns a vector of InputProperties for 'node'. The vector will contain one +// entry for each input of 'node'. +// For each node in the graph, the 'name_to_cost' map stores a pointer to the +// corresponding cost graph node indexed by node name. +std::vector<OpInfo::TensorProperties> FindInputFeatures( + const NodeDef& node, + const std::unordered_map<string, const CostGraphDef::Node*>& name_to_cost); + +// Returns the DeviceProperties of the device on which 'node' runs. +OpInfo::DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node); + +// Returns the DeviceProperties of the CPU on which grappler is running. +OpInfo::DeviceProperties GetLocalCPUInfo(); + +// Returns the DeviceProperties for the specified GPU attached to the server on +// which grappler is running. +OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_COSTS_UTILS_H_ diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc new file mode 100644 index 0000000000..907ef2096a --- /dev/null +++ b/tensorflow/core/grappler/grappler_item.cc @@ -0,0 +1,251 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/grappler_item.h" + +#include <unordered_map> +#include <unordered_set> +#include <vector> + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/variable.pb.h" +#include "tensorflow/core/framework/versions.h" +#include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { +namespace grappler { + +// static +std::unique_ptr<GrapplerItem> GrapplerItem::FromMetaGraphDef( + const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) { + // Check if the graph is compatible with the current version of TensorFlow. + Status status = + CheckVersions(meta_graph.graph_def().versions(), TF_GRAPH_DEF_VERSION, + TF_GRAPH_DEF_VERSION_MIN_PRODUCER, "GraphDef", "graph"); + if (!status.ok()) { + LOG(ERROR) << "Cannot process item: " << status.error_message(); + return nullptr; + } + + std::unique_ptr<GrapplerItem> new_item(new GrapplerItem()); + if (id.empty()) { + LOG(ERROR) << "id must be non-empty."; + return nullptr; + } + new_item->id = id; + new_item->graph = meta_graph.graph_def(); + + // Attempt to detect the fetch node(s). + if (meta_graph.collection_def().count("train_op") > 0) { + const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); + if (nodes.has_node_list()) { + for (const auto& node : nodes.node_list().value()) { + const string name = NodeName(node); + if (name.empty()) { + LOG(ERROR) << "Invalid fetch node name " << node + << ", skipping this input"; + return nullptr; + } + LOG(INFO) << "Will use fetch node " << name; + new_item->fetch.push_back(name); + } + } + } + if (new_item->fetch.empty()) { + LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input"; + return nullptr; + } + + for (auto& node : *new_item->graph.mutable_node()) { + // Delete user specified placement if requested. + if (cfg.ignore_user_placement) { + node.clear_device(); + } + + if (node.op() == "Placeholder" || node.op() == "PlaceholderV2") { + if (node.attr().count("dtype") == 0) { + LOG(ERROR) << "Unknown type for placeholder " << node.name() + << ", skipping this input"; + return nullptr; + } + DataType type = node.attr().at("dtype").type(); + + if (node.attr().count("shape") == 0) { + LOG(INFO) << "Unknown shape for placeholder " << node.name() + << ", skipping this input"; + return nullptr; + } + TensorShape shape(node.attr().at("shape").shape()); + // Some placeholder nodes have a mis-match between the node + // attribute "shape" and a different node attribute "_output_shapes". + // Specifically, a shape with shape.dims() == 0 could indicate either + // a scalar or an unknown shape. In those cases, we check _output_shapes + // for additional information. + // This case is observed in the bnmt graphs. Have not observed any + // cases where there was more than 1 _output_shapes, so limit it + // to cases where there is only 1 _output_shapes. + // We only do this if cfg.placeholder_unknown_output_shape_dim has + // been set to avoid crashing non-BNMT graphs. + if ((cfg.placeholder_unknown_output_shape_dim >= 0) && + (shape.dims() == 0) && (node.attr().count("_output_shapes") == 1) && + (node.attr().at("_output_shapes").list().shape(0).dim_size() != 0)) { + shape.Clear(); + for (int dim_i = 0; + dim_i < + node.attr().at("_output_shapes").list().shape(0).dim_size(); + dim_i++) { + const ::tensorflow::TensorShapeProto_Dim dim = + node.attr().at("_output_shapes").list().shape(0).dim(dim_i); + if (dim.size() == -1) { + shape.AddDim(cfg.placeholder_unknown_output_shape_dim); + } else { + shape.AddDim(node.attr() + .at("_output_shapes") + .list() + .shape(0) + .dim(dim_i) + .size()); + } + } + } + Tensor fake_input(type, shape); + // TODO(bsteiner): figure out a better way to initialize the feeds, for + // example by recording a sample of the fed inputs in mldash when running + // the graph. + memset(const_cast<char*>(fake_input.tensor_data().data()), 0, + fake_input.tensor_data().size()); + new_item->feed.emplace_back(node.name(), fake_input); + } + + if (cfg.ignore_colocation) { + auto attr = node.mutable_attr(); + auto it = attr->find("_class"); + if (it != attr->end()) { + attr->erase(it); + } + } + } + + for (const string& var_collection : + {"variables", "local_variables", "model_variables", + "trainable_variables"}) { + if (meta_graph.collection_def().count(var_collection) == 0) { + continue; + } + const CollectionDef& vars = meta_graph.collection_def().at(var_collection); + for (const auto& raw_var : vars.bytes_list().value()) { + VariableDef var; + var.ParseFromString(raw_var); + if (!var.initializer_name().empty()) { + new_item->init_ops.push_back(var.initializer_name()); + } + } + } + + if (meta_graph.collection_def().count("table_initializer") > 0) { + const CollectionDef& inits = + meta_graph.collection_def().at("table_initializer"); + if (inits.has_node_list()) { + for (const auto& node : inits.node_list().value()) { + new_item->init_ops.push_back(node); + } + } + } + + if (meta_graph.collection_def().count("queue_runners") > 0) { + const CollectionDef& vars = meta_graph.collection_def().at("queue_runners"); + for (const auto& raw : vars.bytes_list().value()) { + QueueRunnerDef queue_runner; + if (!queue_runner.ParseFromString(raw)) { + LOG(ERROR) << "Could parse queue_runners, skipping this input"; + return nullptr; + } + if (queue_runner.cancel_op_name().empty()) { + LOG(ERROR) << "Queue without a cancel op, skipping this input"; + return nullptr; + } + new_item->queue_runners.push_back(queue_runner); + } + } + + // Make sure we still can access the input files (aka "asset_filepaths") since + // these might have been moved or deleted, the cns cell might have been shut + // down, or we might be running as a user who does not have access to the + // files. + if (meta_graph.collection_def().count("asset_filepaths") > 0) { + const CollectionDef& file_paths = + meta_graph.collection_def().at("asset_filepaths"); + std::vector<string> paths; + for (const auto& raw_path : file_paths.bytes_list().value()) { + paths.push_back(raw_path); + } + if (!FilesExist(paths, nullptr)) { + LOG(ERROR) + << "Can't access one or more of the asset files, skipping this input"; + return nullptr; + } + } + + return new_item; +} + +std::vector<const NodeDef*> GrapplerItem::MainOpsFanin() const { + return ComputeTransitiveFanin(graph, fetch); +} + +std::vector<const NodeDef*> GrapplerItem::InitOpsFanin() const { + return ComputeTransitiveFanin(graph, init_ops); +} + +std::vector<const NodeDef*> ComputeTransitiveFanin( + const GraphDef& graph, const std::vector<string>& terminal_nodes) { + std::unordered_map<string, const NodeDef*> name_to_node; + for (const auto& node : graph.node()) { + name_to_node[node.name()] = &node; + } + + std::vector<const NodeDef*> queue; + for (const string& root : terminal_nodes) { + const NodeDef* node = name_to_node[NodeName(root)]; + CHECK(node); + queue.push_back(node); + } + + std::vector<const NodeDef*> result; + std::unordered_set<const NodeDef*> visited; + + while (!queue.empty()) { + const NodeDef* node = queue.back(); + queue.pop_back(); + if (!visited.insert(node).second) { + // The node has already been visited. + continue; + } + result.push_back(node); + for (const string& input : node->input()) { + const NodeDef* in = name_to_node[NodeName(input)]; + CHECK(in); + queue.push_back(in); + } + } + return result; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h new file mode 100644 index 0000000000..dff288de9f --- /dev/null +++ b/tensorflow/core/grappler/grappler_item.h @@ -0,0 +1,80 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_H_ +#define TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_H_ + +#include <memory> +#include <string> +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/protobuf/queue_runner.pb.h" + +namespace tensorflow { + +class MetaGraphDef; + +namespace grappler { + +struct ItemConfig { + // If true, ignore all user specified node placement. + bool ignore_user_placement = true; + // If true, ignore all user specified colocation attributes. + bool ignore_colocation = true; + // Dimension to use if a placeholder node has an _output_shapes attribute with + // a dimension of -1. + int32 placeholder_unknown_output_shape_dim = -1; +}; + +// A TensorFlow model to optimize. +// Models are represented by the combination of a graph, one of more fetch +// nodes, and potentially a set of nodes to feed. +// TODO(volunteer_needed): turn this struct into a class. +struct GrapplerItem { + // Factory method for creating a GrapplerItem from a MetaGraphDef. + // Returns nullptr if the given meta_graph cannot be converted. + static std::unique_ptr<GrapplerItem> FromMetaGraphDef( + const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg); + + string id; // A unique id for this item + + // Inputs + GraphDef graph; + std::vector<std::pair<string, Tensor>> feed; + std::vector<string> fetch; + + // Initialization op(s). + std::vector<string> init_ops; + + // Queue runner(s) required to run the queue(s) of this model. + std::vector<QueueRunnerDef> queue_runners; + + // Return the set of node evaluated during a regular train/inference step. + std::vector<const NodeDef*> MainOpsFanin() const; + // Return the set nodes used by TensorFlow to initialize the graph. + std::vector<const NodeDef*> InitOpsFanin() const; +}; + +// Return the transitive fanin of a set of terminal nodes. +std::vector<const NodeDef*> ComputeTransitiveFanin( + const GraphDef& graph, const std::vector<string>& terminal_nodes); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_GRAPPLER_ITEM_H_ diff --git a/tensorflow/core/grappler/grappler_item_test.cc b/tensorflow/core/grappler/grappler_item_test.cc new file mode 100644 index 0000000000..72a9f481ca --- /dev/null +++ b/tensorflow/core/grappler/grappler_item_test.cc @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class GrapplerItemTest : public ::testing::Test {}; + +TEST_F(GrapplerItemTest, Basic) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {{"CPU:0"}}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + EXPECT_TRUE(item.InitOpsFanin().empty()); + + std::vector<string> graph_nodes; + for (const auto& node : item.graph.node()) { + graph_nodes.push_back(node.name()); + } + std::vector<string> main_ops; + for (const auto& node : item.MainOpsFanin()) { + main_ops.push_back(node->name()); + } + std::sort(graph_nodes.begin(), graph_nodes.end()); + std::sort(main_ops.begin(), main_ops.end()); + EXPECT_EQ(main_ops, graph_nodes); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/inputs/BUILD b/tensorflow/core/grappler/inputs/BUILD new file mode 100644 index 0000000000..f0ca36d85d --- /dev/null +++ b/tensorflow/core/grappler/inputs/BUILD @@ -0,0 +1,72 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "utils", + srcs = [ + "utils.cc", + ], + hdrs = [ + "utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_test( + name = "utils_test", + srcs = [ + "utils_test.cc", + ], + deps = [ + ":utils", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "input_yielder", + hdrs = [ + "input_yielder.h", + ], + visibility = ["//visibility:public"], + deps = [], +) + +cc_library( + name = "trivial_test_graph_input_yielder", + srcs = ["trivial_test_graph_input_yielder.cc"], + hdrs = [ + "trivial_test_graph_input_yielder.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":input_yielder", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/kernels:aggregate_ops", + "//tensorflow/core/kernels:array", + ], +) diff --git a/tensorflow/core/grappler/inputs/input_yielder.h b/tensorflow/core/grappler/inputs/input_yielder.h new file mode 100644 index 0000000000..c9f90820a9 --- /dev/null +++ b/tensorflow/core/grappler/inputs/input_yielder.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_INPUTS_INPUT_YIELDER_H_ +#define TENSORFLOW_GRAPPLER_INPUTS_INPUT_YIELDER_H_ + +namespace tensorflow { +namespace grappler { + +struct GrapplerItem; + +// Abstract interface for yielding graphs that we want to optimize. +class InputYielder { + public: + virtual ~InputYielder() {} + + virtual bool NextItem(GrapplerItem* item) = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_INPUTS_INPUT_YIELDER_H_ diff --git a/tensorflow/core/grappler/inputs/testdata/test_file.txt b/tensorflow/core/grappler/inputs/testdata/test_file.txt new file mode 100644 index 0000000000..557db03de9 --- /dev/null +++ b/tensorflow/core/grappler/inputs/testdata/test_file.txt @@ -0,0 +1 @@ +Hello World diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc new file mode 100644 index 0000000000..8370133fc4 --- /dev/null +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc @@ -0,0 +1,111 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// The builtin inputs provide a mechanism to generate simple TensorFlow graphs +// and feed them as inputs to Grappler. This can be used for quick experiments +// or to derive small regression tests. + +#include "tensorflow/cc/ops/standard_ops.h" + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" + +namespace tensorflow { +namespace grappler { + +// Make a program with specified number of stages and "width" ops per stage. +namespace { +GraphDef CreateGraphDef(int num_stages, int width, int tensor_size, + bool use_multiple_devices, bool insert_queue, + const std::vector<string>& device_names) { + CHECK_GE(device_names.size(), width); + + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + // x is from the feed. + const int batch_size = tensor_size < 0 ? 1 : tensor_size; + Output x = Const(s.WithOpName("x"), 0.0f, {batch_size, 1}); + + // Create stages. + std::vector<Output> last_stage; + last_stage.push_back(x); + for (int i = 0; i < num_stages; i++) { + std::vector<Output> this_stage; + for (int j = 0; j < width; j++) { + Output combine = AddN( + s.WithDevice(device_names[use_multiple_devices ? j : 0]), last_stage); + this_stage.push_back(combine); + } + last_stage = this_stage; + } + + if (insert_queue) { + FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT}); + QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage); + QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT}); + QueueClose cancel(s.WithOpName("cancel"), queue, + QueueClose::CancelPendingEnqueues(true)); + last_stage = {dequeue[0]}; + } + + // Create output. + AddN output(s.WithOpName("y"), last_stage); + + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + return def; +} +} // namespace + +TrivialTestGraphInputYielder::TrivialTestGraphInputYielder( + int num_stages, int width, int tensor_size, bool insert_queue, + const std::vector<string>& device_names) + : num_stages_(num_stages), + width_(width), + tensor_size_(tensor_size), + insert_queue_(insert_queue), + device_names_(device_names) {} + +bool TrivialTestGraphInputYielder::NextItem(GrapplerItem* item) { + GrapplerItem r; + r.id = strings::StrCat("ns:", num_stages_, "/", // wrap + "w:", width_, "/", // wrap + "ts:", tensor_size_); + r.graph = CreateGraphDef(num_stages_, width_, tensor_size_, + true /*use_multiple_devices*/, insert_queue_, + device_names_); + // If the batch size is variable, we need to choose a value to create a feed + const int batch_size = tensor_size_ < 0 ? 1 : tensor_size_; + Tensor x(DT_FLOAT, TensorShape({batch_size, 1})); + r.feed.push_back(std::make_pair("x", x)); + r.fetch.push_back("y"); + + if (insert_queue_) { + QueueRunnerDef queue_runner; + queue_runner.set_queue_name("queue"); + queue_runner.set_cancel_op_name("cancel"); + *queue_runner.add_enqueue_op_name() = "enqueue"; + r.queue_runners.push_back(queue_runner); + } + + *item = std::move(r); + return true; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h new file mode 100644 index 0000000000..4c5600c816 --- /dev/null +++ b/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_ +#define TENSORFLOW_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_ + +#include <string> +#include <vector> +#include "tensorflow/core/grappler/inputs/input_yielder.h" + +namespace tensorflow { +namespace grappler { + +class Cluster; +class GrapplerItem; + +class TrivialTestGraphInputYielder : public InputYielder { + public: + TrivialTestGraphInputYielder(int num_stages, int width, int tensor_size, + bool insert_queue, + const std::vector<string>& device_names); + bool NextItem(GrapplerItem* item) override; + + private: + const int num_stages_; + const int width_; + const int tensor_size_; + const bool insert_queue_; + std::vector<string> device_names_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_INPUTS_TRIVIAL_TEST_GRAPH_INPUT_YIELDER_H_ diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc new file mode 100644 index 0000000000..17f41105b2 --- /dev/null +++ b/tensorflow/core/grappler/inputs/utils.cc @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/platform/env.h" + +#include <vector> + +namespace tensorflow { +namespace grappler { + +bool FilesExist(const std::vector<string>& files, std::vector<Status>* status) { + return Env::Default()->FilesExist(files, status); +} + +bool FilesExist(const std::set<string>& files) { + return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr); +} + +} // End namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h new file mode 100644 index 0000000000..ee65ca031d --- /dev/null +++ b/tensorflow/core/grappler/inputs/utils.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_INPUTS_UTILS_H_ +#define TENSORFLOW_GRAPPLER_INPUTS_UTILS_H_ + +#include <set> +#include <vector> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +bool FilesExist(const std::vector<string>& files, + std::vector<Status>* status = nullptr); +bool FilesExist(const std::set<string>& files); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_INPUTS_UTILS_H_ diff --git a/tensorflow/core/grappler/inputs/utils_test.cc b/tensorflow/core/grappler/inputs/utils_test.cc new file mode 100644 index 0000000000..2a7c4834f1 --- /dev/null +++ b/tensorflow/core/grappler/inputs/utils_test.cc @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class UtilsTest : public ::testing::Test { + protected: + string BaseDir() { return io::JoinPath(testing::TmpDir(), "base_dir"); } + + void SetUp() override { + TF_CHECK_OK(env_->CreateDir(BaseDir())); + non_existent_file_ = io::JoinPath(BaseDir(), "non_existent_file.txt"); + actual_file_ = io::JoinPath(BaseDir(), "test_file.txt"); + TF_CHECK_OK(WriteStringToFile(env_, actual_file_, "Some test data")); + } + + void TearDown() override { + int64 undeleted_files, undeleted_dirs; + TF_CHECK_OK( + env_->DeleteRecursively(BaseDir(), &undeleted_files, &undeleted_dirs)); + } + + string non_existent_file_; + string actual_file_; + Env* env_ = Env::Default(); +}; + +TEST_F(UtilsTest, FilesExist) { + EXPECT_FALSE(FilesExist(std::vector<string>{{non_existent_file_}})); + EXPECT_FALSE( + FilesExist(std::vector<string>{{non_existent_file_}, {actual_file_}})); + EXPECT_TRUE(FilesExist(std::vector<string>{{actual_file_}})); + + std::vector<Status> status; + EXPECT_FALSE(FilesExist( + std::vector<string>{{non_existent_file_}, {actual_file_}}, &status)); + EXPECT_EQ(status.size(), 2); + EXPECT_FALSE(status[0].ok()); + EXPECT_TRUE(status[1].ok()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD new file mode 100644 index 0000000000..518d12e3ab --- /dev/null +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -0,0 +1,45 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +cc_library( + name = "graph_optimizer", + hdrs = [ + "graph_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +cc_library( + name = "layout_optimizer", + srcs = ["layout_optimizer.cc"], + hdrs = [ + "layout_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + ], +) diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h new file mode 100644 index 0000000000..34e126b0af --- /dev/null +++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +class Cluster; +class GrapplerItem; + +// An abstract interface for an algorithm for generating a candidate +// optimization of a GrapplerItem for running on a cluster. +class GraphOptimizer { + public: + virtual ~GraphOptimizer() {} + + virtual string name() const = 0; + + // Routine called to allow an algorithm to propose a rewritten graph + // for the graph, feeds and fetches in "item" to run more efficiently + // on "cluster". + // Returns true iff it managed to generate a solution, false otherwise. + virtual Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) = 0; + + // Method invoked by the framework so that it can provide feedback + // on how well the "optimize_output" (produced as *output from a + // call to Optimize) performed. Lower "result" scores are better. + virtual void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) = 0; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc new file mode 100644 index 0000000000..417f71fed2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -0,0 +1,996 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <unordered_set> + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace grappler { + +const char kConcatConst[] = "LayoutOptimizerConcatConst"; +const char kPermNHWCToNCHW[] = "LayoutOptimizerPermConstNHWCToNCHW"; +const char kPermNCHWToNHWC[] = "LayoutOptimizerPermConstNCHWToNHWC"; +const char kTransposeNHWCToNCHW[] = "LayoutOptimizerTransposeNHWCToNCHW"; +const char kTransposeNCHWToNHWC[] = "LayoutOptimizerTransposeNCHWToNHWC"; +const char kPermVecNHWCToNCHW[] = "LayoutOptimizerPermVecNHWCToNCHW"; +const char kReshapeNHWCToNCHW[] = "LayoutOptimizerReshapeNHWCToNCHW"; +const char kReshapeConst[] = "LayoutOptimizerReshapeConst"; +const char kReductionConst[] = "LayoutOptimizerReductionConst"; + +std::set<string> GetOpsFormatSupported() { + std::set<string> ops_format_supported = {"AvgPool", + "AvgPoolGrad", + "Conv2D", + "Conv2DBackpropFilter", + "Conv2DBackpropInput", + "BiasAdd", + "BiasAddGrad", + "FusedBatchNorm", + "FusedBatchNormGrad", + "MaxPool", + "MaxPoolGrad"}; + return ops_format_supported; +} + +std::set<string> GetOpsFormatAgnostic() { + std::set<string> ops_format_agnostic = {"Add", + "AddN", + "Concat", + "ConcatV2", + "Floor", + "Identity", + "Mul", + "Neg", + "RealDiv", + "Relu", + "ReluGrad", + "Slice", + "SquaredDifference", + "Squeeze", + "Sub", + "Sum"}; + return ops_format_agnostic; +} + +class NodeMap { + public: + explicit NodeMap(GraphDef* graph) : graph_(graph) { + for (int i = 0; i < graph_->node_size(); i++) { + auto node = graph_->mutable_node(i); + nodes_.insert(std::make_pair(node->name(), node)); + for (const auto& input : node->input()) { + outputs_[input].insert(nodes_[node->name()]); + } + } + } + + NodeDef* GetNode(const string& name) { + string node_name = NodeName(name); + return nodes_[node_name]; + } + + std::set<NodeDef*> GetOutputs(const string& name) { return outputs_[name]; } + + void AddNode(const string& name, NodeDef* node) { + nodes_.insert(std::make_pair(name, node)); + } + + void AddOutput(const string& node, const string& output) { + outputs_[node].insert(nodes_[output]); + } + + void UpdateOutput(const string& node, const string& old_output, + const string& new_output) { + outputs_[node].erase(nodes_[old_output]); + outputs_[node].insert(nodes_[new_output]); + } + + private: + GraphDef* graph_; + std::unordered_map<string, NodeDef*> nodes_; + std::unordered_map<string, std::set<NodeDef*>> outputs_; +}; + +bool IsNodeNHWCToNCHW(const string& node_name) { + const string transpose_node_prefix = kTransposeNHWCToNCHW; + string prefix = node_name.substr(0, transpose_node_prefix.length()); + if (prefix.compare(transpose_node_prefix) == 0) { + return true; + } + return false; +} + +bool IsNodeNCHWToNHWC(const string& node_name) { + const string transpose_node_prefix = kTransposeNCHWToNHWC; + string prefix = node_name.substr(0, transpose_node_prefix.length()); + if (prefix.compare(transpose_node_prefix) == 0) { + return true; + } + return false; +} + +class NodeProcessor { + public: + NodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : graph_(graph), node_(node), node_map_(node_map) {} + virtual ~NodeProcessor() {} + virtual void ConvertNode() { + if (ShouldProcess()) { + UpdateAttrDataFormat(); + UpdateAttrKSize(); + UpdateAttrStrides(); + UpdateAttrShape(); + AddLayoutTransposeToInputs(); + AddLayoutTransposeToOutputs(); + CustomizedProcessing(); + } + } + + protected: + bool IsDimsN(NodeDef* node, int n) const { + if (node->attr().find("_output_shapes") != node->attr().end()) { + auto shape = node->attr().at("_output_shapes").list().shape(0); + if (shape.dim_size() == n) { + return true; + } + } + return false; + } + + bool IsDimsFour(NodeDef* node) const { return IsDimsN(node, 4); } + + bool IsNHWC() const { + if (node_->attr().find("data_format") != node_->attr().end()) { + if (node_->attr().at("data_format").s().compare("NHWC") == 0) { + return true; + } + } + return false; + } + + bool HasOutputs() const { + auto outputs = node_map_->GetOutputs(node_->name()); + return !outputs.empty(); + } + + virtual bool ShouldProcess() const { + return IsNHWC() && IsDimsFour(node_) && HasOutputs(); + } + + void UpdateAttrDataFormat() { + if (node_->attr().find("data_format") != node_->attr().end()) { + if (node_->attr().at("data_format").s().compare("NHWC") == 0) { + string* data_format = + node_->mutable_attr()->at("data_format").mutable_s(); + *data_format = "NCHW"; + } + } + } + + virtual void UpdateAttrShape() { + if (node_->attr().find("_output_shapes") != node_->attr().end()) { + auto shape = node_->mutable_attr() + ->at("_output_shapes") + .mutable_list() + ->mutable_shape(0); + if (shape->dim_size() == 4) { + int64 h = shape->dim(1).size(); + int64 w = shape->dim(2).size(); + int64 c = shape->dim(3).size(); + shape->mutable_dim(1)->set_size(c); + shape->mutable_dim(2)->set_size(h); + shape->mutable_dim(3)->set_size(w); + } + } + } + + void UpdateAttrKSize() { + if (node_->attr().find("ksize") != node_->attr().end()) { + auto list = node_->mutable_attr()->at("ksize").mutable_list(); + UpdateTuple(list); + } + } + + void UpdateAttrStrides() { + if (node_->attr().find("strides") != node_->attr().end()) { + auto list = node_->mutable_attr()->at("strides").mutable_list(); + UpdateTuple(list); + } + } + + void UpdateAttrValue(const string& name) { + NodeDef* node = node_map_->GetNode(name); + Tensor tensor; + auto success = + tensor.FromProto(node->mutable_attr()->at({"value"}).tensor()); + if (!success) { + LOG(ERROR) << "Failed to parse TensorProto."; + } + int c = tensor.flat<int>()(3); + tensor.flat<int>()(3) = tensor.flat<int>()(2); + tensor.flat<int>()(2) = tensor.flat<int>()(1); + tensor.flat<int>()(1) = c; + tensor.AsProtoTensorContent( + node->mutable_attr()->at({"value"}).mutable_tensor()); + } + + virtual std::vector<int> GetInputPos() const { + std::vector<int> input_pos = {0}; + return input_pos; + } + + void AddNodeTranspose(const string& node_name, const string& input_name, + DataType data_type, const TensorShapeProto& input_shape, + bool NHWCToNCHW) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(node_name, node); + node->set_name(node_name); + *node->add_input() = input_name; + *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; + node->set_op("Transpose"); + AttrValue attr_data_type; + attr_data_type.set_type(data_type); + node->mutable_attr()->insert({"T", attr_data_type}); + AttrValue attr_data_type_perm; + attr_data_type_perm.set_type(DT_INT32); + node->mutable_attr()->insert({"Tperm", attr_data_type_perm}); + AttrValue attr_output_shape; + auto output_shape = attr_output_shape.mutable_list()->add_shape(); + if (NHWCToNCHW) { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + } else { + output_shape->add_dim()->set_size(input_shape.dim(0).size()); + output_shape->add_dim()->set_size(input_shape.dim(2).size()); + output_shape->add_dim()->set_size(input_shape.dim(3).size()); + output_shape->add_dim()->set_size(input_shape.dim(1).size()); + } + node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); + } + + virtual void AddLayoutTransposeToInputs() { + std::vector<int> input_pos = GetInputPos(); + for (const auto& pos : input_pos) { + string node_name_NHWCToNCHW = strings::StrCat( + kTransposeNHWCToNCHW, "-", node_->name(), "-", node_->input(pos)); + auto input_node = node_map_->GetNode(node_->input(pos)); + int output_pos = NodePosition(node_->input(pos)); + AddNodeTranspose( + node_name_NHWCToNCHW, node_->input(pos), node_->attr().at("T").type(), + input_node->attr().at("_output_shapes").list().shape(output_pos), + true); + node_map_->UpdateOutput(node_->input(pos), node_->name(), + node_name_NHWCToNCHW); + node_map_->AddOutput(node_name_NHWCToNCHW, node_->name()); + *node_->mutable_input(pos) = node_name_NHWCToNCHW; + } + } + + virtual void AddLayoutTransposeToOutputs() { + auto outputs = node_map_->GetOutputs(node_->name()); + for (const auto& output : outputs) { + string node_name_NCHWToNHWC = strings::StrCat( + kTransposeNCHWToNHWC, "-", node_->name(), "-", output->name()); + auto it = std::find_if(output->mutable_input()->begin(), + output->mutable_input()->end(), + [this](const string& input) { + return input.compare(node_->name()) == 0; + }); + int output_pos = NodePosition(*it); + AddNodeTranspose( + node_name_NCHWToNHWC, node_->name(), node_->attr().at("T").type(), + node_->attr().at("_output_shapes").list().shape(output_pos), false); + *it = node_name_NCHWToNHWC; + node_map_->UpdateOutput(node_->name(), output->name(), + node_name_NCHWToNHWC); + node_map_->AddOutput(node_name_NCHWToNHWC, output->name()); + } + } + + virtual void CustomizedProcessing() {} + + GraphDef* graph_; + NodeDef* node_; + NodeMap* node_map_; + + private: + void UpdateTuple(AttrValue_ListValue* list) { + int64 h = list->i(1); + int64 w = list->i(2); + int64 c = list->i(3); + list->set_i(1, c); + list->set_i(2, h); + list->set_i(3, w); + } +}; + +class AvgPoolGradProcessor : public NodeProcessor { + public: + AvgPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {1}; + return input_pos; + } + void CustomizedProcessing() override { UpdateAttrValue(node_->input(0)); } +}; + +class BiasAddGradProcessor : public NodeProcessor { + public: + BiasAddGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + bool ShouldProcess() const override { + auto input = node_map_->GetNode(node_->input(0)); + if (input) { + if ((IsNHWC() && IsDimsFour(input)) || IsNodeNCHWToNHWC(input->name())) { + return true; + } + } + return false; + } + + void AddLayoutTransposeToOutputs() override {} +}; + +class Conv2DBackpropFilterProcessor : public NodeProcessor { + public: + Conv2DBackpropFilterProcessor(GraphDef* graph, NodeDef* node, + NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {0, 2}; + return input_pos; + } + + void AddLayoutTransposeToOutputs() override {} + // No need to update output shape, as it is always of shape + // [filter_height, filter_width, in_channels, out_channels], regardless of + // whether NCHW or NHWC is used. + void UpdateAttrShape() override {} +}; + +class Conv2DBackpropInputProcessor : public NodeProcessor { + public: + Conv2DBackpropInputProcessor(GraphDef* graph, NodeDef* node, + NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {2}; + return input_pos; + } + void CustomizedProcessing() override { UpdateAttrValue(node_->input(0)); } +}; + +class FusedBatchNormGradProcessor : public NodeProcessor { + public: + FusedBatchNormGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {0, 1}; + return input_pos; + } +}; + +class MaxPoolGradProcessor : public NodeProcessor { + public: + MaxPoolGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {0, 1, 2}; + return input_pos; + } +}; + +class AgnosticNodeProcessor : public NodeProcessor { + public: + AgnosticNodeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : NodeProcessor(graph, node, node_map) {} + + protected: + bool ShouldProcess() const override { + return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC(); + } + + bool IsNodeAfterNCHWToNHWC() const { + std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); + auto node = node_map_->GetNode(node_->name()); + while (node->input_size() > 0) { + int data_input_pos = 0; + if (node->op().compare("Concat") == 0) { + data_input_pos = 1; + } + node = node_map_->GetNode(node->input(data_input_pos)); + if (IsNodeNCHWToNHWC(node->name())) { + return true; + } + bool connected = + ops_format_agnostic.find(node->name()) != ops_format_agnostic.end(); + if (!connected) { + return false; + } + } + return false; + } +}; + +class AddNProcessor : public AgnosticNodeProcessor { + public: + AddNProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos; + for (int i = 0; i < node_->input_size(); i++) { + input_pos.push_back(i); + } + return input_pos; + } +}; + +class BinaryOpProcessor : public AgnosticNodeProcessor { + public: + BinaryOpProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) { + is_4d_with_vector_ = Is4DOperateWithVector(); + } + + protected: + bool ShouldProcess() const override { + return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + (Is4DOperateWithND(4) || Is4DOperateWithScalar() || + Is4DOperateWithVector()); + } + + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {0}; + if (Is4DOperateWithND(4)) { + input_pos.push_back(1); + } + return input_pos; + } + + bool Is4DOperateWithND(int n) const { + auto input0 = node_map_->GetNode(node_->input(0)); + auto input1 = node_map_->GetNode(node_->input(1)); + if (input0 && input1) { + return (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) && + ((n == 4) + ? (IsDimsFour(input1) || IsNodeNCHWToNHWC(input1->name())) + : IsDimsN(input1, n)); + } + return false; + } + + bool Is4DOperateWithScalar() const { return Is4DOperateWithND(0); } + + bool Is4DOperateWithVector() const { return Is4DOperateWithND(1); } + + void AddNodeShapeConst(const string& name, int num_channels) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({4})); + std::vector<int> shape = {1, num_channels, 1, 1}; + for (int i = 0; i < shape.size(); i++) { + tensor.flat<int>()(i) = shape[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + } + + void AddNodeReshape(const string& node_name, const string& input_name, + const string& shape_const_node_name, DataType data_type) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(node_name, node); + node->set_name(node_name); + *node->add_input() = input_name; + *node->add_input() = shape_const_node_name; + node->set_op("Reshape"); + + AttrValue attr_type_indices; + attr_type_indices.set_type(DT_INT32); + node->mutable_attr()->insert({"Tshape", attr_type_indices}); + + AttrValue attr_type_params; + attr_type_params.set_type(data_type); + node->mutable_attr()->insert({"T", attr_type_params}); + } + + void CustomizedProcessing() override { + if (is_4d_with_vector_) { + string suffix = strings::StrCat("-", node_->name(), "-", node_->input(1)); + string reshape_node_name = strings::StrCat(kReshapeNHWCToNCHW, suffix); + string shape_const_node_name = strings::StrCat(kReshapeConst, suffix); + int vector_size = node_map_->GetNode(node_->input(1)) + ->attr() + .at("_output_shapes") + .list() + .shape(0) + .dim(0) + .size(); + AddNodeShapeConst(shape_const_node_name, vector_size); + AddNodeReshape(reshape_node_name, node_->input(1), shape_const_node_name, + node_->attr().at("T").type()); + node_map_->AddOutput(shape_const_node_name, reshape_node_name); + node_map_->UpdateOutput(node_->input(1), node_->name(), + reshape_node_name); + node_map_->AddOutput(reshape_node_name, node_->name()); + *node_->mutable_input(1) = reshape_node_name; + } + } + + private: + bool is_4d_with_vector_; +}; + +class ConcatProcessor : public AgnosticNodeProcessor { + public: + ConcatProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) { + // For Concat, the concat axis is the first input; for ConcatV2, + // the last input. + axis_node_pos_ = + (node_->op().compare("Concat") == 0) ? 0 : (node_->input_size() - 1); + } + + protected: + bool ShouldProcess() const override { + return IsDimsFour(node_) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsAlongDimC(); + } + + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos; + int start = (node_->op().compare("Concat") == 0) ? 1 : 0; + int end = (node_->op().compare("Concat") == 0) ? node_->input_size() + : (node_->input_size() - 1); + for (int i = start; i < end; i++) { + input_pos.push_back(i); + } + return input_pos; + } + + void CustomizedProcessing() override { + node_map_->AddOutput(kConcatConst, node_->name()); + *node_->mutable_input(axis_node_pos_) = kConcatConst; + } + + bool IsAlongDimC() const { + auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_)); + if (axis_node->attr().find("value") != axis_node->attr().end()) { + return axis_node->attr().at("value").tensor().int_val(0) == 3; + } + return false; + } + + int axis_node_pos_; +}; + +class ReluGradProcessor : public AgnosticNodeProcessor { + public: + ReluGradProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + std::vector<int> GetInputPos() const override { + std::vector<int> input_pos = {0, 1}; + return input_pos; + } +}; + +// This is the older, less optimized gather-based SliceProcessor. We keep it as +// a test case for constant propagation optimization. +class SliceProcessorGatherBased : public AgnosticNodeProcessor { + public: + SliceProcessorGatherBased(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + void CustomizedProcessing() override { + // Skip the first input, which is the data to be sliced. + for (int i = 1; i < node_->input_size(); i++) { + string node_name_NHWCToNCHW = + strings::StrCat(kPermVecNHWCToNCHW, "-", node_->name(), "-input", i); + AddNodePermVec(node_name_NHWCToNCHW, node_->input(i), + node_->attr().at("Index").type(), true); + node_map_->UpdateOutput(node_->input(i), node_->name(), + node_name_NHWCToNCHW); + node_map_->AddOutput(node_name_NHWCToNCHW, node_->name()); + *node_->mutable_input(i) = node_name_NHWCToNCHW; + } + } + + private: + void AddNodePermVec(const string& node_name, const string& input_name, + DataType data_type, bool NHWCToNCHW) { + NodeDef* node = graph_->add_node(); + node_map_->AddNode(node_name, node); + node->set_name(node_name); + *node->add_input() = input_name; + *node->add_input() = NHWCToNCHW ? kPermNHWCToNCHW : kPermNCHWToNHWC; + node->set_op("Gather"); + + AttrValue attr_type_indices; + attr_type_indices.set_type(DT_INT32); + node->mutable_attr()->insert({"Tindices", attr_type_indices}); + + AttrValue attr_type_params; + attr_type_params.set_type(data_type); + node->mutable_attr()->insert({"Tparams", attr_type_params}); + + AttrValue attr_validate; + attr_validate.set_b(true); + node->mutable_attr()->insert({"validate_indices", attr_validate}); + } +}; + +class SliceProcessor : public AgnosticNodeProcessor { + public: + SliceProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + void CustomizedProcessing() override { + auto maybe_concatoffset_node = + node_map_->GetNode(NodeName(node_->input(1))); + if (maybe_concatoffset_node->op() == "ConcatOffset") { + auto axis_node = node_map_->GetNode(maybe_concatoffset_node->input(0)); + // Need to process if the channel is at dimension 3, which indicates the + // NHWC format is being used. As mutiple Slice nodes may share the same + // ConcatOffset node, the NHWC to NCHW conversion may have already + // been performed when processing other Slice nodes. + if (axis_node->attr().at("value").tensor().int_val(0) == 3) { + for (int i = 1; i < maybe_concatoffset_node->input_size(); i++) { + auto shape_node = + node_map_->GetNode(maybe_concatoffset_node->input(i)); + AttrValue attr_tensor; + Tensor tensor; + CHECK(tensor.FromProto(shape_node->attr().at({"value"}).tensor())); + int h = tensor.flat<int>()(1); + int w = tensor.flat<int>()(2); + int c = tensor.flat<int>()(3); + tensor.flat<int>()(1) = c; + tensor.flat<int>()(2) = h; + tensor.flat<int>()(3) = w; + tensor.AsProtoTensorContent( + shape_node->mutable_attr()->at({"value"}).mutable_tensor()); + } + // Set the channel dimension to 1, as we have converted the vector + // element order from NHWC to NCHW. + axis_node->mutable_attr()->at("value").mutable_tensor()->set_int_val(0, + 1); + } + } + } +}; + +class SqueezeProcessor : public AgnosticNodeProcessor { + public: + SqueezeProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + bool ShouldProcess() const override { + return IsDimsN(node_, 2) && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsInputConvertible() && IsAlongDimHW(); + } + + void AddLayoutTransposeToOutputs() override {} + + bool IsInputConvertible() const { + auto input = node_map_->GetNode(node_->input(0)); + if (IsNodeNCHWToNHWC(input->name())) { + input = node_map_->GetNode(input->input(0)); + } + if (input->attr().find("_output_shapes") != input->attr().end()) { + auto shape = input->attr().at("_output_shapes").list().shape(0); + if (shape.dim_size() != 4) { + return false; + } + if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) { + return true; + } + } + return false; + } + + bool IsAlongDimHW() const { + if (node_->attr().find("squeeze_dims") != node_->attr().end()) { + auto list = node_->attr().at("squeeze_dims").list(); + if (list.i(0) == 1 && list.i(1) == 2) { + return true; + } + } + return false; + } + + void CustomizedProcessing() override { + auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list(); + list->set_i(0, 2); + list->set_i(1, 3); + } +}; + +class SumProcessor : public AgnosticNodeProcessor { + public: + SumProcessor(GraphDef* graph, NodeDef* node, NodeMap* node_map) + : AgnosticNodeProcessor(graph, node, node_map) {} + + protected: + bool ShouldProcess() const override { + auto input0 = node_map_->GetNode(node_->input(0)); + return HasOutputs() && IsNodeAfterNCHWToNHWC() && + (IsDimsFour(input0) || IsNodeNCHWToNHWC(input0->name())) && + IsAlongDimNHW(); + } + + void AddLayoutTransposeToOutputs() override {} + + void CustomizedProcessing() override { + node_map_->AddOutput(kReductionConst, node_->name()); + *node_->mutable_input(1) = kReductionConst; + } + + private: + bool IsAlongDimNHW() const { + NodeDef* node = node_map_->GetNode(node_->input(1)); + Tensor tensor; + if (node->attr().find({"value"}) == node->attr().end()) { + return false; + } + auto success = tensor.FromProto(node->attr().at({"value"}).tensor()); + if (!success) { + LOG(ERROR) << "Failed to parse TensorProto."; + return false; + } + if (tensor.flat<int>().size() != 3) { + return false; + } + if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 && + tensor.flat<int>()(2) == 2) { + return true; + } + return false; + } +}; + +class DataLayoutOptimizer { + public: + explicit DataLayoutOptimizer(GraphDef* graph) + : graph_(graph), node_map_(graph_) { + LOG(INFO) << "Number of nodes for original graph: " << graph_->node_size(); + Expand(); + LOG(INFO) << "Number of nodes after Expand: " << graph_->node_size(); + Collapse(); + LOG(INFO) << "Number of nodes after Collapse: " << graph_->node_size(); + } + + private: + void AddNodePermConst(const string& name, + const std::vector<int>& permutation) { + NodeDef* node = graph_->add_node(); + node_map_.AddNode(name, node); + node->set_name(name); + node->set_op("Const"); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({4})); + for (int i = 0; i < permutation.size(); i++) { + tensor.flat<int>()(i) = permutation[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + } + + void AddNodeConcatConst() { + NodeDef* node = graph_->add_node(); + node_map_.AddNode(kConcatConst, node); + node->set_name(kConcatConst); + node->set_op("Const"); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar<int>()() = 1; + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + } + + void AddNodeReductionConst() { + NodeDef* node = graph_->add_node(); + node_map_.AddNode(kReductionConst, node); + node->set_name(kReductionConst); + node->set_op("Const"); + AttrValue attr_data_type; + attr_data_type.set_type(DT_INT32); + node->mutable_attr()->insert({"dtype", attr_data_type}); + + AttrValue attr_tensor; + Tensor tensor(DT_INT32, TensorShape({3})); + std::vector<int> axis = {0, 2, 3}; + for (int i = 0; i < axis.size(); i++) { + tensor.flat<int>()(i) = axis[i]; + } + tensor.AsProtoTensorContent(attr_tensor.mutable_tensor()); + node->mutable_attr()->insert({"value", attr_tensor}); + } + + // Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic. + void Expand() { + int node_size_original = graph_->node_size(); + // This is the first pass where we expand the nodes which support NCHW. + std::set<string> ops_format_supported = GetOpsFormatSupported(); + for (int i = 0; i < graph_->node_size(); i++) { + if (ops_format_supported.find(graph_->node(i).op()) != + ops_format_supported.end()) { + auto node = graph_->mutable_node(i); + std::unique_ptr<NodeProcessor> node_processor; + if (node->op().compare("AvgPoolGrad") == 0) { + node_processor.reset( + new AvgPoolGradProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("BiasAddGrad") == 0) { + node_processor.reset( + new BiasAddGradProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Conv2DBackpropFilter") == 0) { + node_processor.reset( + new Conv2DBackpropFilterProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Conv2DBackpropInput") == 0) { + node_processor.reset( + new Conv2DBackpropInputProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("FusedBatchNormGrad") == 0) { + node_processor.reset( + new FusedBatchNormGradProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("MaxPoolGrad") == 0) { + node_processor.reset( + new MaxPoolGradProcessor(graph_, node, &node_map_)); + } else { + node_processor.reset(new NodeProcessor(graph_, node, &node_map_)); + } + node_processor->ConvertNode(); + } + } + + // This is the second pass where we expand layout-agnostic nodes. This pass + // only needs to be performed if at least one node in the previous pass is + // expanded. + if (graph_->node_size() > node_size_original) { + AddNodePermConst(kPermNHWCToNCHW, {0, 3, 1, 2}); + AddNodePermConst(kPermNCHWToNHWC, {0, 2, 3, 1}); + AddNodeConcatConst(); + AddNodeReductionConst(); + std::set<string> ops_format_agnostic = GetOpsFormatAgnostic(); + for (int i = 0; i < graph_->node_size(); i++) { + if (ops_format_agnostic.find(graph_->node(i).op()) != + ops_format_agnostic.end()) { + auto node = graph_->mutable_node(i); + std::unique_ptr<NodeProcessor> node_processor; + if (node->op().compare("AddN") == 0) { + node_processor.reset(new AddNProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Add") == 0 || + node->op().compare("Mul") == 0 || + node->op().compare("RealDiv") == 0 || + node->op().compare("SquaredDifference") == 0 || + node->op().compare("Sub") == 0) { + node_processor.reset( + new BinaryOpProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Concat") == 0 || + node->op().compare("ConcatV2") == 0) { + node_processor.reset(new ConcatProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("ReluGrad") == 0) { + node_processor.reset( + new ReluGradProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Slice") == 0) { + node_processor.reset(new SliceProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Squeeze") == 0) { + node_processor.reset( + new SqueezeProcessor(graph_, node, &node_map_)); + } else if (node->op().compare("Sum") == 0) { + node_processor.reset(new SumProcessor(graph_, node, &node_map_)); + } else { + node_processor.reset( + new AgnosticNodeProcessor(graph_, node, &node_map_)); + } + node_processor->ConvertNode(); + } + } + } + } + + // Remove all node pairs, where a NCHW-to-NHWC node is followed by + // a NHWC-to-NCHW node. + void Collapse() { + std::unordered_set<string> nodes_removable; + for (int i = 0; i < graph_->node_size(); i++) { + auto node = graph_->mutable_node(i); + if (IsNodeNHWCToNCHW(node->name())) { + if (IsNodeNCHWToNHWC(node->input(0))) { + const string& trans_first = node->input(0); + const string& trans_second = node->name(); + auto outputs = node_map_.GetOutputs(trans_second); + CHECK(outputs.size() == 1) + << "There is always only a single output for a Transpose node, " + << "due to the way it is added by NodeProcessor."; + NodeDef* output = *outputs.begin(); + string input = node_map_.GetNode(trans_first)->input(0); + for (int i = 0; i < output->input_size(); i++) { + if (output->input(i).compare(trans_second) == 0) { + *output->mutable_input(i) = input; + break; + } + } + nodes_removable.insert(trans_first); + nodes_removable.insert(trans_second); + } + } + } + graph_->mutable_node()->erase( + std::remove_if( + graph_->mutable_node()->begin(), graph_->mutable_node()->end(), + [nodes_removable](const NodeDef& node) { + return nodes_removable.find(node.name()) != nodes_removable.end(); + }), + graph_->mutable_node()->end()); + } + + GraphDef* graph_; + NodeMap node_map_; +}; + +Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + DataLayoutOptimizer layout_optimizer(output); + return Status::OK(); +} + +void LayoutOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) { + // Nothing to do for LayoutOptimizer. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.h b/tensorflow/core/grappler/optimizers/layout_optimizer.h new file mode 100644 index 0000000000..66dec17a35 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_ + +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// Convert the NHWC layout to NCHW for Conv-related ops on GPUs. +class LayoutOptimizer : public GraphOptimizer { + public: + LayoutOptimizer() {} + ~LayoutOptimizer() override {} + + string name() const override { return "layout"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_LAYOUT_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc new file mode 100644 index 0000000000..daef641fd5 --- /dev/null +++ b/tensorflow/core/grappler/utils.cc @@ -0,0 +1,90 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/lib/strings/scanner.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/stream_executor.h" + +namespace tensorflow { +namespace grappler { + +int GetNumAvailableGPUs() { + int num_eligible_gpus = 0; + if (ValidateGPUMachineManager().ok()) { + perftools::gputools::Platform* gpu_manager = GPUMachineManager(); + if (gpu_manager != nullptr) { + int num_gpus = gpu_manager->VisibleDeviceCount(); + for (int i = 0; i < num_gpus; i++) { + auto exec_status = gpu_manager->ExecutorForDevice(i); + if (exec_status.ok()) { + perftools::gputools::StreamExecutor* se = exec_status.ValueOrDie(); + const perftools::gputools::DeviceDescription& desc = + se->GetDeviceDescription(); + int min_gpu_core_count = 8; + if (desc.core_count() >= min_gpu_core_count) { + num_eligible_gpus++; + } + } + } + } + } + LOG(INFO) << "Number of eligible GPUs (core count >= 8): " + << num_eligible_gpus; + return num_eligible_gpus; +} + +int GetNumAvailableLogicalCPUCores() { return port::NumSchedulableCPUs(); } + +string ParseNodeName(const string& name, int* position) { + // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any) + // to get a node name. + strings::Scanner scan(name); + scan.ZeroOrOneLiteral("^") + .RestartCapture() + .One(strings::Scanner::LETTER_DIGIT_DOT) + .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); + StringPiece capture; + StringPiece remaining; + if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) { + *position = 0; + return ""; + } else { + if (name[0] == '^') { + *position = -1; + } else if (remaining.empty()) { + *position = 0; + } else { + // Skip the first ':' character. + *position = std::stoi(remaining.substr(1).ToString()); + } + return capture.ToString(); + } +} + +string NodeName(const string& name) { + int position; + return ParseNodeName(name, &position); +} + +int NodePosition(const string& name) { + int position; + ParseNodeName(name, &position); + return position; +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h new file mode 100644 index 0000000000..1aa91e25e4 --- /dev/null +++ b/tensorflow/core/grappler/utils.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_GRAPPLER_UTILS_H_ +#define TENSORFLOW_GRAPPLER_UTILS_H_ + +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { + +// Get the number of available GPUs whose number of multiprocessors is no less +// than 8. +int GetNumAvailableGPUs(); + +// Get the number of logical CPU cores (aka hyperthreads) available. +int GetNumAvailableLogicalCPUCores(); + +// Return the node name corresponding to 'name' if name is valid, or the empty +// string otherwise. +string NodeName(const string& name); + +// Get the trailing position number ":{digits}" (if any) of a node name. +int NodePosition(const string& name); + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_GRAPPLER_UTILS_H_ diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc new file mode 100644 index 0000000000..57ba45352a --- /dev/null +++ b/tensorflow/core/grappler/utils_test.cc @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class UtilsTest : public ::testing::Test {}; + +TEST_F(UtilsTest, NodeName) { + EXPECT_EQ("abc", NodeName("abc")); + EXPECT_EQ("abc", NodeName("^abc")); + EXPECT_EQ("abc", NodeName("abc:0")); + EXPECT_EQ("abc", NodeName("^abc:0")); + + EXPECT_EQ("abc/def", NodeName("abc/def")); + EXPECT_EQ("abc/def", NodeName("^abc/def")); + EXPECT_EQ("abc/def", NodeName("abc/def:1")); + EXPECT_EQ("abc/def", NodeName("^abc/def:1")); + + EXPECT_EQ("abc/def0", NodeName("abc/def0")); + EXPECT_EQ("abc/def0", NodeName("^abc/def0")); + EXPECT_EQ("abc/def0", NodeName("abc/def0:0")); + EXPECT_EQ("abc/def0", NodeName("^abc/def0:0")); + + EXPECT_EQ("abc/def_0", NodeName("abc/def_0")); + EXPECT_EQ("abc/def_0", NodeName("^abc/def_0")); + EXPECT_EQ("abc/def_0", NodeName("abc/def_0:3")); + EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3")); + + EXPECT_EQ("abc/def_0", NodeName("^abc/def_0:3214")); +} + +TEST_F(UtilsTest, NodePosition) { + EXPECT_EQ(2, NodePosition("abc:2")); + EXPECT_EQ(123, NodePosition("abc:123")); + EXPECT_EQ(-1, NodePosition("^abc:123")); + EXPECT_EQ(-1, NodePosition("^abc")); + EXPECT_EQ(0, NodePosition("")); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc index d2b252629e..d729963616 100644 --- a/tensorflow/core/platform/env.cc +++ b/tensorflow/core/platform/env.cc @@ -132,6 +132,52 @@ Status Env::FileExists(const string& fname) { return fs->FileExists(fname); } +bool Env::FilesExist(const std::vector<string>& files, + std::vector<Status>* status) { + std::unordered_map<string, std::vector<string>> files_per_fs; + for (const auto& file : files) { + StringPiece scheme, host, path; + io::ParseURI(file, &scheme, &host, &path); + files_per_fs[scheme.ToString()].push_back(file); + } + + std::unordered_map<string, Status> per_file_status; + bool result = true; + for (auto itr : files_per_fs) { + FileSystem* file_system = file_system_registry_->Lookup(itr.first); + bool fs_result; + std::vector<Status> local_status; + std::vector<Status>* fs_status = status ? &local_status : nullptr; + if (!file_system) { + fs_result = false; + if (fs_status) { + Status s = errors::Unimplemented("File system scheme ", itr.first, + " not implemented"); + local_status.resize(itr.second.size(), s); + } + } else { + fs_result = file_system->FilesExist(itr.second, fs_status); + } + if (fs_status) { + result &= fs_result; + for (int i = 0; i < itr.second.size(); ++i) { + per_file_status[itr.second[i]] = fs_status->at(i); + } + } else if (!fs_result) { + // Return early + return false; + } + } + + if (status) { + for (const auto& file : files) { + status->push_back(per_file_status[file]); + } + } + + return result; +} + Status Env::GetChildren(const string& dir, std::vector<string>* result) { FileSystem* fs; TF_RETURN_IF_ERROR(GetFileSystemForFile(dir, &fs)); diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 354748eb3e..1b7e024b0f 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -136,6 +136,12 @@ class Env { /// Returns OK if the named path exists and NOT_FOUND otherwise. Status FileExists(const string& fname); + /// Returns true if all the listed files exist, false otherwise. + /// if status is not null, populate the vector with a detailed status + /// for each file. + bool FilesExist(const std::vector<string>& files, + std::vector<Status>* status); + /// \brief Stores in *result the names of the children of the specified /// directory. The names are relative to "dir". /// diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc index 4564297b5f..3d7553e6da 100644 --- a/tensorflow/core/platform/file_system.cc +++ b/tensorflow/core/platform/file_system.cc @@ -76,6 +76,22 @@ WritableFile::~WritableFile() {} FileSystemRegistry::~FileSystemRegistry() {} +bool FileSystem::FilesExist(const std::vector<string>& files, + std::vector<Status>* status) { + bool result = true; + for (const auto& file : files) { + Status s = FileExists(file); + result &= s.ok(); + if (status != nullptr) { + status->push_back(s); + } else if (!result) { + // Return early since there is no need to check other files. + return false; + } + } + return result; +} + Status FileSystem::GetMatchingPaths(const string& pattern, std::vector<string>* results) { results->clear(); diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index b2499769aa..903df96b58 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -105,6 +105,12 @@ class FileSystem { /// Returns OK if the named path exists and NOT_FOUND otherwise. virtual Status FileExists(const string& fname) = 0; + /// Returns true if all the listed files exist, false otherwise. + /// if status is not null, populate the vector with a detailed status + /// for each file. + virtual bool FilesExist(const std::vector<string>& files, + std::vector<Status>* status); + /// \brief Returns the immediate children in the given directory. /// /// The returned paths are relative to 'dir'. |