aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-08 09:06:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 09:10:28 -0700
commit5f308cb408eb46ec9af0546be6b9ae1d5166b185 (patch)
treef6da0324d24a19486ceee4d91989198099394a06 /tensorflow/core
parent75f57a8b7836a1ed3cda8ba81c88f6caf15cf0c6 (diff)
Optimize PinToHostOptimizer by adding cache, also add PinToHostOptimizer to benchmarks.
original runtime: 4.83492736816 secs w/ cache runtime: 2.19033999443 secs PiperOrigin-RevId: 216195286
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/grappler/op_types.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc162
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc76
4 files changed, 179 insertions, 85 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 1b5a215987..cbf5c8e038 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -102,15 +102,19 @@ bool IsConjugateTranspose(const NodeDef& node) {
}
bool IsControlFlow(const NodeDef& node) {
- // clang-format off
- return node.op() == "ControlTrigger" ||
- node.op() == "Enter" ||
- node.op() == "Exit" ||
- node.op() == "LoopCond" ||
- node.op() == "Merge" ||
- node.op() == "NextIteration" ||
- node.op() == "Switch";
- // clang-format on
+ // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
+ // string comparison.
+ static const gtl::FlatSet<string>* const kControFlowOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "ControlTrigger",
+ "Enter",
+ "Exit",
+ "LoopCond",
+ "Merge",
+ "NextIteration",
+ "Switch",
+ }));
+ return kControFlowOps->count(node.op()) > 0;
}
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 8ed4271fa4..29a3b2b74c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,16 +25,29 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
+
namespace internal {
+namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
+struct OpDevicePortHasher {
+ std::size_t operator()(const std::tuple<string, string, int>& x) const {
+ uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
+
+ return Hash64Combine(code, hash<int>()(std::get<2>(x)));
+ }
+};
+using OpDevicePortOnHostMap =
+ gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
+
// All the nodes that should be blacklisted and not swapped.
bool IsBlacklisted(const NodeDef& node) {
return
@@ -82,10 +95,10 @@ Status TryFindKernelDef(const std::vector<DeviceType>& devices,
// Checks if a node's output port is host friendly.
// Roughly this means checking if the output port is on Host memory.
-Status IsNodeOutputPortHostFriendly(const GraphView& graph,
- GraphProperties* properties,
- const NodeDef& node, int port_id,
- bool* is_candidate) {
+Status IsNodeOutputPortHostFriendly(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
// Make sure we are not a blacklisted op.
@@ -117,7 +130,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
for (const auto& fanin : graph.GetFanins(node, false)) {
bool fanin_candidate = false;
TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
if (!fanin_candidate) {
return Status::OK();
}
@@ -132,11 +146,22 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
return Status::OK();
}
+ // Check `op_device_outport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ const std::tuple<string, string, int> cache_key(node.op(), node.device(),
+ port_id);
+ auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_outport_pinned_to_host_cache->end()) {
+ *is_candidate = it->second;
+ return Status::OK();
+ }
+
// Check if op's output port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -146,6 +171,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
LOG(WARNING) << "Invalid port: " << port_id << "!\n"
<< node.DebugString() << "\n"
<< op->DebugString();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -155,6 +181,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
&kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -166,22 +193,35 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
}
}
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
+
return Status::OK();
}
// Checks if a node's input port is Host friendly.
// Roughly this means checking if the input port is on Host memory.
-bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
+bool IsNodeInputPortHostFriendly(
+ const NodeDef& node, int port_id,
+ OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
// If node is on Host, assume its inputs are Host friendly.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
return true;
}
+ // Check `op_device_inport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
+ auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_inport_pinned_to_host_cache->end()) {
+ return it->second;
+ }
+
// Check if op's input port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
@@ -192,16 +232,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
// Check if the input_arg is pinned to Host.
for (const string& host_memory_arg : kernel->host_memory_arg()) {
if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
return true;
}
}
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+
return false;
}
@@ -211,18 +255,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
// 2] Check if node can run on Host.
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
// ints, and pinned to Host).
-Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
- const NodeDef& node, bool* is_candidate) {
+Status IsNodeHostCandidate(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- *is_candidate = true;
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
return Status::OK();
}
- // Skip these node types.
- if (IsBlacklisted(node)) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
return Status::OK();
}
@@ -232,17 +278,6 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
return Status::OK();
}
- // Check all inputs are Host friendly.
- for (const GraphView::OutputPort& fanin :
- graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
- bool fanin_candidate = false;
- TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
- if (!fanin_candidate) {
- return Status::OK();
- }
- }
-
// Check all outputs are Host friendly.
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
@@ -255,16 +290,42 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
}
}
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+
*is_candidate = true;
return Status::OK();
}
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device) {
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace
+
+// Tries to swap `device` to a Host device from `devices`. Returns true iff
+// there was a swap.
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device) {
// Force this node onto the CPU.
- if (device.empty() && has_device_cpu) {
- return "/device:CPU:0";
- } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ if (device->empty() && has_device_cpu) {
+ *device = "/device:CPU:0";
+ return true;
+ } else if (str_util::StrContains(*device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -272,27 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ strings::StrCat(device->substr(0, device->rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- return device_host;
+ *device = device_host;
+ return true;
}
}
}
- // We couldn't find an appropriate Host device, return original device.
- return device;
-}
-
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
+ // We couldn't find an appropriate Host device, return false.
return false;
}
+
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -324,20 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
+ // Cache to map {op, device, port} -> bool on whether it is pinned to host.
+ internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
+ internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
+
for (auto& node : *optimized_graph->mutable_node()) {
bool is_candidate = false;
- TF_RETURN_IF_ERROR(
- internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
+ TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
+ graph, &properties, node, &op_device_outport_pinned_to_host_cache,
+ &is_candidate));
if (!is_candidate) {
continue;
}
- if (IsConstant(node)) {
- const_nodes.emplace_back(&node, node.device());
+ const string original_device = node.device();
+ const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
+ node.mutable_device());
+ // Keep track of all Const nodes that we swapped.
+ if (swapped && IsConstant(node)) {
+ const_nodes.emplace_back(&node, original_device);
}
- // Try and swap the device to Host.
- node.set_device(
- internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -349,8 +408,9 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
// The consumer is not Host friendly, swap it back to the original device.
- if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
- fanout.port_id)) {
+ if (!internal::IsNodeInputPortHostFriendly(
+ *fanout.node, fanout.port_id,
+ &op_device_inport_pinned_to_host_cache)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index d557a03463..bed4a9ef95 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device);
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 7c64529441..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {