diff options
author | 2018-08-22 11:37:18 -0700 | |
---|---|---|
committer | 2018-08-22 11:37:18 -0700 | |
commit | 792aa3d576242b83ad38c36f7e50fc4f5714b561 (patch) | |
tree | 44f7a3338e92703e3501fea16ba42ec11f4e67ab | |
parent | 6528b69885fa00c21db648c004be93b823d36d0d (diff) | |
parent | 1f6f5e3b9fe092c5218b872648a5fb65334a2af8 (diff) |
Merge remote-tracking branch 'upstream/master'
28 files changed, 296 insertions, 150 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 175a571ddb..2027ec7737 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -124,11 +124,11 @@ void XlaTransferManager::TransferLiteralFromDevice( TensorReference ref(device_tensor); transfer_manager_->TransferLiteralFromDevice( device_to_host_stream_.get(), shaped_buffer, literal, - [=, &shaped_buffer, &literal](xla::Status status) { + [=, &shaped_buffer](xla::Status status) { ref.Unref(); done([&]() -> Status { - VLOG(1) << "Transfer from device as literal: " << literal.ToString() - << " " << shaped_buffer.ToString(); + VLOG(1) << "Transfer from device as literal: " + << shaped_buffer.ToString(); return status; }()); }); diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 7bc6e8d860..4e7ef66dc5 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1881,7 +1881,7 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, - tensorflow::gtl::ArraySlice<int64> replica_group_ids) { + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); @@ -1889,14 +1889,14 @@ XlaOp XlaBuilder::CrossReplicaSum( b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); TF_ASSIGN_OR_RETURN(auto computation, b->Build()); - return CrossReplicaSum(operand, computation, replica_group_ids, + return CrossReplicaSum(operand, computation, replica_groups, /*channel_id=*/absl::nullopt); }); } XlaOp XlaBuilder::CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups, const absl::optional<ChannelHandle>& channel_id) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; @@ -1904,8 +1904,9 @@ XlaOp XlaBuilder::CrossReplicaSum( TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); - for (int64 replica_group_id : replica_group_ids) { - instr.add_replica_group_ids(replica_group_id); + + for (const ReplicaGroup& group : replica_groups) { + *instr.add_replica_groups() = group; } if (channel_id.has_value()) { @@ -2767,16 +2768,17 @@ XlaOp ReduceWindowWithGeneralPadding( padding); } -XlaOp CrossReplicaSum(const XlaOp& operand, - tensorflow::gtl::ArraySlice<int64> replica_group_ids) { - return operand.builder()->CrossReplicaSum(operand, replica_group_ids); +XlaOp CrossReplicaSum( + const XlaOp& operand, + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups) { + return operand.builder()->CrossReplicaSum(operand, replica_groups); } XlaOp CrossReplicaSum(const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups, const absl::optional<ChannelHandle>& channel_id) { return operand.builder()->CrossReplicaSum(operand, computation, - replica_group_ids, channel_id); + replica_groups, channel_id); } XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 8d9ec9a18a..adb62f5f02 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -685,7 +685,7 @@ class XlaBuilder { // sum for each subgroup. XlaOp CrossReplicaSum( const XlaOp& operand, - tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}); + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -694,10 +694,11 @@ class XlaBuilder { // scalars, e.g., add, min, or max. The way that AllReduce is applied is // configured by: // - // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, - // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // - `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group. Allreduce will be applied within + // subgroups. For example, we have 4 replicas, then + // replica_groups={{0,2},{1,3}} means, replica 0 and 2 are in subgroup 0, + // replica 1 and 3 are in subgroup 1. // // - `channel_id`: for Allreduce nodes from different modules, if they have // the same channel_id, they will be 'Allreduce'd. If empty, Allreduce will @@ -706,7 +707,7 @@ class XlaBuilder { // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}, + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {}, const absl::optional<ChannelHandle>& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. @@ -1253,10 +1254,10 @@ class XlaBuilder { tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding); friend XlaOp CrossReplicaSum( const XlaOp& operand, - tensorflow::gtl::ArraySlice<int64> replica_group_ids); + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups); friend XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups, const absl::optional<ChannelHandle>& channel_id); friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension, int64 concat_dimension, int64 split_count, @@ -1833,7 +1834,7 @@ XlaOp ReduceWindowWithGeneralPadding( // sum for each subgroup. XlaOp CrossReplicaSum( const XlaOp& operand, - tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}); + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {}); // Enqueues an operation that do an AllReduce of the operand cross cores. Here // AllReduce means doing a reduction on the input operand cross cores and then @@ -1842,10 +1843,10 @@ XlaOp CrossReplicaSum( // scalars, e.g., add, min, or max. The way that AllReduce is applied is // configured by: // -// - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all -// replicas belong to one group. Allreduce will be applied within subgroups. -// For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, -// replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. +// - `replica_groups`: each ReplicaGroup contains a list of replica id. If +// empty, all replicas belong to one group. Allreduce will be applied within +// subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} +// means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // - `channel_id`: for Allreduce nodes from different modules, if they have the // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be @@ -1854,7 +1855,7 @@ XlaOp CrossReplicaSum( // TODO(b/79737069): Rename this to AllReduce when it's ready to use. XlaOp CrossReplicaSum( const XlaOp& operand, const XlaComputation& computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}, + tensorflow::gtl::ArraySlice<ReplicaGroup> replica_groups = {}, const absl::optional<ChannelHandle>& channel_id = absl::nullopt); // Enqueues an operation that do an Alltoall of the operand cross cores. diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index a44756e136..6363a21c3b 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -235,7 +235,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, - sum, /*replica_group_ids=*/{}, /*barrier=*/"", + sum, /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 303ceac2e0..49ae5320b0 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -251,7 +251,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction, - /*replica_group_ids=*/{}, /*barrier=*/"", + /*replica_groups=*/{}, /*barrier=*/"", /*all_reduce_id=*/absl::nullopt)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 12b609a60f..821c599863 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -46,6 +46,8 @@ message HloInstructionProto { reserved "control_predecessor_names"; reserved 6; reserved "called_computation_names"; + reserved 44; + reserved "replica_group_ids"; string name = 1; string opcode = 2; @@ -158,9 +160,6 @@ message HloInstructionProto { string backend_config = 43; // Cross replica op fields. - // TODO(b/112107579): remove replica_group_ids field and always use - // replica_groups. - repeated int64 replica_group_ids = 44; repeated ReplicaGroup replica_groups = 49; int64 all_reduce_id = 45; string cross_replica_sum_barrier = 46; diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc index 78955db0da..af904647f8 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.cc @@ -31,12 +31,12 @@ class HloDomainIsolator::RunContext { StatusOr<bool> Run(); private: - // Inserts a kDomain instruction between parent and operand, in case - // the attribute (ie, sharding) values change between instruction and operand. + // Inserts a kDomain instruction between operand and instruction in case + // the attribute (ie, sharding) values change between root and instruction. // Returns the newly inserted kDomain instruction, or nullptr if no kDomain // instruction was necessary. StatusOr<HloInstruction*> CreateDomain(HloInstruction* instruction, - HloInstruction* parent, + HloInstruction* root, HloInstruction* operand); HloModule* module_; @@ -44,14 +44,14 @@ class HloDomainIsolator::RunContext { }; StatusOr<HloInstruction*> HloDomainIsolator::RunContext::CreateDomain( - HloInstruction* instruction, HloInstruction* parent, + HloInstruction* instruction, HloInstruction* root, HloInstruction* operand) { HloInstruction* domain = nullptr; std::unique_ptr<HloInstruction> domain_instruction = - isolator_->creator_(instruction, operand); + isolator_->creator_(instruction, root, operand); if (domain_instruction != nullptr) { domain = operand->parent()->AddInstruction(std::move(domain_instruction)); - TF_RETURN_IF_ERROR(operand->ReplaceUseWith(parent, domain)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain)); } return domain; } @@ -71,14 +71,13 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() { // When applying multiple domains, we could end up stacking more than // one in one edge, so here we want to build the effective // (kDomain-less) instruction->operand edge. - HloInstruction* parent = instruction; - while (operand->opcode() == HloOpcode::kDomain) { - parent = operand; - operand = operand->mutable_operand(0); + HloInstruction* root = operand; + while (root->opcode() == HloOpcode::kDomain) { + root = root->mutable_operand(0); } // Check whether a kDomain is necessary between instruction and operand. TF_ASSIGN_OR_RETURN(HloInstruction * domain, - CreateDomain(instruction, parent, operand)); + CreateDomain(instruction, root, operand)); if (domain != nullptr) { VLOG(4) << "New domain: " << domain->ToString(); ++added_domains; diff --git a/tensorflow/compiler/xla/service/hlo_domain_isolator.h b/tensorflow/compiler/xla/service/hlo_domain_isolator.h index eded3e78ee..bb1537766c 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_isolator.h +++ b/tensorflow/compiler/xla/service/hlo_domain_isolator.h @@ -34,10 +34,12 @@ class HloDomainIsolator : public HloPassInterface { public: // Creates a new kDomain instruction for the edge between the use instruction // (the first HloInstruction argument), and the operand instruction (the - // second HloInstruction argument). + // third HloInstruction argument) if the interesting attribute of the + // instruction differes from the attribute of the root (the second + // HloInstruction argument). // Returns nullptr in case no domain separation is necessary. using DomainCreator = std::function<std::unique_ptr<HloInstruction>( - HloInstruction*, HloInstruction*)>; + HloInstruction*, HloInstruction*, HloInstruction*)>; explicit HloDomainIsolator(DomainCreator creator); diff --git a/tensorflow/compiler/xla/service/hlo_domain_test.cc b/tensorflow/compiler/xla/service/hlo_domain_test.cc index 7d48be15cf..2654929bf0 100644 --- a/tensorflow/compiler/xla/service/hlo_domain_test.cc +++ b/tensorflow/compiler/xla/service/hlo_domain_test.cc @@ -106,12 +106,13 @@ class OpNameMetadata : public DomainMetadata { // Creator function for OpNameMetadata domains. std::unique_ptr<HloInstruction> OpNameDomainCreator(HloInstruction* instruction, + HloInstruction* root, HloInstruction* operand) { - if (instruction->metadata().op_name() == operand->metadata().op_name()) { + if (instruction->metadata().op_name() == root->metadata().op_name()) { return nullptr; } std::unique_ptr<DomainMetadata> operand_side_metadata = - absl::make_unique<OpNameMetadata>(operand->metadata().op_name()); + absl::make_unique<OpNameMetadata>(root->metadata().op_name()); std::unique_ptr<DomainMetadata> user_side_metadata = absl::make_unique<OpNameMetadata>(instruction->metadata().op_name()); return HloInstruction::CreateDomain(operand->shape(), operand, @@ -524,5 +525,64 @@ ENTRY entry { tpl->sharding()); } +TEST_F(HloDomainTest, MultiDomainMultiUser) { + const char* const hlo_string = R"( + HloModule Module + +ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) { + %p0 = (f32[4], f32[4]) parameter(0) + %a = f32[4]{0} get-tuple-element(%p0), index=0 + %domain = f32[4] domain(%a), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %b = f32[4] get-tuple-element(%p0), index=1 + %domain.1 = f32[4] domain(%b), + domain={kind="sharding", entry={maximal device=1}, exit={maximal device=0}} + %c = f32[4] add(%domain, %domain.1), sharding={maximal device=1} + %domain.2 = f32[4] domain(%c), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %d = f32[4] subtract(%domain, %c), + sharding={maximal device=1}, metadata={op_name="D"} + %domain.3 = f32[4] domain(%d), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + %e = f32[4] multiply(%c, %d), + sharding={maximal device=1}, metadata={op_name="D"} + %f = f32[4] add(f32[4]{0} %e, f32[4]{0} %c), sharding={maximal device=1} + %domain.4 = f32[4]{0} domain(%f), + domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}} + ROOT %g = (f32[4], f32[4], f32[4]) tuple(%domain.2, %domain.3, %domain.4) +})"; + + TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string)); + LOG(INFO) << "Original module:\n" << module->ToString(); + + HloDomainIsolator opname_isolator(OpNameDomainCreator); + TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed, + opname_isolator.Run(module)); + EXPECT_TRUE(opname_isolator_changed); + + EXPECT_TRUE(HasDomainEdge(module, "c", "a")); + EXPECT_TRUE(HasDomainEdge(module, "c", "b")); + EXPECT_TRUE(HasDomainEdge(module, "d", "a")); + EXPECT_TRUE(HasDomainEdge(module, "d", "c")); + EXPECT_FALSE(HasDomainEdge(module, "e", "d")); + + HloDomainRemover sharding_remover(ShardingMetadata::KindName(), + ShardingMetadata::NormalizeShardingDomain); + TF_ASSERT_OK_AND_ASSIGN(bool sharding_remover_changed, + sharding_remover.Run(module)); + EXPECT_TRUE(sharding_remover_changed); + + HloDomainRemover opname_remover(OpNameMetadata::KindName(), + OpNameDomainNormalizer); + TF_ASSERT_OK_AND_ASSIGN(bool opname_remover_changed, + opname_remover.Run(module)); + EXPECT_TRUE(opname_remover_changed); + + EXPECT_FALSE(HasDomainEdge(module, "c", "a")); + EXPECT_FALSE(HasDomainEdge(module, "c", "b")); + EXPECT_FALSE(HasDomainEdge(module, "d", "a")); + EXPECT_FALSE(HasDomainEdge(module, "d", "c")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 9d795da100..a211167519 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -302,9 +302,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } instruction = CreateCrossReplicaSum( proto.shape(), all_operands(), computations(0), - /*replica_group_ids=*/ - std::vector<int64>(proto.replica_group_ids().begin(), - proto.replica_group_ids().end()), + /*replica_groups=*/ + std::vector<ReplicaGroup>(proto.replica_groups().begin(), + proto.replica_groups().end()), /*barrier=*/proto.cross_replica_sum_barrier(), /*all_reduce_id=*/all_reduce_id); break; @@ -665,11 +665,11 @@ HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction::CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const std::vector<ReplicaGroup>& replica_groups, tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) { return absl::make_unique<HloAllReduceInstruction>( - shape, operands, reduce_computation, replica_group_ids, barrier, + shape, operands, reduce_computation, replica_groups, barrier, all_reduce_id); } @@ -3184,11 +3184,10 @@ const string& HloInstruction::outfeed_config() const { return Cast<HloOutfeedInstruction>(this)->outfeed_config(); } -const std::vector<int64>& HloInstruction::replica_group_ids() const { - return Cast<HloAllReduceInstruction>(this)->replica_group_ids(); -} - const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const { + if (opcode() == HloOpcode::kCrossReplicaSum) { + return Cast<HloAllReduceInstruction>(this)->replica_groups(); + } return Cast<HloAllToAllInstruction>(this)->replica_groups(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 21710bd31d..fdd34544eb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -435,9 +435,10 @@ class HloInstruction { // // `reduction_computation`: the reduction function. // - // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all - // replicas belong to one group. Allreduce will be applied within subgroups. - // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // `replica_groups`: each ReplicaGroup contains a list of replica id. If + // empty, all replicas belong to one group in the order of 0 - (n-1). + // Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means, // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. // // `all_reduce_id`: for Allreduce nodes from different modules, if they have @@ -448,7 +449,7 @@ class HloInstruction { static std::unique_ptr<HloInstruction> CreateCrossReplicaSum( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const std::vector<ReplicaGroup>& replica_groups, tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id); @@ -1439,9 +1440,6 @@ class HloInstruction { // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; - // Delegates to HloAllReduceInstruction::replica_group_ids. - const std::vector<int64>& replica_group_ids() const; - // Delegates to HloAllToAllInstruction::replica_groups. const std::vector<ReplicaGroup>& replica_groups() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index dbafa35b2a..36fac4a266 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -300,10 +300,10 @@ HloRecvDoneInstruction::CloneWithNewOperandsImpl( HloAllReduceInstruction::HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const std::vector<ReplicaGroup>& replica_groups, tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id) : HloInstruction(HloOpcode::kCrossReplicaSum, shape), - replica_group_ids_(replica_group_ids.begin(), replica_group_ids.end()), + replica_groups_(replica_groups), cross_replica_sum_barrier_(barrier.begin(), barrier.end()), all_reduce_id_(all_reduce_id) { for (auto operand : operands) { @@ -314,9 +314,8 @@ HloAllReduceInstruction::HloAllReduceInstruction( HloInstructionProto HloAllReduceInstruction::ToProto() const { HloInstructionProto proto = HloInstruction::ToProto(); - for (int64 i : replica_group_ids_) { - proto.add_replica_group_ids(i); - } + *proto.mutable_replica_groups() = {replica_groups_.begin(), + replica_groups_.end()}; // Proto3 is so sad. if (all_reduce_id_) { proto.set_all_reduce_id(*all_reduce_id_); @@ -327,8 +326,14 @@ HloInstructionProto HloAllReduceInstruction::ToProto() const { std::vector<string> HloAllReduceInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& /*options*/) const { - std::vector<string> result = { - StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}")}; + std::vector<string> result; + std::vector<string> replica_group_str; + for (const ReplicaGroup& group : replica_groups()) { + replica_group_str.push_back( + StrCat("{", Join(group.replica_ids(), ","), "}")); + } + result.push_back( + StrCat("replica_groups={", Join(replica_group_str, ","), "}")); if (!cross_replica_sum_barrier().empty()) { result.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\"")); } @@ -343,7 +348,11 @@ bool HloAllReduceInstruction::IdenticalSlowPath( const std::function<bool(const HloComputation*, const HloComputation*)>& eq_computations) const { const auto& casted_other = static_cast<const HloAllReduceInstruction&>(other); - return replica_group_ids() == casted_other.replica_group_ids() && + return ContainersEqual(replica_groups(), casted_other.replica_groups(), + [](const ReplicaGroup& a, const ReplicaGroup& b) { + return ContainersEqual(a.replica_ids(), + b.replica_ids()); + }) && eq_computations(to_apply(), casted_other.to_apply()) && cross_replica_sum_barrier() == casted_other.cross_replica_sum_barrier() && @@ -356,7 +365,7 @@ HloAllReduceInstruction::CloneWithNewOperandsImpl( tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* /*context*/) const { return absl::make_unique<HloAllReduceInstruction>( - shape, new_operands, to_apply(), replica_group_ids(), + shape, new_operands, to_apply(), replica_groups(), cross_replica_sum_barrier(), all_reduce_id()); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 93e4c21b2f..0a6a0c6233 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -223,13 +223,12 @@ class HloAllReduceInstruction : public HloInstruction { explicit HloAllReduceInstruction( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, HloComputation* reduce_computation, - tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const std::vector<ReplicaGroup>& replica_groups, tensorflow::StringPiece barrier, const absl::optional<int64>& all_reduce_id); - // Returns the group ids of each replica for CrossReplicaSum op. - const std::vector<int64>& replica_group_ids() const { - return replica_group_ids_; + const std::vector<ReplicaGroup>& replica_groups() const { + return replica_groups_; } // Returns the barrier config used for the CrossReplicaSum implementation of @@ -260,8 +259,8 @@ class HloAllReduceInstruction : public HloInstruction { tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, HloCloneContext* context) const override; - // The group id of each replica for CrossReplicaSum. - std::vector<int64> replica_group_ids_; + // The replica ids of each subgroup for CrossReplicaSum op. + std::vector<ReplicaGroup> replica_groups_; // The string representation of the barrier config used for CrossReplicaSum. string cross_replica_sum_barrier_; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ede55510d3..beef96476c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -293,6 +293,20 @@ class HloParser { missing_instruction_hook_; }; +// Creates replica groups from the provided nested array. groups[i] represents +// the replica ids for group 'i'. +std::vector<ReplicaGroup> CreateReplicaGroups( + tensorflow::gtl::ArraySlice<std::vector<int64>> groups) { + std::vector<ReplicaGroup> replica_groups; + absl::c_transform(groups, std::back_inserter(replica_groups), + [](const std::vector<int64>& ids) { + ReplicaGroup group; + *group.mutable_replica_ids() = {ids.begin(), ids.end()}; + return group; + }); + return replica_groups; +} + bool HloParser::Error(LocTy loc, StringPiece msg) { auto line_col = lexer_.GetLineAndColumn(loc); const unsigned line = line_col.first; @@ -637,31 +651,29 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional<std::vector<std::vector<int64>>> tmp_groups; optional<HloComputation*> to_apply; optional<std::vector<int64>> replica_group_ids; optional<string> barrier; optional<int64> all_reduce_id; attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, &to_apply}; - attrs["replica_group_ids"] = { - /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids}; + attrs["replica_groups"] = {/*required=*/false, + AttrTy::kBracedInt64ListList, &tmp_groups}; attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier}; attrs["all_reduce_id"] = {/*required=*/false, AttrTy::kInt64, &all_reduce_id}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - if (replica_group_ids) { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, *replica_group_ids, - barrier ? *barrier : "", all_reduce_id)); - } else { - instruction = - builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( - shape, operands, *to_apply, {}, barrier ? *barrier : "", - all_reduce_id)); + std::vector<ReplicaGroup> replica_groups; + if (tmp_groups) { + replica_groups = CreateReplicaGroups(*tmp_groups); } + instruction = + builder->AddInstruction(HloInstruction::CreateCrossReplicaSum( + shape, operands, *to_apply, replica_groups, + barrier ? *barrier : "", all_reduce_id)); break; } case HloOpcode::kAllToAll: { @@ -675,13 +687,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } std::vector<ReplicaGroup> replica_groups; if (tmp_groups) { - absl::c_transform( - *tmp_groups, std::back_inserter(replica_groups), - [](const std::vector<int64>& ids) { - ReplicaGroup group; - *group.mutable_replica_ids() = {ids.begin(), ids.end()}; - return group; - }); + replica_groups = CreateReplicaGroups(*tmp_groups); } instruction = builder->AddInstruction(HloInstruction::CreateAllToAll( shape, operands, replica_groups, barrier ? *barrier : "")); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 5a1993a3bb..f52cfadb81 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1052,7 +1052,7 @@ add { ENTRY CRS { input = f32[8]{0} parameter(0) - ROOT crs = f32[8]{0} cross-replica-sum(input), replica_group_ids={}, to_apply=add + ROOT crs = f32[8]{0} cross-replica-sum(input), replica_groups={}, to_apply=add } )" @@ -1070,7 +1070,7 @@ add { ENTRY CrossReplicaSumWithSubgroups { input = f32[128,32]{0,1} parameter(0) - ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_group_ids={0,0,1,1}, barrier="abc", to_apply=add + ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add } )" diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc index 4e19557f82..6f0353ee5f 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc @@ -284,18 +284,19 @@ Status ApplyDomainSharding(const DomainMetadata::Domain& domain, // The kDomain instruction will be created only if the sharding differ between // the instruction and the operand. std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction, + HloInstruction* root, HloInstruction* operand) { const HloSharding* instruction_sharding = instruction->has_sharding() ? &instruction->sharding() : nullptr; - const HloSharding* operand_sharding = - operand->has_sharding() ? &operand->sharding() : nullptr; + const HloSharding* root_sharding = + root->has_sharding() ? &root->sharding() : nullptr; // No need for domain if they both have no sharding. - if (instruction_sharding == nullptr && operand_sharding == nullptr) { + if (instruction_sharding == nullptr && root_sharding == nullptr) { return nullptr; } // No need for domain if they match. - if (instruction_sharding != nullptr && operand_sharding != nullptr && - ShardingMatches(*instruction_sharding, *operand_sharding)) { + if (instruction_sharding != nullptr && root_sharding != nullptr && + ShardingMatches(*instruction_sharding, *root_sharding)) { return nullptr; } std::unique_ptr<HloSharding> real_instruction_sharding; @@ -303,8 +304,8 @@ std::unique_ptr<HloInstruction> CreateDomain(HloInstruction* instruction, if (instruction_sharding != nullptr) { real_instruction_sharding = CloneShardingForDomain(*instruction_sharding); } - if (operand_sharding != nullptr) { - real_operand_sharding = CloneShardingForDomain(*operand_sharding); + if (root_sharding != nullptr) { + real_operand_sharding = CloneShardingForDomain(*root_sharding); } VLOG(3) << "Creating domain:"; VLOG(3) << " Instruction: " << instruction->name(); @@ -417,8 +418,9 @@ Status ShardingMetadata::NormalizeShardingDomain( } std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand) { - return CreateDomain(instruction, operand); + HloInstruction* instruction, HloInstruction* root, + HloInstruction* operand) { + return CreateDomain(instruction, root, operand); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h index 5e01fc0e22..dc258e4094 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_sharding_metadata.h @@ -60,10 +60,10 @@ class ShardingMetadata : public DomainMetadata { // Given an HLO graph edge between instruction and one of its operands, creates // a ShardingMetadata based kDomain instruction if the sharding between -// instruction and operand changes. Returns nullptr if there is no need for a +// instruction and parent changes. Returns nullptr if there is no need for a // domain separation. std::unique_ptr<HloInstruction> CreateShardingDomain( - HloInstruction* instruction, HloInstruction* operand); + HloInstruction* instruction, HloInstruction* root, HloInstruction* operand); } // namespace xla diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py index 5a5a6ad63a..f7dd3183b0 100644 --- a/tensorflow/contrib/autograph/converters/control_flow.py +++ b/tensorflow/contrib/autograph/converters/control_flow.py @@ -95,6 +95,18 @@ class ControlFlowTransformer(converter.Base): return 'no variables' return ', '.join(map(str, symbol_set)) + def _validate_no_live_vars_created(self, node): + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) + live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) + live_vars_created_in_body = live_vars_out & body_scope.created + if live_vars_created_in_body: + raise ValueError( + 'The following variables are created inside the loop and used later:' + '\n%s\n' + 'Variables must be declared outside loops because loops may not' + ' necessarily execute.' % self._fmt_symbol_list( + live_vars_created_in_body)) + def visit_If(self, node): node = self.generic_visit(node) @@ -197,6 +209,8 @@ class ControlFlowTransformer(converter.Base): def visit_While(self, node): self.generic_visit(node) + self._validate_no_live_vars_created(node) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced @@ -262,6 +276,8 @@ class ControlFlowTransformer(converter.Base): def visit_For(self, node): self.generic_visit(node) + self._validate_no_live_vars_created(node) + body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) body_closure = body_scope.modified - body_scope.created all_referenced = body_scope.referenced diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py index 6cb907f69a..02bc00dbc8 100644 --- a/tensorflow/contrib/autograph/converters/control_flow_test.py +++ b/tensorflow/contrib/autograph/converters/control_flow_test.py @@ -57,6 +57,17 @@ class ControlFlowTest(converter_testing.TestCase): self.assertTransformedResult(test_fn, constant_op.constant(5), 0) + def test_while_variable_defined_in_body(self): + def bad_while_loop(n): + while n > 0: + n -= 1 + s = n + return s + + node, ctx = self.prepare(bad_while_loop, {}) + with self.assertRaises(transformer.AutographParseError): + control_flow.transform(node, ctx) + def test_if_basic(self): def test_fn(n): @@ -196,6 +207,15 @@ class ControlFlowTest(converter_testing.TestCase): self.assertEqual(result.test_fn(5), 10) self.assertEqual(eval_count[0], 1) + def test_for_variable_defined_in_body(self): + def bad_for_loop(n): + for i in range(n): + s = i + return s + + node, ctx = self.prepare(bad_for_loop, {}) + with self.assertRaises(transformer.AutographParseError): + control_flow.transform(node, ctx) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 58a94ff4a5..dcad77ccbb 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -639,9 +639,9 @@ table SubGraph { } // Table of raw data buffers (used for constant tensors). Referenced by tensors -// by index. +// by index. The generous alignment accommodates mmap-friendly data structures. table Buffer { - data:[ubyte]; + data:[ubyte] (force_align: 16); } table Model { diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index 35662e278f..313d976e2c 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -33,6 +33,7 @@ REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16, int64); REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double, complex64, complex128); +REGISTER2(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, float, double); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc index 0b05416274..25ccdcfb00 100644 --- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc @@ -21,6 +21,7 @@ namespace tensorflow { namespace functor { DEFINE_BINARY10(div, Eigen::half, float, double, uint8, uint16, int16, int32, int64, complex64, complex128); +DEFINE_BINARY2(div_no_nan, float, double); } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 1f2e57e9a9..3d03bc1d5f 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -354,6 +354,27 @@ Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Transpose", TransposeGrad); +Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"params: Tparams", "indices: Tindices", "doutput: Tparams"}, + // Ret val defs + {"dparams: Tparams", "dindices: Tindices"}, + // Attr defs + {"Tparams: type", "Tindices: type"}, + // Nodes + { + {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}}, + {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"}, + {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}}, + {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad); + Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) { *g = FDH::Define( // Arg defs diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 380bcf763f..ca6aafd715 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -211,6 +211,18 @@ class FunctionTest(test.TestCase): random_seed.set_random_seed(1) self.assertAllEqual(f(), x) + def testSymGradGatherNd(self): + with ops.Graph().as_default(), self.test_session() as sess: + + @function.defun + def f(x): + return array_ops.gather_nd(x, [[0]]) + + c = constant_op.constant([[2.]]) + f_c = f(c) + g, = gradients_impl.gradients(f_c, c) + self.assertAllEqual(sess.run(g), [[1.0]]) + def testNestedInputsDefunOpGraphMode(self): matmul = function.defun(math_ops.matmul) diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py index dbcad323b9..290c4604ce 100644 --- a/tensorflow/python/estimator/keras_test.py +++ b/tensorflow/python/estimator/keras_test.py @@ -203,7 +203,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) before_eval_results = est_keras.evaluate( @@ -228,7 +228,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): metrics=['mse', keras.metrics.categorical_accuracy]) my_hook = MyHook() - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) before_eval_results = est_keras.evaluate( @@ -252,7 +252,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) my_hook = MyHook() - with self.cached_session(): + with self.test_session(): keras_model.fit(x_train, y_train, epochs=1) keras_est = keras_lib.model_to_estimator( @@ -274,7 +274,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) @@ -297,7 +297,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16) @@ -316,7 +316,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): # Create state keras_model.train_on_batch(np.random.random((10,) + _INPUT_SIZE), np.random.random((10, _NUM_CLASS))) @@ -343,7 +343,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): x_test, y_test), _, eval_input_fn = get_resource_for_simple_model( model_type='functional', is_evaluate=True) - with self.cached_session(): + with self.test_session(): metrics = [ 'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy', 'categorical_crossentropy', 'cosine_proximity', 'hinge', @@ -357,7 +357,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.fit(x_train, y_train, epochs=1) keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32) - with self.cached_session(): + with self.test_session(): keras_est = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) est_eval = keras_est.evaluate(input_fn=eval_input_fn) @@ -385,7 +385,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): x_test, _), _, pred_input_fn = get_resource_for_simple_model( model_type='sequential', is_evaluate=False) - with self.cached_session(): + with self.test_session(): keras_model.compile( loss='categorical_crossentropy', optimizer='adam', @@ -393,7 +393,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): keras_model.fit(x_train, y_train, epochs=1) keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)] - with self.cached_session(): + with self.test_session(): keras_est = keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) est_pred = [ @@ -439,7 +439,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): output_dict = {'dense_2': c_test, 'dense_3': d_test} return input_dict, output_dict - with self.cached_session(): + with self.test_session(): model = multi_inputs_multi_outputs_model() est_keras = keras_lib.model_to_estimator( keras_model=model, config=self._config) @@ -456,7 +456,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): x_test, _), _, pred_input_fn = get_resource_for_simple_model( model_type='functional', is_evaluate=False) - with self.cached_session(): + with self.test_session(): keras_model.compile( loss='categorical_crossentropy', optimizer='rmsprop', @@ -466,7 +466,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): fname = os.path.join(self._base_dir, 'keras_model.h5') keras.models.save_model(keras_model, fname) - with self.cached_session(): + with self.test_session(): keras_est = keras_lib.model_to_estimator( keras_model_path=fname, config=self._config) est_pred = [ @@ -479,19 +479,19 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'Either'): keras_lib.model_to_estimator() - with self.cached_session(): + with self.test_session(): keras_model = simple_sequential_model() with self.assertRaisesRegexp(ValueError, 'not both'): keras_lib.model_to_estimator( keras_model=keras_model, keras_model_path=tempfile.mkdtemp(dir=self._base_dir)) - with self.cached_session(): + with self.test_session(): keras_model = simple_sequential_model() with self.assertRaisesRegexp(ValueError, 'compiled'): keras_lib.model_to_estimator(keras_model=keras_model) - with self.cached_session(): + with self.test_session(): keras_model = simple_sequential_model() with self.assertRaisesRegexp(ValueError, 'not a local path'): keras_lib.model_to_estimator( @@ -516,10 +516,10 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): model = simple_functional_model() model.compile( loss='categorical_crossentropy', optimizer='adam', metrics=['acc']) - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=model, config=self._config) - with self.cached_session(): + with self.test_session(): with self.assertRaisesRegexp(KeyError, 'Difference: .*invalid_input_name'): est_keras.train(input_fn=invald_input_name_input_fn, steps=100) @@ -554,13 +554,13 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): num_epochs=None, batch_size=16) with self.assertRaisesRegexp(ValueError, 'relu6'): - with self.cached_session(): + with self.test_session(): est = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) est.train(input_fn=train_input_fn, steps=1) - with self.cached_session(): + with self.test_session(): est = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir), @@ -586,7 +586,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): } }) with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): - with self.cached_session(): + with self.test_session(): keras_lib.model_to_estimator( keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir)) @@ -602,7 +602,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3) sess_config = config_pb2.ConfigProto(gpu_options=gpu_options) self._config._session_config = sess_config - with self.cached_session(): + with self.test_session(): keras_lib.model_to_estimator( keras_model=keras_model, config=self._config) self.assertEqual( @@ -618,7 +618,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=self._base_dir, config=run_config_lib.RunConfig()) @@ -629,7 +629,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): self.assertEqual(self._base_dir, est_keras._config.model_dir) self.assertEqual(self._base_dir, est_keras._model_dir) - with self.cached_session(): + with self.test_session(): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, model_dir=self._base_dir, config=None) @@ -648,7 +648,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR): est_keras = keras_lib.model_to_estimator( keras_model=keras_model, @@ -663,7 +663,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): optimizer='rmsprop', metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in ' 'constructor and `RunConfig`'): keras_lib.model_to_estimator( @@ -676,7 +676,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): loss='categorical_crossentropy', optimizer=rmsprop.RMSPropOptimizer(1e-3), metrics=['mse', keras.metrics.categorical_accuracy]) - with self.cached_session(): + with self.test_session(): keras_model.train_on_batch( np.random.random((10,) + _INPUT_SIZE), np.random.random((10, _NUM_CLASS))) diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 4823ba541d..d1bdd9b80a 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -112,8 +112,6 @@ class SparseTensor(_TensorLike): values: A 1-D tensor of any type and shape `[N]`. dense_shape: A 1-D int64 tensor of shape `[ndims]`. - Returns: - A `SparseTensor`. """ with ops.name_scope(None, "SparseTensor", [indices, values, dense_shape]): diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py index 6bd41020c5..1b01d1d37f 100644 --- a/tensorflow/python/ops/math_ops_test.py +++ b/tensorflow/python/ops/math_ops_test.py @@ -483,7 +483,7 @@ class DivNoNanTest(test_util.TensorFlowTestCase): np_result = np.true_divide(nums, divs) np_result[:, divs[0] == 0] = 0 - with self.cached_session(): + with self.cached_session(use_gpu=True): tf_result = math_ops.div_no_nan(nums, divs).eval() self.assertAllEqual(tf_result, np_result) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 653e46fc41..090cf48a07 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -171,14 +171,15 @@ def write_docs(output_dir, os.path.join('/', from_path), os.path.join('/', to_path))) - redirects = sorted(redirects) - template = ('- from: {}\n' - ' to: {}\n') - redirects = [template.format(f, t) for f, t in redirects] - api_redirects_path = os.path.join(output_dir, '_redirects.yaml') - with open(api_redirects_path, 'w') as redirect_file: - redirect_file.write('redirects:\n') - redirect_file.write(''.join(redirects)) + if redirects: + redirects = sorted(redirects) + template = ('- from: {}\n' + ' to: {}\n') + redirects = [template.format(f, t) for f, t in redirects] + api_redirects_path = os.path.join(output_dir, '_redirects.yaml') + with open(api_redirects_path, 'w') as redirect_file: + redirect_file.write('redirects:\n') + redirect_file.write(''.join(redirects)) if yaml_toc: # Generate table of contents |