aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-08 12:50:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 12:52:42 -0700
commit278fbe4146b160980fec318187546d9d8870d244 (patch)
tree402e929ee51c089458fa7bfac1ae6826f73d71cd /tensorflow
parent7bb79ee219d4efbd92d1ef4e0dbe45f4aee26654 (diff)
Add kGenerateToken HLO instruction.
The new HLO instruction serves two purposes. (1) It generates a new token value. This is the only way to create tokens. (2) The operation is variadic, taking zero or more token operands. The operation acts as a join of its operands. I considered initially using a kConstant constant as a method to create new tokens, but this ran into problems because of expectations in backends regarding constants and their materialization. This CL enables creation of generate-token instructions, but the new instruction is not supported yet in any backend. PiperOrigin-RevId: 199836205
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h2
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc50
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc11
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h7
-rw-r--r--tensorflow/compiler/xla/tests/BUILD16
-rw-r--r--tensorflow/compiler/xla/tests/token_hlo_test.cc124
19 files changed, 263 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 64678d9d74..ee2b455730 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -243,6 +243,8 @@ class DfsHloVisitorBase {
virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0;
+ virtual Status HandleGenerateToken(HloInstructionPtr token) = 0;
+
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
virtual Status FinishVisit(HloInstructionPtr root) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 240faebe62..6934e00a4b 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
+ Status HandleGenerateToken(HloInstructionPtr token) override {
+ return DefaultAction(token);
+ }
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index b9d30ee802..92a66681a9 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -387,6 +387,10 @@ Status HloCostAnalysis::HandleTranspose(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleGenerateToken(const HloInstruction*) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) {
auto lhs = convolution->operand(0);
auto rhs = convolution->operand(1);
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index d17678d20f..0d66736fe1 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -97,6 +97,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleBroadcast(const HloInstruction* broadcast) override;
Status HandlePad(const HloInstruction* pad) override;
Status HandleReshape(const HloInstruction* reshape) override;
+ Status HandleGenerateToken(const HloInstruction* token) override;
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 1e78d775c8..e0648e1467 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -910,6 +910,14 @@ Status HloEvaluator::HandleBroadcast(HloInstruction* broadcast) {
return Status::OK();
}
+Status HloEvaluator::HandleGenerateToken(HloInstruction* token) {
+ // Literals cannot represent a TOKEN shape so just create an empty tuple as
+ // the "result" of the kGenerateToken operation.
+ // TODO(b/109929053): Add support for TOKENs in Literals.
+ evaluated_[token] = Literal::MakeTuple({});
+ return Status::OK();
+}
+
Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) {
const auto result_shape = get_tuple_element->shape();
const int64 index = get_tuple_element->tuple_index();
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index b53d5644de..fc2fc9437b 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -174,6 +174,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandleGenerateToken(HloInstruction* token) override;
+
// Returns the already-evaluated literal result for the instruction.
// A Constant instruction is considered evaluated and its literal will be
// returned directly without looking up the cache.
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index a6750460e5..cf954001c6 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -964,6 +964,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kBitcast:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTrace:
+ case HloOpcode::kGenerateToken:
case HloOpcode::kTuple:
return kWhite;
case HloOpcode::kBroadcast:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index ae230d2740..a778a6a965 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -583,6 +583,17 @@ HloInstruction::CreateCrossReplicaSum(
return MakeUnique<HloReverseInstruction>(shape, operand, dimensions);
}
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateGenerateToken(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ auto instruction = WrapUnique(new HloInstruction(
+ HloOpcode::kGenerateToken, ShapeUtil::MakeTokenShape()));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
const Shape& shape, HloComputation* condition, HloComputation* body,
HloInstruction* init) {
@@ -1512,6 +1523,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CreateDomain(shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
break;
+ case HloOpcode::kGenerateToken:
+ clone = CreateGenerateToken(new_operands);
+ break;
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@@ -1776,6 +1790,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kRng:
case HloOpcode::kTrace:
case HloOpcode::kWhile:
+ case HloOpcode::kGenerateToken:
return false;
case HloOpcode::kParameter:
@@ -2776,6 +2791,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleGather(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
+ case HloOpcode::kGenerateToken:
+ return visitor->HandleGenerateToken(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cc4a8b8252..d252533eb2 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -664,6 +664,11 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates a token instruction used for joining or creating token types which
+ // thread through side-effecting operations.
+ static std::unique_ptr<HloInstruction> CreateGenerateToken(
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
tensorflow::gtl::ArraySlice<int64> output_window_dims,
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 1fe06ee0c0..a35546f5f4 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -81,6 +81,7 @@ namespace xla {
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
+ V(kGenerateToken, "generate-token", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
V(kHostCompute, "host-compute") \
diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
index cd2ce5c69f..774345124b 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
@@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kConcatenate:
case HloOpcode::kFusion:
case HloOpcode::kMap:
+ case HloOpcode::kGenerateToken:
case HloOpcode::kTuple:
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index a1bc269400..bf1c7b9323 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -606,6 +606,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
HloInstruction::CreateReshape(shape, operands[0]));
break;
}
+ case HloOpcode::kGenerateToken: {
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateGenerateToken(operands));
+ break;
+ }
case HloOpcode::kTuple: {
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9cfd8a9bf7..9034073cc8 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -426,6 +426,14 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
+Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) {
+ std::vector<const Shape*> operand_shapes;
+ for (const HloInstruction* operand : token->operands()) {
+ operand_shapes.push_back(&operand->shape());
+ }
+ return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes));
+}
+
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) {
// If allow_mixed_precision_ is false, check if there are operands with
@@ -791,6 +799,46 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
return Status::OK();
}
+namespace {
+
+// Returns true if the given Shape has a TOKEN shape as any subshape.
+bool ShapeContainsToken(const Shape& shape) {
+ bool contains_token = false;
+ ShapeUtil::ForEachSubshape(
+ shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsToken(subshape)) {
+ contains_token = true;
+ }
+ });
+ return contains_token;
+}
+
+// Verifies that all types entering and exiting the entry computation are
+// legal. For example, TOKEN types have no Literal representation and cannot be
+// on the interface of the entry computation (parameters and root instruction).
+Status VerifyEntryAndExitShapes(const HloModule& module) {
+ for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
+ HloInstruction* param =
+ module.entry_computation()->parameter_instruction(i);
+ if (ShapeContainsToken(param->shape())) {
+ return InternalError(
+ "Entry parameter %d is or contains a token shape: %s", i,
+ ShapeUtil::HumanString(param->shape()).c_str());
+ }
+ }
+ if (ShapeContainsToken(
+ module.entry_computation()->root_instruction()->shape())) {
+ return InternalError(
+ "Entry root is or contains a token shape: %s",
+ ShapeUtil::HumanString(
+ module.entry_computation()->root_instruction()->shape())
+ .c_str());
+ }
+ return Status::OK();
+}
+
+} // namespace
+
StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
@@ -851,6 +899,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
}
+ TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
+
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 1392a78097..7283b3e7dc 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -81,6 +81,7 @@ class ShapeVerifier : public DfsHloVisitor {
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
+ Status HandleGenerateToken(HloInstruction* token) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 429c850343..abedb4063d 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -96,6 +96,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kShiftRightLogical:
case HloOpcode::kSlice:
case HloOpcode::kSubtract:
+ case HloOpcode::kGenerateToken:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return false;
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index d624f548b1..fdc7f41759 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -463,6 +463,17 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
+/* static */ StatusOr<Shape> ShapeInference::InferTokenShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ for (const Shape* arg_shape : arg_shapes) {
+ if (arg_shape->element_type() != TOKEN) {
+ return InvalidArgument(
+ "Operands of token instructions must be TOKEN types.");
+ }
+ }
+ return ShapeUtil::MakeTokenShape();
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 9da2c99b41..6100e2cd33 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -227,6 +227,13 @@ class ShapeInference {
static StatusOr<Shape> InferConcatOpShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+ // Infers the shape produced by a kGenerateToken operation. Trivially this
+ // shape is always a TOKEN shape. However, ShapeInference serves two purposes:
+ // inferring shapes and checking operand shapes. This method verifies that the
+ // operand shapes are all TOKENs.
+ static StatusOr<Shape> InferTokenShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
+
// Helper that validates the given operand shape can be converted to the
// target output_shape via a convert instruction -- the requirement is that
// the shape is identical except for the element type.
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 7f6bbe6f87..e7e0a19db0 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -1204,6 +1204,22 @@ xla_test(
)
xla_test(
+ name = "token_hlo_test",
+ srcs = ["token_hlo_test.cc"],
+ tags = [
+ "enable_for_xla_interpreter",
+ ],
+ deps = [
+ ":client_library_test_base",
+ "//tensorflow/compiler/xla/service:hlo_verifier",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
name = "call_test",
srcs = ["call_test.cc"],
tags = [
diff --git a/tensorflow/compiler/xla/tests/token_hlo_test.cc b/tensorflow/compiler/xla/tests/token_hlo_test.cc
new file mode 100644
index 0000000000..4585244ce8
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/token_hlo_test.cc
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <array>
+
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class TokenHloTest : public HloTestBase {};
+
+// TODO(b/79770375): Compile, not just verify the HLO module when the backends
+// support kGenerateToken.
+XLA_TEST_F(TokenHloTest, SingleTokenInstruction) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+
+ module->AddEntryComputation(builder.Build());
+ EXPECT_IS_OK(HloVerifier().Run(module.get()).status());
+}
+
+XLA_TEST_F(TokenHloTest, TokenTree) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto token0 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ auto token2 = builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ builder.AddInstruction(
+ HloInstruction::CreateGenerateToken({token0, token0, token1, token2}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+
+ module->AddEntryComputation(builder.Build());
+ EXPECT_IS_OK(HloVerifier().Run(module.get()).status());
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTokenShapedEntryParameter) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeTokenShape(), "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(42)));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Entry parameter 1 is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTupleTokenShapedEntryParameter) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0,
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeTokenShape()}),
+ "param"));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Entry parameter 0 is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidTokenRoot) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateGenerateToken({}));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("Entry root is or contains a token shape"));
+}
+
+XLA_TEST_F(TokenHloTest, InvalidOperandToTokenInstruction) {
+ std::unique_ptr<HloModule> module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
+ builder.AddInstruction(HloInstruction::CreateGenerateToken({param}));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<int32>(123)));
+ module->AddEntryComputation(builder.Build());
+
+ Status status = HloVerifier().Run(module.get()).status();
+ ASSERT_IS_NOT_OK(status);
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr(
+ "Operands of token instructions must be TOKEN types"));
+}
+
+} // namespace
+} // namespace xla