aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-08 23:42:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 23:46:56 -0700
commit129bb5e845ccb2ab6339e85d39545800dac6ca33 (patch)
tree1d0f33cd2db065c56dca200fd247551b48bb037b
parente27ee15fa45a5f4e43e10ed1fe0eb3a1feb4253a (diff)
Automated rollback of commit 5f308cb408eb46ec9af0546be6b9ae1d5166b185
PiperOrigin-RevId: 216309111
-rw-r--r--tensorflow/core/grappler/op_types.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc162
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc76
4 files changed, 85 insertions, 179 deletions
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index cbf5c8e038..1b5a215987 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -102,19 +102,15 @@ bool IsConjugateTranspose(const NodeDef& node) {
}
bool IsControlFlow(const NodeDef& node) {
- // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
- // string comparison.
- static const gtl::FlatSet<string>* const kControFlowOps =
- CHECK_NOTNULL((new gtl::FlatSet<string>{
- "ControlTrigger",
- "Enter",
- "Exit",
- "LoopCond",
- "Merge",
- "NextIteration",
- "Switch",
- }));
- return kControFlowOps->count(node.op()) > 0;
+ // clang-format off
+ return node.op() == "ControlTrigger" ||
+ node.op() == "Enter" ||
+ node.op() == "Exit" ||
+ node.op() == "LoopCond" ||
+ node.op() == "Merge" ||
+ node.op() == "NextIteration" ||
+ node.op() == "Switch";
+ // clang-format on
}
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 29a3b2b74c..8ed4271fa4 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,29 +25,16 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
-
namespace internal {
-namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
-struct OpDevicePortHasher {
- std::size_t operator()(const std::tuple<string, string, int>& x) const {
- uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
-
- return Hash64Combine(code, hash<int>()(std::get<2>(x)));
- }
-};
-using OpDevicePortOnHostMap =
- gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
-
// All the nodes that should be blacklisted and not swapped.
bool IsBlacklisted(const NodeDef& node) {
return
@@ -95,10 +82,10 @@ Status TryFindKernelDef(const std::vector<DeviceType>& devices,
// Checks if a node's output port is host friendly.
// Roughly this means checking if the output port is on Host memory.
-Status IsNodeOutputPortHostFriendly(
- const GraphView& graph, GraphProperties* properties, const NodeDef& node,
- int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
- bool* is_candidate) {
+Status IsNodeOutputPortHostFriendly(const GraphView& graph,
+ GraphProperties* properties,
+ const NodeDef& node, int port_id,
+ bool* is_candidate) {
*is_candidate = false;
// Make sure we are not a blacklisted op.
@@ -130,8 +117,7 @@ Status IsNodeOutputPortHostFriendly(
for (const auto& fanin : graph.GetFanins(node, false)) {
bool fanin_candidate = false;
TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id,
- op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
if (!fanin_candidate) {
return Status::OK();
}
@@ -146,22 +132,11 @@ Status IsNodeOutputPortHostFriendly(
return Status::OK();
}
- // Check `op_device_outport_pinned_to_host_cache` for our
- // {op, device, port_id} combo to see if the arg is pinned on Host.
- const std::tuple<string, string, int> cache_key(node.op(), node.device(),
- port_id);
- auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
- if (it != op_device_outport_pinned_to_host_cache->end()) {
- *is_candidate = it->second;
- return Status::OK();
- }
-
// Check if op's output port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
- op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -171,7 +146,6 @@ Status IsNodeOutputPortHostFriendly(
LOG(WARNING) << "Invalid port: " << port_id << "!\n"
<< node.DebugString() << "\n"
<< op->DebugString();
- op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -181,7 +155,6 @@ Status IsNodeOutputPortHostFriendly(
&kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
- op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -193,35 +166,22 @@ Status IsNodeOutputPortHostFriendly(
}
}
- op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
-
return Status::OK();
}
// Checks if a node's input port is Host friendly.
// Roughly this means checking if the input port is on Host memory.
-bool IsNodeInputPortHostFriendly(
- const NodeDef& node, int port_id,
- OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
+bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
// If node is on Host, assume its inputs are Host friendly.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
return true;
}
- // Check `op_device_inport_pinned_to_host_cache` for our
- // {op, device, port_id} combo to see if the arg is pinned on Host.
- std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
- auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
- if (it != op_device_inport_pinned_to_host_cache->end()) {
- return it->second;
- }
-
// Check if op's input port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
- op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
@@ -232,20 +192,16 @@ bool IsNodeInputPortHostFriendly(
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
- op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
// Check if the input_arg is pinned to Host.
for (const string& host_memory_arg : kernel->host_memory_arg()) {
if (op->input_arg(input_arg_id).name() == host_memory_arg) {
- op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
return true;
}
}
- op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
-
return false;
}
@@ -255,29 +211,38 @@ bool IsNodeInputPortHostFriendly(
// 2] Check if node can run on Host.
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
// ints, and pinned to Host).
-Status IsNodeHostCandidate(
- const GraphView& graph, GraphProperties* properties, const NodeDef& node,
- OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
- bool* is_candidate) {
+Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
+ const NodeDef& node, bool* is_candidate) {
*is_candidate = false;
- // Skip these node types.
- if (IsBlacklisted(node)) {
- return Status::OK();
- }
-
// Check if node already on CPU.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
*is_candidate = true;
return Status::OK();
}
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
+
// Check the node can be run on CPU.
Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
if (!s.ok()) {
return Status::OK();
}
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+
// Check all outputs are Host friendly.
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
@@ -290,42 +255,16 @@ Status IsNodeHostCandidate(
}
}
- // Check all inputs are Host friendly.
- for (const GraphView::OutputPort& fanin :
- graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
- bool fanin_candidate = false;
- TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id,
- op_device_outport_pinned_to_host_cache, &fanin_candidate));
- if (!fanin_candidate) {
- return Status::OK();
- }
- }
-
*is_candidate = true;
return Status::OK();
}
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
- return false;
-}
-} // end namespace
-
-// Tries to swap `device` to a Host device from `devices`. Returns true iff
-// there was a swap.
-bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, string* device) {
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device) {
// Force this node onto the CPU.
- if (device->empty() && has_device_cpu) {
- *device = "/device:CPU:0";
- return true;
- } else if (str_util::StrContains(*device, DEVICE_GPU)) {
+ if (device.empty() && has_device_cpu) {
+ return "/device:CPU:0";
+ } else if (str_util::StrContains(device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -333,19 +272,27 @@ bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device->substr(0, device->rfind(device_match.first)),
+ strings::StrCat(device.substr(0, device.rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- *device = device_host;
- return true;
+ return device_host;
}
}
}
- // We couldn't find an appropriate Host device, return false.
- return false;
+ // We couldn't find an appropriate Host device, return original device.
+ return device;
}
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -377,26 +324,20 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
- // Cache to map {op, device, port} -> bool on whether it is pinned to host.
- internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
- internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
-
for (auto& node : *optimized_graph->mutable_node()) {
bool is_candidate = false;
- TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
- graph, &properties, node, &op_device_outport_pinned_to_host_cache,
- &is_candidate));
+ TF_RETURN_IF_ERROR(
+ internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
if (!is_candidate) {
continue;
}
- const string original_device = node.device();
- const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
- node.mutable_device());
- // Keep track of all Const nodes that we swapped.
- if (swapped && IsConstant(node)) {
- const_nodes.emplace_back(&node, original_device);
+ if (IsConstant(node)) {
+ const_nodes.emplace_back(&node, node.device());
}
+ // Try and swap the device to Host.
+ node.set_device(
+ internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -408,9 +349,8 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
// The consumer is not Host friendly, swap it back to the original device.
- if (!internal::IsNodeInputPortHostFriendly(
- *fanout.node, fanout.port_id,
- &op_device_inport_pinned_to_host_cache)) {
+ if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
+ fanout.port_id)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index bed4a9ef95..d557a03463 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, string* device);
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 9bb030b220..7c64529441 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,60 +28,30 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
+TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
gtl::FlatSet<string> devices = {};
-
- string device = "ABC";
- EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_EQ(device, "ABC");
-}
-
-TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
- gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
-
- string device = "";
- EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
- EXPECT_EQ(device, "/device:CPU:0");
-
- device = "/device:XLA_GPU:0";
- EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
- EXPECT_EQ(device, "/device:CPU:0");
-
- device = "/device:XLA_GPU:*";
- EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
- EXPECT_EQ(device, "/device:CPU:0");
-}
-
-TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
- gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
-
- string device = "";
- EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_TRUE(device.empty());
-
- device = "/device:XLA_GPU:0";
- EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_EQ(device, "/device:XLA_CPU:0");
-
- device = "/device:XLA_GPU:*";
- EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_EQ(device, "/device:XLA_CPU:0");
-}
-
-TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
- gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
-
- string device = "";
- EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_TRUE(device.empty());
-
- device = "/device:XLA_GPU:0";
- EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_EQ(device, "/device:XLA_GPU:0");
-
- device = "/device:XLA_GPU:*";
- EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
- EXPECT_EQ(device, "/device:XLA_GPU:*");
+ EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
+
+ devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
+ "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
+ "/device:CPU:0");
+
+ devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_CPU:0");
+
+ devices = {"/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_GPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {