aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-13 14:19:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-13 14:24:30 -0700
commite1296c15a32cac020160a1c89002dc561333c66b (patch)
tree62ef334470d3484d166ea583eddea10fc24d1718
parentbf920de58a3ccb2cfe6642be9c487c3fcb13ccae (diff)
Fix assumptions that a Shape must be a tuple or an array.
A TOKEN primitive type was added with cl/199215963 and XLA also has an OPAQUE primitive type. However, in many places in XLA we assume either a tuple or array. This CL fixes many of those instances, but some may remain. Identified instances were discovered by searching for IsTuple or IsArray so the set of fixes is not exhaustive. Also opportunistically addressed a couple potential points of confusion in the ShapeUtil interface: (1) Rename ShapeUtil::HasZeroElements to ShapeUtil::IsZeroElementArray. The point of confusion here is that tuples can also have zero elements and HasZeroElements would check fail on tuple shapes. Method no longer check fails if the given shape is not an array. (2) ShapeUtil::IsNil now returns true only for empty tuples. Previously it also returned true for zero-element array types which was confusing because ShapeUtil::MakeNil creates an empty tuple. PiperOrigin-RevId: 200452672
-rw-r--r--tensorflow/compiler/tf2xla/lib/batch_dot.cc4
-rw-r--r--tensorflow/compiler/xla/BUILD1
-rw-r--r--tensorflow/compiler/xla/layout_util.cc10
-rw-r--r--tensorflow/compiler/xla/literal_comparison.cc7
-rw-r--r--tensorflow/compiler/xla/literal_util.cc12
-rw-r--r--tensorflow/compiler/xla/literal_util.h2
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc5
-rw-r--r--tensorflow/compiler/xla/primitive_util.h3
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc24
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc2
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc14
-rw-r--r--tensorflow/compiler/xla/service/gather_expander.cc4
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc141
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h2
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc8
-rw-r--r--tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc4
-rw-r--r--tensorflow/compiler/xla/shape_util.cc6
-rw-r--r--tensorflow/compiler/xla/shape_util.h10
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc53
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc2
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc17
33 files changed, 208 insertions, 171 deletions
diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
index 526694d5a0..ee0bb91a6b 100644
--- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc
+++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc
@@ -71,8 +71,8 @@ xla::StatusOr<xla::XlaOp> BatchDot(xla::XlaBuilder* builder, xla::XlaOp x,
}
// Check for zero lhs/rhs dim size.
- if (xla::ShapeUtil::HasZeroElements(x_shape) ||
- xla::ShapeUtil::HasZeroElements(y_shape)) {
+ if (xla::ShapeUtil::IsZeroElementArray(x_shape) ||
+ xla::ShapeUtil::IsZeroElementArray(y_shape)) {
std::vector<int64> dimensions(batch_dimension_numbers.size());
for (int i = 0; i < batch_dimension_numbers.size(); ++i) {
dimensions[i] = x_shape.dimensions(batch_dimension_numbers[i]);
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 1b8e516770..4525197146 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -309,7 +309,6 @@ cc_library(
":types",
":util",
":xla_data_proto",
- "//tensorflow/core:framework",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index e8f29b8329..3f059cac30 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -190,9 +190,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
if (!ShapeUtil::IsArray(shape)) {
- return InvalidArgument(
- "shape of primitive type %s should not have a layout",
- PrimitiveType_Name(shape.element_type()).c_str());
+ if (layout.minor_to_major_size() != 0 ||
+ layout.padded_dimensions_size() != 0) {
+ return InvalidArgument(
+ "shape of primitive type %s should not have a non-trivial layout",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return Status::OK();
}
if (layout.format() == INVALID_FORMAT) {
diff --git a/tensorflow/compiler/xla/literal_comparison.cc b/tensorflow/compiler/xla/literal_comparison.cc
index bf9679cafe..748a243e53 100644
--- a/tensorflow/compiler/xla/literal_comparison.cc
+++ b/tensorflow/compiler/xla/literal_comparison.cc
@@ -606,8 +606,8 @@ Status NearHelper(const LiteralSlice& expected, const LiteralSlice& actual,
} // namespace
Status EqualShapes(const Shape& expected, const Shape& actual) {
- if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) {
- return InvalidArgument("tupleness-mismatch! want: %s got %s",
+ if (expected.element_type() != actual.element_type()) {
+ return InvalidArgument("element type mismatch, want: %s got %s",
ShapeUtil::HumanString(expected).c_str(),
ShapeUtil::HumanString(actual).c_str());
}
@@ -626,7 +626,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
return AppendStatus(result, StrCat("mismatch in tuple index", i));
}
}
- } else {
+ } else if (ShapeUtil::IsArray(expected)) {
if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) {
return InvalidArgument("want rank of %s got rank of %s",
ShapeUtil::HumanString(expected).c_str(),
@@ -652,6 +652,7 @@ Status EqualShapes(const Shape& expected, const Shape& actual) {
}
}
}
+ // Non-array, non-tuple shapes are trivially equivalent.
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 6b29589700..72740e5976 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -264,8 +264,8 @@ Status Literal::CopySliceFromInternal(
StridedCopy(data<NativeT>(), linear_index(shape(), dest_base), 0,
src_literal.data<NativeT>(),
linear_index(src_literal.shape(), src_base), 0, 1);
- } else if (!ShapeUtil::HasZeroElements(shape()) &&
- !ShapeUtil::HasZeroElements(src_literal.shape())) {
+ } else if (!ShapeUtil::IsZeroElementArray(shape()) &&
+ !ShapeUtil::IsZeroElementArray(src_literal.shape())) {
// Perform copy if neither src nor dest has dimensions with zero element,
// otherwise it's a no-op.
TF_RET_CHECK(src_base.size() == dest_base.size());
@@ -379,7 +379,7 @@ void CopyElementsBetween(tensorflow::gtl::MutableArraySlice<NativeT> dest,
tensorflow::gtl::ArraySlice<NativeT> src,
const Shape& dest_shape, const Shape& src_shape) {
CHECK(ShapeUtil::Compatible(dest_shape, src_shape));
- if (ShapeUtil::HasZeroElements(dest_shape)) {
+ if (ShapeUtil::IsZeroElementArray(dest_shape)) {
return;
}
std::vector<int64> index(ShapeUtil::Rank(dest_shape));
@@ -1177,7 +1177,7 @@ size_t LiteralBase::Hash() const {
ShapeUtil::ForEachSubshape(
shape(), [&](const Shape& subshape, const ShapeIndex& index) {
- if (ShapeUtil::IsTuple(subshape)) {
+ if (!ShapeUtil::IsArray(subshape)) {
return;
}
@@ -1556,7 +1556,7 @@ string LiteralBase::ToString(bool print_layout) const {
void LiteralBase::EachCellAsString(
const std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
const string& value)>& per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices = IndexUtil::LinearIndexToMultidimensionalIndex(
@@ -1962,7 +1962,7 @@ bool LiteralBase::IsAllFirst() const {
// Empty shapes are not all the first element since there is no first
// element.
- if (ShapeUtil::HasZeroElements(piece.subshape())) {
+ if (ShapeUtil::IsZeroElementArray(piece.subshape())) {
return false;
}
auto piece_is_all = [&]() {
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 8e4159e360..bcecbcccb7 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -1456,7 +1456,7 @@ void LiteralBase::EachCell(
std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
NativeT value)>
per_cell) const {
- if (ShapeUtil::HasZeroElements(shape())) {
+ if (ShapeUtil::IsZeroElementArray(shape())) {
return;
}
std::vector<int64> indices(ShapeUtil::Rank(shape()), 0);
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 143c9a2366..b16147e3be 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -85,5 +85,10 @@ PrimitiveType ComplexComponentType(PrimitiveType complex_type) {
}
}
+bool IsArrayType(PrimitiveType primitive_type) {
+ return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
+ primitive_type != OPAQUE && primitive_type != TOKEN;
+}
+
} // namespace primitive_util
} // namespace xla
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index b26a10ade6..889e9a1cec 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -133,6 +133,9 @@ bool IsUnsignedIntegralType(PrimitiveType type);
bool IsIntegralType(PrimitiveType type);
+// Returns true if values of the given primitive type are held in array shapes.
+bool IsArrayType(PrimitiveType primitive_type);
+
// Returns the number of bits in the representation for a given type.
int BitWidth(PrimitiveType type);
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 3b36939b8a..1fc8fb9b69 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -449,7 +449,7 @@ Status AlgebraicSimplifierVisitor::HandleConcatenate(
// Filter out and remove empty operands.
std::vector<HloInstruction*> nonempty_operands;
for (HloInstruction* operand : operands) {
- if (!ShapeUtil::HasZeroElements(operand->shape())) {
+ if (!ShapeUtil::IsZeroElementArray(operand->shape())) {
nonempty_operands.push_back(operand);
}
}
@@ -1058,9 +1058,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
}
// Replace a zero element dot with a broadcast of the constant 0.
- if (ShapeUtil::HasZeroElements(dot->shape()) ||
- ShapeUtil::HasZeroElements(lhs->shape()) ||
- ShapeUtil::HasZeroElements(rhs->shape())) {
+ if (ShapeUtil::IsZeroElementArray(dot->shape()) ||
+ ShapeUtil::IsZeroElementArray(lhs->shape()) ||
+ ShapeUtil::IsZeroElementArray(rhs->shape())) {
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
return ReplaceWithNewInstruction(
@@ -1392,7 +1392,7 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) {
}
Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
- if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(pad->operand(0)->shape())) {
return ReplaceWithNewInstruction(
pad, HloInstruction::CreateBroadcast(pad->shape(),
pad->mutable_operand(1), {}));
@@ -1638,7 +1638,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
// Reshape directly to empty constant if the shape contains zero-element
// dimension.
- if (ShapeUtil::HasZeroElements(reshape->shape())) {
+ if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
auto empty_constant = HloInstruction::CreateConstant(
Literal::CreateFromShape(reshape->shape()));
@@ -1739,7 +1739,7 @@ Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice(
// If any dimension of update is 0, elide the DynamicUpdateSlice. This
// optimization becomes invalid should we later prefer to warn about out of
// bound indices.
- if (ShapeUtil::HasZeroElements(update->shape())) {
+ if (ShapeUtil::IsZeroElementArray(update->shape())) {
return ReplaceInstruction(dynamic_update_slice,
dynamic_update_slice->mutable_operand(0));
}
@@ -1751,8 +1751,8 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
auto init_value = reduce->mutable_operand(1);
tensorflow::gtl::ArraySlice<int64> dimensions(reduce->dimensions());
HloComputation* function = reduce->to_apply();
- if (ShapeUtil::HasZeroElements(arg->shape()) ||
- ShapeUtil::HasZeroElements(reduce->shape())) {
+ if (ShapeUtil::IsZeroElementArray(arg->shape()) ||
+ ShapeUtil::IsZeroElementArray(reduce->shape())) {
return ReplaceWithNewInstruction(
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
@@ -1863,7 +1863,7 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) {
Status AlgebraicSimplifierVisitor::HandleReduceWindow(
HloInstruction* reduce_window) {
- if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(reduce_window->operand(0)->shape())) {
return ReplaceWithNewInstruction(
reduce_window,
HloInstruction::CreateBroadcast(reduce_window->shape(),
@@ -2059,8 +2059,8 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
HloInstruction* convolution) {
auto lhs = convolution->mutable_operand(0);
auto rhs = convolution->mutable_operand(1);
- if (ShapeUtil::HasZeroElements(lhs->shape()) ||
- ShapeUtil::HasZeroElements(rhs->shape())) {
+ if (ShapeUtil::IsZeroElementArray(lhs->shape()) ||
+ ShapeUtil::IsZeroElementArray(rhs->shape())) {
return ReplaceWithNewInstruction(
convolution,
HloInstruction::CreateBroadcast(
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index ed0746980f..8f1d2f0804 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -631,7 +631,7 @@ Status BFloat16Propagation::ResolveInconsistentFusions(HloModule* module) {
subshape, converted_outputs.element(parent_index),
output_index.back()));
}
- if (ShapeUtil::IsTuple(subshape)) {
+ if (!ShapeUtil::IsArray(subshape)) {
continue;
}
if (!ShapeUtil::Compatible(
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 8eb39d615f..e8b205051e 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -1627,8 +1627,8 @@ bool PotentiallyImplementedAsEigenDot(
const Shape& lhs_shape = hlo.operand(0)->shape();
const Shape& rhs_shape = hlo.operand(1)->shape();
- if (ShapeUtil::HasZeroElements(lhs_shape) ||
- ShapeUtil::HasZeroElements(rhs_shape)) {
+ if (ShapeUtil::IsZeroElementArray(lhs_shape) ||
+ ShapeUtil::IsZeroElementArray(rhs_shape)) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
index b560b7531c..1a8bedfe6a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -64,8 +64,8 @@ bool PotentiallyImplementedAsEigenConvolution(
return false;
}
- if (ShapeUtil::HasZeroElements(input_shape) ||
- ShapeUtil::HasZeroElements(kernel_shape)) {
+ if (ShapeUtil::IsZeroElementArray(input_shape) ||
+ ShapeUtil::IsZeroElementArray(kernel_shape)) {
return false;
}
// Make sure input and kernel has the same data type.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a4141dee01..94053e5716 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -226,10 +226,13 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy));
return EmitMemcpy(*(copy->operand(0)), *copy);
- } else {
- // Use the elemental emitter for non-tuple shapes.
+ } else if (ShapeUtil::IsArray(copy->shape())) {
+ // Use the elemental emitter for array shapes.
return DefaultAction(copy);
}
+ return Unimplemented(
+ "unsupported operand type %s for copy instruction",
+ PrimitiveType_Name(copy->shape().element_type()).c_str());
}
// Calculate the alignment of a buffer allocated for a given primitive type.
@@ -1867,7 +1870,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice));
- if (ShapeUtil::HasZeroElements(slice->shape())) {
+ if (ShapeUtil::IsZeroElementArray(slice->shape())) {
return Status::OK();
}
@@ -2803,7 +2806,10 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
// For the root node, we write directly to the output buffer of the
// function.
llvm::Argument* retval = compute_function_->result_arg();
- if (!ShapeUtil::IsNil(target_shape)) {
+ if ((ShapeUtil::IsArray(target_shape) &&
+ !ShapeUtil::IsZeroElementArray(target_shape)) ||
+ (ShapeUtil::IsTuple(target_shape) &&
+ !ShapeUtil::IsEmptyTuple(target_shape))) {
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
diff --git a/tensorflow/compiler/xla/service/gather_expander.cc b/tensorflow/compiler/xla/service/gather_expander.cc
index 2d3e4b1fcd..7cd2c9c136 100644
--- a/tensorflow/compiler/xla/service/gather_expander.cc
+++ b/tensorflow/compiler/xla/service/gather_expander.cc
@@ -300,7 +300,7 @@ static StatusOr<HloInstruction*> PermuteGatherAndWindowDims(
StatusOr<HloInstruction*> GatherExpander::ExpandGather(
HloInstruction* gather_instr) {
- CHECK(!ShapeUtil::HasZeroElements(gather_instr->shape()));
+ CHECK(!ShapeUtil::IsZeroElementArray(gather_instr->shape()));
HloComputation* computation = gather_instr->parent();
HloInstruction* operand = gather_instr->mutable_operand(0);
@@ -369,7 +369,7 @@ StatusOr<bool> GatherExpander::Run(HloModule* module) {
return inst->opcode() == HloOpcode::kGather &&
// Avoid expanding gather ops that produce zero sized tensors,
// instead punt these to ZeroSizedHloElimination.
- !ShapeUtil::HasZeroElements(inst->shape());
+ !ShapeUtil::IsZeroElementArray(inst->shape());
};
std::vector<HloInstruction*> gather_instrs;
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 5ee67ccb4a..d9f62c21c4 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -74,7 +74,7 @@ GenericTransferManager::TransferLiteralFromDevice(
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
device_buffer.on_host_shape(),
[&](const Shape& subshape, const ShapeIndex& index) -> Status {
- if (!ShapeUtil::IsTuple(subshape)) {
+ if (ShapeUtil::IsArray(subshape)) {
TF_RETURN_IF_ERROR(TransferBufferFromDevice(
executor,
/*source=*/device_buffer.buffer(index),
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
index e0c73aa73a..f9dccd287d 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc
@@ -42,8 +42,8 @@ bool CanImplementAsCudnnForwardConv(HloInstruction* conv) {
}
// CuDNN does not accept zero-element arguments
- if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) ||
- ShapeUtil::HasZeroElements(conv->operand(1)->shape())) {
+ if (ShapeUtil::IsZeroElementArray(conv->operand(0)->shape()) ||
+ ShapeUtil::IsZeroElementArray(conv->operand(1)->shape())) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 67890bfed1..388aa35d7d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -56,8 +56,8 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
IsRank2WithNoPadding(rhs_shape) &&
IsRank2WithNoPadding(output_shape) &&
- !ShapeUtil::HasZeroElements(lhs_shape) &&
- !ShapeUtil::HasZeroElements(rhs_shape);
+ !ShapeUtil::IsZeroElementArray(lhs_shape) &&
+ !ShapeUtil::IsZeroElementArray(rhs_shape);
}
bool DotImplementedAsGemm(const HloInstruction& dot) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 547af33e9a..7b7dd673a5 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -610,7 +610,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
- if (ShapeUtil::HasZeroElements(convolution->shape())) {
+ if (ShapeUtil::IsZeroElementArray(convolution->shape())) {
// Emit no code for an empty output.
return Status::OK();
}
@@ -620,7 +620,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) {
}
Status IrEmitter::HandleFft(HloInstruction* fft) {
- if (ShapeUtil::HasZeroElements(fft->shape())) {
+ if (ShapeUtil::IsZeroElementArray(fft->shape())) {
// Emit no code for an empty output.
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index b158f44923..c73e54a0b1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -556,8 +556,13 @@ StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
}
return AddInstruction(HloInstruction::CreateTuple(elements));
} else {
- return FailedPrecondition(
- "Can only copy array and tuple shaped instructions");
+ // Tokens, opaques, etc are not copyable.
+ if (indices_to_copy == nullptr || indices_to_copy->element(*index)) {
+ return FailedPrecondition(
+ "Cannot copy instruction of shape: %s",
+ ShapeUtil::HumanString(instruction->shape()).c_str());
+ }
+ return instruction;
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index e0648e1467..080ee4ad18 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -372,7 +372,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
// The result concatenate dimension is going to be the sum of all
// concatenate dimensions of the operands taking part of the operation.
const Shape& reference_shape = operands[0]->shape();
- CHECK(!ShapeUtil::IsTuple(reference_shape));
+ CHECK(ShapeUtil::IsArray(reference_shape));
const int64 rank = ShapeUtil::Rank(reference_shape);
const int64 concat_dim = concatenate->dimensions()[0];
CHECK_GE(concat_dim, 0);
@@ -383,7 +383,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) {
for (int64 i = 1; i < operands.size(); ++i) {
const Shape& operand_shape = operands[i]->shape();
- CHECK(!ShapeUtil::IsTuple(operand_shape));
+ CHECK(ShapeUtil::IsArray(operand_shape));
// Accumulate the concat dimension from all tensors taking part to the
// operation.
concat_dimensions[concat_dim] +=
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 13f46407e3..e01ce19d04 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -778,7 +778,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
Status HandleSelect(HloInstruction* select) override {
CHECK(!ShapeUtil::IsScalar(select->operand(0)->shape()));
- CHECK(!ShapeUtil::IsTuple(select->shape()));
+ CHECK(ShapeUtil::IsArray(select->shape()));
std::function<ReturnT(bool, ReturnT, ReturnT)> select_op =
[](bool pred, ReturnT on_true, ReturnT on_false) {
if (pred) {
@@ -1103,7 +1103,7 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
Status HandlePad(HloInstruction* pad) override {
- CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape()));
+ CHECK(ShapeUtil::IsArray(pad->operand(0)->shape()));
// Padding value must be scalar.
CHECK(ShapeUtil::IsScalar(pad->operand(1)->shape()));
CHECK_EQ(ShapeUtil::Rank(pad->operand(0)->shape()),
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 28fc6c4209..ab224021c5 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -832,13 +832,13 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
// is just noise.
- if (!ShapeUtil::IsTuple(shape) && ShapeUtil::HasZeroElements(shape)) {
+ if (ShapeUtil::IsZeroElementArray(shape)) {
return Printf("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements.
optional<int64> elem_count;
- if (!ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)) {
+ if (ShapeUtil::IsArray(shape)) {
elem_count = 1;
for (int64 dim : shape.dimensions()) {
*elem_count *= dim;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 761d833546..34038ae0ae 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -658,7 +658,7 @@ string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
CanonicalNameMap* canonical_name_map) const {
string operands;
// For constants, show the actual value in place of an empty operand list.
- if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
+ if ((ShapeUtil::IsArray(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
options.print_large_constants()) {
// Literal::ToString emits multidimensional arrays over multiple
// lines. Compact this into one line by stripping out white space.
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 9034073cc8..1d6cd4cb23 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -431,7 +431,8 @@ Status ShapeVerifier::HandleGenerateToken(HloInstruction* token) {
for (const HloInstruction* operand : token->operands()) {
operand_shapes.push_back(&operand->shape());
}
- return CheckShape(token, ShapeInference::InferTokenShape(operand_shapes));
+ return CheckShape(token,
+ ShapeInference::InferGenerateTokenShape(operand_shapes));
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index bd98e86b08..e25f5e67c7 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -49,19 +49,13 @@ bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
}
-Status ExpectNotTupleOrOpaque(const Shape& shape,
- tensorflow::StringPiece op_type) {
- if (ShapeUtil::IsTuple(shape)) {
- return InvalidArgument("Expected non-tuple argument for %s, but got %s.",
+Status ExpectArray(const Shape& shape, tensorflow::StringPiece op_type) {
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument("Expected array argument for %s, but got %s.",
std::string(op_type).c_str(),
ShapeUtil::HumanString(shape).c_str());
- } else if (ShapeUtil::IsOpaque(shape)) {
- return InvalidArgument("Expected non-opaque argument for %s, but got %s.",
- std::string(op_type).c_str(),
- ShapeUtil::HumanString(shape).c_str());
- } else {
- return Status::OK();
}
+ return Status::OK();
}
Status VerifyReducerShape(const ProgramShape& reducer_shape,
@@ -198,8 +192,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return shape;
}
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(shape, "operand of unary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(shape, "operand of unary operation"));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
switch (opcode) {
@@ -289,8 +282,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
const Shape* arg_shape = nullptr;
PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
for (const Shape* shape : arg_shapes) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
+ TF_RETURN_IF_ERROR(ExpectArray(*shape, "operand of concatenation"));
if (!arg_shape) {
arg_shape = shape;
element_type = arg_shape->element_type();
@@ -337,7 +329,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
-/* static */ StatusOr<Shape> ShapeInference::InferTokenShape(
+/* static */ StatusOr<Shape> ShapeInference::InferGenerateTokenShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
for (const Shape* arg_shape : arg_shapes) {
if (arg_shape->element_type() != TOKEN) {
@@ -358,12 +350,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
return InvalidArgument(
- "Convert does not allow tuples, so cannot convert from %s to %s.",
+ "Convert does not allow non-arrays, so cannot convert from %s to %s.",
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
@@ -380,7 +373,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
PrimitiveType_Name(new_element_type).c_str());
}
- if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ if (!ShapeUtil::IsArray(operand_shape) ||
+ !primitive_util::IsArrayType(new_element_type)) {
// Note: we may want to support tuple conversions via this operation in the
// future, by recursing into the tuple elements to check all sub-conversions
// are valid. For now we just reject them, though.
@@ -427,7 +421,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
const Shape& operand_shape, const Shape& padding_value_shape,
const PaddingConfig& padding_config) {
- if (ShapeUtil::IsTuple(operand_shape)) {
+ if (!ShapeUtil::IsArray(operand_shape)) {
return InvalidArgument(
"Pad operation does not support tuple-shape operands.");
}
@@ -566,8 +560,8 @@ Status ValidateDotDimensionNumbers(
/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of dot"));
auto fail = [lhs, rhs](const string& addendum) -> Status {
string message = tensorflow::strings::Printf(
@@ -786,10 +780,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
HloOpcode operation, const Shape& lhs, const Shape& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of elementwise binary operation"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -853,12 +845,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs));
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- lhs, tensorflow::strings::StrCat("lhs of binary operation ",
- HloOpcodeString(opcode))));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- rhs, tensorflow::strings::StrCat("rhs of binary operation ",
- HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(lhs, tensorflow::strings::StrCat("lhs of binary operation ",
+ HloOpcodeString(opcode))));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(rhs, tensorflow::strings::StrCat("rhs of binary operation ",
+ HloOpcodeString(opcode))));
switch (opcode) {
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
@@ -984,15 +976,12 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// All arguments must have the same shape.
const Shape* arg_shape = arg_shapes[0];
for (size_t i = 1; i < arg_shapes.size(); ++i) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
+ TF_RETURN_IF_ERROR(ExpectArray(*arg_shapes[i], "operand of map"));
if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
continue;
}
- if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
- !ShapeUtil::IsTuple(*arg_shape) &&
- ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
+ if (ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
*arg_shape)) {
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
continue;
@@ -1075,11 +1064,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& scale_shape,
const Shape& offset_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm training"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm training"));
+ ExpectArray(operand_shape, "operand of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm training"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm training"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1181,11 +1170,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& offset_shape, const Shape& mean_shape,
const Shape& variance_shape, int64 feature_index) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- offset_shape, "offset input of batch norm inference"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- scale_shape, "scale input of batch norm inference"));
+ ExpectArray(operand_shape, "operand of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(offset_shape, "offset input of batch norm inference"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scale_shape, "scale input of batch norm inference"));
TF_RET_CHECK(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape) ==
Status::OK());
@@ -1328,16 +1317,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& scale_shape,
const Shape& mean_shape, const Shape& var_shape,
const Shape& output_grad_shape, int64 feature_index) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm grad"));
+ ExpectArray(scale_shape, "scale input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(mean_shape, "mean input of batch norm grad"));
+ TF_RETURN_IF_ERROR(ExpectArray(var_shape, "var input of batch norm grad"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(scale_shape, "scale input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(mean_shape, "mean input of batch norm grad"));
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(var_shape, "var input of batch norm grad"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- output_grad_shape, "output_grad input of batch norm grad"));
+ ExpectArray(output_grad_shape, "output_grad input of batch norm grad"));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(operand_shape));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(mean_shape));
@@ -1486,8 +1472,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
const Shape& lhs, const Shape& rhs, const Window& window,
const ConvolutionDimensionNumbers& dnums) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(lhs, "lhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectArray(rhs, "rhs of convolution"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
@@ -1722,7 +1708,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
tensorflow::gtl::ArraySlice<const Shape*> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum"));
+ ExpectArray(*operand_shape, "operand of cross replica sum"));
}
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
@@ -1764,8 +1750,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window, const ProgramShape& to_apply_shape) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reduce-window"));
TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
operand_shape.element_type()));
return InferWindowOutputShape(operand_shape, window,
@@ -1778,7 +1763,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Window& window, const Shape& source_shape,
const Shape& init_value_shape, const ProgramShape& scatter_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
+ ExpectArray(operand_shape, "operand of select-and-scatter"));
// Check if the select function has a proper shape of (T,T) -> PRED.
if (select_shape.parameters_size() != 2) {
@@ -1843,7 +1828,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
Join(starts, ",").c_str(), Join(limits, ",").c_str(),
Join(strides, ",").c_str());
};
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(arg, "operand of slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s starts={%s} limits={%s}",
ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(),
@@ -1902,10 +1887,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
const Shape& operand_shape, const Shape& start_indices_shape,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of dynamic slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
- "start indices of dynamic slice"));
+ ExpectArray(start_indices_shape, "start indices of dynamic slice"));
VLOG(2) << tensorflow::strings::Printf(
"slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
@@ -1963,11 +1947,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
const Shape& operand_shape, const Shape& update_shape,
const Shape& start_indices_shape) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
+ ExpectArray(operand_shape, "operand of dynamic update slice"));
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- start_indices_shape, "start indices of dynamic update slice"));
+ ExpectArray(update_shape, "update of dynamic update slice"));
+ TF_RETURN_IF_ERROR(ExpectArray(start_indices_shape,
+ "start indices of dynamic update slice"));
VLOG(2) << tensorflow::strings::Printf(
"updating slice of shape %s at dynamic start_indices %s with update "
@@ -2035,8 +2019,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand_shape, "operand of reverse"));
if (!AllUnique(dimensions)) {
return InvalidArgument("a dimension number is duplicated in reverse");
}
@@ -2166,7 +2149,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "operand of broadcast"));
for (int64 size : broadcast_sizes) {
if (size < 0) {
return InvalidArgument("Broadcast with negative dimension size %lld.",
@@ -2185,7 +2168,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "reshape"));
Shape inferred_shape =
ShapeUtil::MakeShape(operand.element_type(), new_sizes);
@@ -2217,7 +2200,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "transpose"));
std::vector<int64> indices(ShapeUtil::Rank(operand));
std::iota(indices.begin(), indices.end(), 0);
@@ -2238,9 +2221,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
// "degenerate" cases, as with binary elementwise ops.
/* static */ StatusOr<Shape> ShapeInference::InferClampShape(
const Shape& min, const Shape& operand, const Shape& max) {
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
+ TF_RETURN_IF_ERROR(ExpectArray(min, "clamp min"));
+ TF_RETURN_IF_ERROR(ExpectArray(operand, "clamp operand"));
+ TF_RETURN_IF_ERROR(ExpectArray(max, "clamp max"));
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("Clamp with different operand types: %s, %s, %s.",
@@ -2439,9 +2422,9 @@ static Status ValidateGatherDimensionNumbers(
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds) {
TF_RETURN_IF_ERROR(
- ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
- TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
- gather_indices_shape, "gather indices operand of gather op"));
+ ExpectArray(input_shape, "input tensor operand gather op"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(gather_indices_shape, "gather indices operand of gather op"));
if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
return InvalidArgument(
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index f1f7b50902..eef6e62fc8 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -220,7 +220,7 @@ class ShapeInference {
// shape is always a TOKEN shape. However, ShapeInference serves two purposes:
// inferring shapes and checking operand shapes. This method verifies that the
// operand shapes are all TOKENs.
- static StatusOr<Shape> InferTokenShape(
+ static StatusOr<Shape> InferGenerateTokenShape(
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes);
// Helper that validates the given operand shape can be converted to the
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6d017dffe2..bafe14d6f4 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1311,7 +1311,7 @@ TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
ASSERT_FALSE(inferred_status_error4.ok());
ASSERT_THAT(
inferred_status_error4.status().error_message(),
- HasSubstr("Expected non-tuple argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
@@ -1387,7 +1387,7 @@ TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
ShapeInference::InferReverseShape(tuple_shape, {0});
ASSERT_FALSE(inferred_status_error3.ok());
ASSERT_THAT(inferred_status_error3.status().error_message(),
- HasSubstr("Expected non-tuple argument"));
+ HasSubstr("Expected array argument"));
}
TEST_F(ShapeInferenceTest, Call) {
@@ -1686,7 +1686,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for input"))
+ HasSubstr("Expected array argument for input"))
<< statusor.status();
}
@@ -1700,7 +1700,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
/*window_bounds=*/{64, 1});
ASSERT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
- HasSubstr("Expected non-tuple argument for gather indices"))
+ HasSubstr("Expected array argument for gather indices"))
<< statusor.status();
}
diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
index aa40b5cb26..44b0ec5cd4 100644
--- a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
+++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc
@@ -32,11 +32,11 @@ StatusOr<bool> ZeroSizedHloElimination::Run(HloModule* module) {
for (HloComputation* comp : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (instruction->HasSideEffect() ||
- ShapeUtil::IsTuple(instruction->shape())) {
+ !ShapeUtil::IsArray(instruction->shape())) {
continue;
}
if (comp->IsRemovable(instruction) &&
- ShapeUtil::HasZeroElements(instruction->shape())) {
+ ShapeUtil::IsZeroElementArray(instruction->shape())) {
TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction(
instruction, HloInstruction::CreateConstant(
Literal::CreateFromShape(instruction->shape()))));
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 5db6659932..2c484661ee 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -363,7 +363,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::IsNil(const Shape& shape) {
- return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape);
+ return IsEmptyTuple(shape);
}
/* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
@@ -413,8 +413,8 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
std::multiplies<int64>());
}
-/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) {
- return ElementsIn(shape) == 0;
+/* static */ bool ShapeUtil::IsZeroElementArray(const Shape& shape) {
+ return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index ae2d17d6bb..b6d29976d1 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -175,8 +175,8 @@ class ShapeUtil {
// Precondition: IsArray(shape)
static int64 ElementsIn(const Shape& shape);
- // Returns true if 'shape' has zero elements.
- static bool HasZeroElements(const Shape& shape);
+ // Returns true if 'shape' is an array with zero elements.
+ static bool IsZeroElementArray(const Shape& shape);
// Returns the number of bytes required for an allocation of shape. The
// |pointer_size| parameter is used for calculating the size of tuple
@@ -336,7 +336,7 @@ class ShapeUtil {
// Appends a major dimension to the shape with the given bound.
static void AppendMajorDimension(int bound, Shape* shape);
- // Returns an empty tuple shape. Can be used to indicate side-effects.
+ // Returns an empty tuple shape. Can be used as a sentinel Shape value.
static Shape MakeNil() { return MakeTupleShape({}); }
// Checks whether the shape is initialized.
@@ -446,7 +446,7 @@ class ShapeUtil {
// Returns true if shape is an empty tuple.
static bool IsEmptyTuple(const Shape& shape);
- // Returns true if shape is an empty tuple, or is an array with no elements.
+ // Returns true if shape is the nil shape (an empty tuple).
static bool IsNil(const Shape& shape);
// Returns the number of elements in the given tuple shape.
@@ -697,7 +697,7 @@ class ShapeUtil {
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function,
bool parallel = false) {
- if (ShapeUtil::HasZeroElements(shape)) {
+ if (ShapeUtil::IsZeroElementArray(shape)) {
return Status::OK();
}
CHECK_EQ(Rank(shape), base.size());
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index 0ff514564b..ebfe06d4bc 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -329,6 +329,16 @@ TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape));
}
+TEST(ShapeUtilTest, NilShape) {
+ EXPECT_TRUE(ShapeUtil::IsNil(ShapeUtil::MakeNil()));
+ EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {1, 2, 3})));
+ EXPECT_FALSE(ShapeUtil::IsNil(ShapeUtil::MakeShape(F32, {0, 1})));
+ EXPECT_FALSE(ShapeUtil::IsNil(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_FALSE(ShapeUtil::IsNil(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {0})})));
+}
+
TEST(ShapeUtilTest, NestedTuple) {
EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({})));
EXPECT_FALSE(ShapeUtil::IsNestedTuple(
@@ -359,25 +369,30 @@ TEST(ShapeUtilTest, ElementsIn) {
EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
}
-TEST(ShapeUtilTest, HasZeroElements) {
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {})));
- EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0})));
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1})));
- EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5})));
- EXPECT_EQ(true,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5})));
- EXPECT_EQ(true,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5})));
- EXPECT_EQ(false,
- ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17})));
+TEST(ShapeUtilTest, IsZeroElementArray) {
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_TRUE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_TRUE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_TRUE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_FALSE(
+ ShapeUtil::IsZeroElementArray(ShapeUtil::MakeShape(S32, {13, 17})));
+
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeNil()));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(ShapeUtil::MakeTupleShape({})));
+ EXPECT_FALSE(ShapeUtil::IsZeroElementArray(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {0, 3, 0})})));
}
TEST(ShapeUtilTest, SameDimensions) {
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 36a7064969..c3a289ee09 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -2758,7 +2758,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(computation_status.status().ToString(),
::testing::ContainsRegex(
- "Expected non-opaque argument for lhs of binary operation"));
+ "Expected array argument for lhs of binary operation"));
}
XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) {
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
index a4c8a83eb1..352864502a 100644
--- a/tensorflow/compiler/xla/tests/concat_test.cc
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -417,7 +417,22 @@ XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
ASSERT_FALSE(computation_status.ok());
EXPECT_THAT(
computation_status.status().ToString(),
- HasSubstr("Expected non-opaque argument for operand of concatenation"));
+ HasSubstr("Expected array argument for operand of concatenation"));
+}
+
+// Show that we can't concatenate with tokens.
+XLA_TEST_F(ConcatTest, CannotConcatTokens) {
+ XlaBuilder builder(TestName());
+ auto token_shape = ShapeUtil::MakeTokenShape();
+ auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
+ auto x = builder.Parameter(0, r1f32, "x");
+ auto y = builder.Parameter(1, token_shape, "y");
+ builder.ConcatInDim({x, y}, 0);
+ StatusOr<XlaComputation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_THAT(
+ computation_status.status().ToString(),
+ HasSubstr("Expected array argument for operand of concatenation"));
}
XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {