diff options
author | Mark Heffernan <meheff@google.com> | 2018-06-08 12:50:16 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-08 12:52:42 -0700 |
commit | 278fbe4146b160980fec318187546d9d8870d244 (patch) | |
tree | 402e929ee51c089458fa7bfac1ae6826f73d71cd /tensorflow | |
parent | 7bb79ee219d4efbd92d1ef4e0dbe45f4aee26654 (diff) |
Add kGenerateToken HLO instruction.
The new HLO instruction serves two purposes. (1) It generates a new token value. This is the only way to create tokens. (2) The operation is variadic, taking zero or more token operands. The operation acts as a join of its operands.
I considered initially using a kConstant constant as a method to create new tokens, but this ran into problems because of expectations in backends regarding constants and their materialization.
This CL enables creation of generate-token instructions, but the new instruction is not supported yet in any backend.
PiperOrigin-RevId: 199836205
Diffstat (limited to 'tensorflow')
19 files changed, 263 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 64678d9d74..ee2b455730 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -243,6 +243,8 @@ class DfsHloVisitorBase { virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; + virtual Status HandleGenerateToken(HloInstructionPtr token) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstructionPtr root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 240faebe62..6934e00a4b 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleGenerateToken(HloInstructionPtr token) override { + return DefaultAction(token); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index b9d30ee802..92a66681a9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -387,6 +387,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { auto lhs = convolution->operand(0); auto rhs = convolution->operand(1); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d17678d20f..0d66736fe1 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleGenerateToken(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 1e78d775c8..e0648e1467 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -910,6 +910,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } +Status HloEvaluator::HandleGenerateToken(HloInstruction* token) { + // Literals cannot represent a TOKEN shape so just create an empty tuple as + // the "result" of the kGenerateToken operation. + // TODO(b/109929053): Add support for TOKENs in Literals. + evaluated_[token] = Literal::MakeTuple({}); + return Status::OK(); +} + Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index b53d5644de..fc2fc9437b 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -174,6 +174,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleGenerateToken(HloInstruction* token) override; + // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be // returned directly without looking up the cache. diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index a6750460e5..cf954001c6 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -964,6 +964,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kBitcast: case HloOpcode::kGetTupleElement: case HloOpcode::kTrace: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: return kWhite; case HloOpcode::kBroadcast: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ae230d2740..a778a6a965 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -583,6 +583,17 @@ HloInstruction::CreateCrossReplicaSum( return MakeUnique<HloReverseInstruction>(shape, operand, dimensions); } +/* static */ std::unique_ptr<HloInstruction> +HloInstruction::CreateGenerateToken( + tensorflow::gtl::ArraySlice<HloInstruction*> operands) { + auto instruction = WrapUnique(new HloInstruction( + HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape())); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + return instruction; +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -1512,6 +1523,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(), user_side_metadata_->Clone()); break; + case HloOpcode::kGenerateToken: + clone = CreateGenerateToken(new_operands); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1776,6 +1790,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kRng: case HloOpcode::kTrace: case HloOpcode::kWhile: + case HloOpcode::kGenerateToken: return false; case HloOpcode::kParameter: @@ -2776,6 +2791,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleGather(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); + case HloOpcode::kGenerateToken: + return visitor->HandleGenerateToken(this); // These opcodes are not handled here. case HloOpcode::kTrace: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cc4a8b8252..d252533eb2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -664,6 +664,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions); + // Creates a token instruction used for joining or creating token types which + // thread through side-effecting operations. + static std::unique_ptr<HloInstruction> CreateGenerateToken( + tensorflow::gtl::ArraySlice<HloInstruction*> operands); + // Creates an instance of GatherDimensionNumbers. static GatherDimensionNumbers MakeGatherDimNumbers( tensorflow::gtl::ArraySlice<int64> output_window_dims, diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 1fe06ee0c0..a35546f5f4 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -81,6 +81,7 @@ namespace xla { V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ V(kHostCompute, "host-compute") \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index cd2ce5c69f..774345124b 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kConcatenate: case HloOpcode::kFusion: case HloOpcode::kMap: + case HloOpcode::kGenerateToken: case HloOpcode::kTuple: EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); break; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index a1bc269400..bf1c7b9323 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -606,6 +606,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateReshape(shape, operands[0])); break; } + case HloOpcode::kGenerateToken: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGenerateToken(operands)); + break; + } case HloOpcode::kTuple: { if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 9cfd8a9bf7..9034073cc8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -426,6 +426,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) { + std::vector<const Shape*> operand_shapes; + for (const HloInstruction* operand : token->operands()) { + operand_shapes.push_back(&operand->shape()); + } + return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes)); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with @@ -791,6 +799,46 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. For example, TOKEN types have no Literal representation and cannot be +// on the interface of the entry computation (parameters and root instruction). +Status VerifyEntryAndExitShapes(const HloModule& module) { + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape()).c_str()); + } + } + if (ShapeContainsToken( + module.entry_computation()->root_instruction()->shape())) { + return InternalError( + "Entry root is or contains a token shape: %s", + ShapeUtil::HumanString( + module.entry_computation()->root_instruction()->shape()) + .c_str()); + } + return Status::OK(); +} + +} // namespace + StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -851,6 +899,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } + TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); + return false; } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1392a78097..7283b3e7dc 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -81,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleGenerateToken(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 429c850343..abedb4063d 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: + case HloOpcode::kGenerateToken: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d624f548b1..fdc7f41759 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -463,6 +463,17 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr<Shape> ShapeInference::InferTokenShape( + tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr<Shape> ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 9da2c99b41..6100e2cd33 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -227,6 +227,13 @@ class ShapeInference { static StatusOr<Shape> InferConcatOpShape( tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension); + // Infers the shape produced by a kGenerateToken operation. Trivially this + // shape is always a TOKEN shape. However, ShapeInference serves two purposes: + // inferring shapes and checking operand shapes. This method verifies that the + // operand shapes are all TOKENs. + static StatusOr<Shape> InferTokenShape( + tensorflow::gtl::ArraySlice<const Shape*> arg_shapes); + // Helper that validates the given operand shape can be converted to the // target output_shape via a convert instruction -- the requirement is that // the shape is identical except for the element type. diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 7f6bbe6f87..e7e0a19db0 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -1204,6 +1204,22 @@ xla_test( ) xla_test( + name = "token_hlo_test", + srcs = ["token_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +xla_test( name = "call_test", srcs = ["call_test.cc"], tags = [ diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc new file mode 100644 index 0000000000..4585244ce8 --- /dev/null +++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <array> + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class TokenHloTest : public HloTestBase {}; + +// TODO(b/79770375): Compile, not just verify the HLO module when the backends +// support kGenerateToken. +XLA_TEST_F(TokenHloTest, SingleTokenInstruction) { + std::unique_ptr<HloModule> module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, TokenTree) { + std::unique_ptr<HloModule> module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + builder.AddInstruction( + HloInstruction::CreateGenerateToken({token0, token0, token1, token2})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); + + module->AddEntryComputation(builder.Build()); + EXPECT_IS_OK(HloVerifier().Run(module.get()).status()); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) { + std::unique_ptr<HloModule> module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<int32>(42))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 1 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) { + std::unique_ptr<HloModule> module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateParameter( + 0, + ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}), + "param")); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT( + status.error_message(), + ::testing::HasSubstr("Entry parameter 0 is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidTokenRoot) { + std::unique_ptr<HloModule> module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + builder.AddInstruction(HloInstruction::CreateGenerateToken({})); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("Entry root is or contains a token shape")); +} + +XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) { + std::unique_ptr<HloModule> module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + builder.AddInstruction(HloInstruction::CreateGenerateToken({param})); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<int32>(123))); + module->AddEntryComputation(builder.Build()); + + Status status = HloVerifier().Run(module.get()).status(); + ASSERT_IS_NOT_OK(status); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr( + "Operands of token instructions must be TOKEN types")); +} + +} // namespace +} // namespace xla |