aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc60
1 files changed, 48 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index f162d52d3c..e8eaf54949 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -118,6 +119,7 @@ class HloParser {
// Types of attributes.
enum class AttrTy {
+ kBool,
kInt64,
kInt32,
kFloat,
@@ -490,6 +492,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateConstant(std::move(literal)));
break;
}
+ case HloOpcode::kIota: {
+ if (!ParseOperands(&operands, /*expected_size=*/0) ||
+ !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateIota(shape));
+ break;
+ }
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
@@ -680,18 +690,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kRecv: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
+ // If the is_host_transfer attribute is not present then default to false.
instruction = builder->AddInstruction(HloInstruction::CreateRecv(
- shape.tuple_shapes(0), operands[0], *channel_id));
+ shape.tuple_shapes(0), operands[0], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kRecvDone: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
@@ -699,24 +718,32 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (channel_id != operands[0]->channel_id()) {
return false;
}
- instruction =
- builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateRecvDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kSend: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/2) ||
!ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateSend(operands[0], operands[1], *channel_id));
+ instruction = builder->AddInstruction(HloInstruction::CreateSend(
+ operands[0], operands[1], *channel_id, *is_host_transfer));
break;
}
case HloOpcode::kSendDone: {
optional<tensorflow::int64> channel_id;
+ // If the is_host_transfer attribute is not present then default to false.
+ optional<bool> is_host_transfer = false;
attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ attrs["is_host_transfer"] = {/*required=*/false, AttrTy::kBool,
+ &is_host_transfer};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
@@ -724,8 +751,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (channel_id != operands[0]->channel_id()) {
return false;
}
- instruction =
- builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateSendDone(operands[0], *is_host_transfer));
break;
}
case HloOpcode::kGetTupleElement: {
@@ -1192,11 +1219,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
return false;
}
- GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers(
- /*output_window_dims=*/*output_window_dims,
- /*elided_window_dims=*/*elided_window_dims,
- /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
- /*index_vector_dim=*/*index_vector_dim);
+ GatherDimensionNumbers dim_numbers =
+ HloGatherInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/*output_window_dims,
+ /*elided_window_dims=*/*elided_window_dims,
+ /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims,
+ /*index_vector_dim=*/*index_vector_dim);
instruction = builder->AddInstruction(HloInstruction::CreateGather(
shape, /*operand=*/operands[0], /*gather_indices=*/operands[1],
@@ -2041,6 +2069,14 @@ bool HloParser::ParseAttributeHelper(
bool success = [&] {
LocTy attr_loc = lexer_.GetLoc();
switch (attr_type) {
+ case AttrTy::kBool: {
+ bool result;
+ if (!ParseBool(&result)) {
+ return false;
+ }
+ static_cast<optional<bool>*>(attr_out_ptr)->emplace(result);
+ return true;
+ }
case AttrTy::kInt64: {
tensorflow::int64 result;
if (!ParseInt64(&result)) {