aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-08-21 10:25:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 10:33:51 -0700
commit938a3b77797164db736a1006a7656326240baa59 (patch)
tree73c2697cf3b7c8900694c5c3358f55096f0cdb82 /tensorflow/compiler/xla
parent792a933b113aa772b5ff5dbb6ef1892ffeb99063 (diff)
In HostCompute op, use SendToHost/RecvFromHost instead of Send/Recv.
PiperOrigin-RevId: 209617148
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc14
4 files changed, 18 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 01f273ad1f..7fdffe85c0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1098,6 +1098,7 @@ cc_library(
hdrs = ["hlo_module_group_util.h"],
deps = [
":hlo",
+ ":hlo_casting_utils",
":hlo_module_group_metadata",
":hlo_reachability",
"//tensorflow/compiler/xla:status",
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 3b512bf0f8..cd10913763 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -204,6 +204,10 @@ const HloModuleGroupMetadata::Channel& HloModuleGroupMetadata::GetChannel(
return channels_[channel_id_map_.at(channel_id)];
}
+bool HloModuleGroupMetadata::HasChannel(int64 channel_id) const {
+ return channel_id_map_.find(channel_id) != channel_id_map_.end();
+}
+
HloComputation* HloModuleGroupMetadata::PeerComputation(
const HloInstruction* instruction) const {
CHECK(IsChannelInstruction(instruction));
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 1b256cd00e..924c8fda71 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -125,6 +125,9 @@ class HloModuleGroupMetadata {
// Returns the Channel instance for the given channel id.
const Channel& GetChannel(int64 channel_id) const;
+ // Returns if the given channel id exists in metadata.
+ bool HasChannel(int64 channel_id) const;
+
// Returns the all-reduce instructions with the same all_reduce_id.
const std::vector<HloInstruction*>& GetAllReduceGroup(
int64 all_reduce_id) const;
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index 4f11ce322e..1a4da388e4 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <utility>
#include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -94,12 +96,14 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
add_unique_predecessor(control_predecessor);
}
}
- if (instruction->opcode() == HloOpcode::kRecvDone) {
+ if (instruction->opcode() == HloOpcode::kRecvDone &&
+ !DynCast<HloRecvDoneInstruction>(instruction)->is_host_transfer()) {
// Send is a remote predecessor of RecvDone.
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_predecessor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// Recv is a remote predecessor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;
@@ -170,14 +174,16 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
add_unique_successor(control_successor);
}
}
- if (instruction->opcode() == HloOpcode::kRecv) {
+ if (instruction->opcode() == HloOpcode::kRecv &&
+ !DynCast<HloRecvInstruction>(instruction)->is_host_transfer()) {
// Send is a remote successor of Recv.
const HloInstruction* recv_done = instruction->users().front();
CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
HloInstruction* send = metadata_.GetChannel(instruction->channel_id()).send;
add_unique_successor(send);
}
- if (instruction->opcode() == HloOpcode::kSend) {
+ if (instruction->opcode() == HloOpcode::kSend &&
+ !DynCast<HloSendInstruction>(instruction)->is_host_transfer()) {
// RecvDone is a remote successor of Send.
HloInstruction* recv_done =
metadata_.GetChannel(instruction->channel_id()).recv_done;