aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-02 13:08:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 13:13:22 -0700
commit78e4ce52aeda5a10ddaf5e64ea8958f439a2f9f2 (patch)
tree096cb5b28777e227053a50e622313cfe66e56744 /tensorflow/compiler/xla/service/hlo_instructions.cc
parent0a201955b47d484c6bfa149364c264a5b5f91be7 (diff)
Add proto serialization/deserialization testing to the HLO parser tests.
Many of the HLO parser tests verify that an text form of an HLO module preserves all information when running through ToString then parsing. It makes sense to also use these tests to exercise proto serialization/deserialization. This is done by adding additional instantiations of the parameterized parsing tests. This caught several bugs which are fixed in this CL: (1) Domain instructions were not being serialized properly. (2) Host send/recv instructions did not preserve the is_host_transfer bit. (3) Sparse literals could not be serialized or deserialized. PiperOrigin-RevId: 215445200
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 1bc168c8b7..68d0979f5c 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
@@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_channel_id(channel_id_);
+ proto.set_is_host_transfer(is_host_transfer_);
return proto;
}
@@ -2310,4 +2312,23 @@ std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
}
+
+HloInstructionProto HloDomainInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ auto operand_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
+ if (operand_side_sharding) {
+ *proto.mutable_domain_entry_sharding() =
+ operand_side_sharding->sharding()->ToProto();
+ }
+
+ auto user_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
+ if (user_side_sharding) {
+ *proto.mutable_domain_exit_sharding() =
+ user_side_sharding->sharding()->ToProto();
+ }
+
+ return proto;
+}
} // namespace xla