aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-06-04 16:41:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 16:44:21 -0700
commit14d4d1634dd2bd70ebc1629bc27354309bce0cb4 (patch)
treea5f2c94125a98ebf57e3c815b99b74ad69f2ca74
parentcf01d118ef0762c0554611bef123bf4559071fbf (diff)
Add TOKEN primitive type.
The token type will be threaded through side-effecting ops to order them. Subsequent cls will add new opcodes and change side effecting operations to support this ordering. This CL also does some cleanup in shape_util and layout_util where we have assumed that shapes are either arrays or tuples. PiperOrigin-RevId: 199215963
-rw-r--r--tensorflow/compiler/xla/layout_util.cc53
-rw-r--r--tensorflow/compiler/xla/layout_util_test.cc51
-rw-r--r--tensorflow/compiler/xla/shape_util.cc263
-rw-r--r--tensorflow/compiler/xla/shape_util.h26
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc49
-rw-r--r--tensorflow/compiler/xla/xla_data.proto11
6 files changed, 304 insertions, 149 deletions
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 89cafa1a7d..e8f29b8329 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -98,8 +98,13 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
} // namespace
/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
+ if (ShapeUtil::IsOpaque(shape) || ShapeUtil::IsToken(shape)) {
+ // Opaque and token types have empty layouts.
+ return Layout();
+ }
+
// A Layout proto corresponds to a single array, not a tuple.
- DCHECK(!ShapeUtil::IsTuple(shape));
+ CHECK(ShapeUtil::IsArray(shape));
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
@@ -126,14 +131,15 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
SetToDefaultLayout(&element_shape);
}
shape->clear_layout();
- } else if (ShapeUtil::IsOpaque(*shape)) {
- shape->clear_layout();
- } else {
+ } else if (ShapeUtil::IsArray(*shape)) {
shape->mutable_layout()->set_format(DENSE);
tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape->dimensions_size(), 0);
SetDefaultLayoutToContainer(minor_to_major);
+ } else {
+ // Opaque, token types etc. have no layout.
+ shape->clear_layout();
}
}
@@ -160,18 +166,20 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
}
return Status::OK();
- } else if (ShapeUtil::IsOpaque(shape)) {
- if (shape.has_layout()) {
- return InvalidArgument("opaque should not have a layout field");
- }
- return Status::OK();
- } else {
- // Array shape.
+ } else if (ShapeUtil::IsArray(shape)) {
if (!shape.has_layout()) {
return InvalidArgument("shape %s does not have a layout",
ShapeUtil::HumanString(shape).c_str());
}
return ValidateLayoutForShape(shape.layout(), shape);
+ } else {
+ // Token, opaque, etc. shape.
+ if (shape.has_layout()) {
+ return InvalidArgument(
+ "shape of primitive type %s should not have a layout",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+ return Status::OK();
}
}
@@ -181,8 +189,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return InvalidArgument("a single Layout is not valid for tuple shapes");
}
- if (ShapeUtil::IsOpaque(shape)) {
- return Status::OK();
+ if (!ShapeUtil::IsArray(shape)) {
+ return InvalidArgument(
+ "shape of primitive type %s should not have a layout",
+ PrimitiveType_Name(shape.element_type()).c_str());
}
if (layout.format() == INVALID_FORMAT) {
@@ -273,7 +283,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
}
/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
- if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) ||
+ if (!ShapeUtil::IsArray(shape) || !HasLayout(shape) ||
shape.layout().padded_dimensions_size() == 0) {
return false;
}
@@ -323,7 +333,8 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
// Tuple shape: all subshapes must have a layout.
return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
[](const Shape& s) { return HasLayout(s); });
- } else if (ShapeUtil::IsOpaque(shape)) {
+ } else if (!ShapeUtil::IsArray(shape)) {
+ // Opaque, token types etc. ignore layout.
return true;
}
return shape.has_layout() && shape.layout().format() != INVALID_FORMAT;
@@ -432,12 +443,9 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
const Shape& rhs) {
- if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
- return false;
- }
if (ShapeUtil::IsTuple(lhs)) {
- if (ShapeUtil::TupleElementCount(lhs) !=
- ShapeUtil::TupleElementCount(rhs)) {
+ if (!ShapeUtil::IsTuple(rhs) || ShapeUtil::TupleElementCount(lhs) !=
+ ShapeUtil::TupleElementCount(rhs)) {
return false;
}
for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
@@ -446,9 +454,12 @@ Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, Shape* dst) {
}
}
return true;
- } else {
+ } else if (ShapeUtil::IsArray(lhs)) {
return ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
LayoutUtil::Equal(lhs.layout(), rhs.layout());
+ } else {
+ // Layouts of non-array and non-tuple shapes is ignored.
+ return true;
}
}
diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc
index 4fd1d818e3..e4c825450d 100644
--- a/tensorflow/compiler/xla/layout_util_test.cc
+++ b/tensorflow/compiler/xla/layout_util_test.cc
@@ -218,6 +218,47 @@ TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
"elements, but shape is rank"));
}
+TEST_F(LayoutUtilTest, CopyTokenLayout) {
+ Shape src = ShapeUtil::MakeTokenShape();
+ Shape dst = ShapeUtil::MakeTokenShape();
+
+ // Layouts are trivially the same for token types and copying layouts should
+ // be a nop.
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
+TEST_F(LayoutUtilTest, CopyOpaqueLayout) {
+ Shape src = ShapeUtil::MakeOpaqueShape();
+ Shape dst = ShapeUtil::MakeOpaqueShape();
+
+ // Layouts are trivially the same for opaque types and copying layouts should
+ // be a nop.
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
+TEST_F(LayoutUtilTest, CopyTupleLayoutWithTokenAndOpaque) {
+ Shape src = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
+ Shape dst = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}), ShapeUtil::MakeTokenShape(),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeOpaqueShape(), MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
+
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
TEST_F(LayoutUtilTest, ClearLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
@@ -236,6 +277,16 @@ TEST_F(LayoutUtilTest, ClearLayoutTuple) {
EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
}
+TEST_F(LayoutUtilTest, ClearLayoutOpaqueAndToken) {
+ // Opaque and token types trivially have layouts.
+ for (Shape shape :
+ {ShapeUtil::MakeOpaqueShape(), ShapeUtil::MakeTokenShape()}) {
+ EXPECT_TRUE(LayoutUtil::HasLayout(shape));
+ LayoutUtil::ClearLayout(&shape);
+ EXPECT_TRUE(LayoutUtil::HasLayout(shape));
+ }
+}
+
TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
Shape shape = ShapeUtil::MakeTupleShape(
{MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index e8a28d76e9..ce4d0079ee 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
@@ -42,17 +41,18 @@ limitations under the License.
namespace xla {
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+
string ShapeIndex::ToString() const {
- return tensorflow::strings::StrCat(
- "{", tensorflow::str_util::Join(indices_, ","), "}");
+ return StrCat("{", tensorflow::str_util::Join(indices_, ","), "}");
}
string ShapeIndexView::ToString() const {
- return tensorflow::strings::StrCat(
- "{",
- tensorflow::str_util::Join(tensorflow::gtl::make_range(begin_, end_),
- ","),
- "}");
+ return StrCat("{",
+ tensorflow::str_util::Join(
+ tensorflow::gtl::make_range(begin_, end_), ","),
+ "}");
}
bool ShapeIndexView::operator==(const ShapeIndexView& other) const {
@@ -84,18 +84,30 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) {
namespace {
+// Returns whether the given primitive type corresponds to an array shape.
+bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
+ return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
+ primitive_type != OPAQUE && primitive_type != TOKEN;
+}
+
// Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also
// match.
bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
- if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) {
- return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) &&
- ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ VLOG(3) << "CompareShapes: lhs element type != rhs element type";
+ return false;
+ }
+
+ if (ShapeUtil::IsTuple(lhs)) {
+ return ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
[=](const Shape& l, const Shape& r) {
return CompareShapes(l, r, compare_layouts);
});
- } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) {
- return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs);
+ } else if (!ShapeUtil::IsArray(lhs)) {
+ // Non-tuple, non-array tupes such as opaque and token types are trivially
+ // the same.
+ return true;
}
if (compare_layouts) {
@@ -125,10 +137,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
VLOG(3) << "CompareShapes: lhs dimensions != rhs dimensions";
return false;
}
- if (!ShapeUtil::SameElementType(lhs, rhs)) {
- VLOG(3) << "CompareShapes: lhs element type != rhs element type";
- return false;
- }
return true;
}
@@ -171,8 +179,8 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
}
/* static */ int64 ShapeUtil::Rank(const Shape& shape) {
- CHECK(!ShapeUtil::IsTuple(shape))
- << "Tuples do not have a rank, shape: " << shape;
+ CHECK(ShapeUtil::IsArray(shape))
+ << "Non-arrays do not have a rank, shape: " << shape;
return shape.dimensions_size();
}
@@ -199,8 +207,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShape(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
- DCHECK_NE(TUPLE, element_type);
- DCHECK_NE(OPAQUE, element_type);
+ CHECK(IsArrayPrimitiveType(element_type));
Shape result;
PopulateShape(element_type, dimensions, &result);
return result;
@@ -223,8 +230,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
int64 max_sparse_elements) {
- DCHECK_NE(TUPLE, element_type);
- DCHECK_NE(OPAQUE, element_type);
+ CHECK(IsArrayPrimitiveType(element_type));
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
*shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements);
TF_DCHECK_OK(ShapeUtil::ValidateShape(shape));
@@ -271,6 +277,13 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return result;
}
+/* static */ Shape ShapeUtil::MakeTokenShape() {
+ Shape result;
+ result.set_element_type(TOKEN);
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
+ return result;
+}
+
/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
Shape* tuple_shape) {
TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
@@ -294,7 +307,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
- if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
+ if (!IsArray(shape)) {
return false;
}
return primitive_util::BitWidth(shape.element_type()) == bits;
@@ -320,6 +333,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
case C64:
case TUPLE:
case OPAQUE:
+ case TOKEN:
return false;
default:
@@ -335,6 +349,10 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return primitive_util::IsFloatingPointType(shape.element_type());
}
+/* static */ bool ShapeUtil::IsArray(const Shape& shape) {
+ return IsArrayPrimitiveType(shape.element_type());
+}
+
/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
shape.tuple_shapes().end(), IsTuple);
@@ -388,7 +406,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
}
/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
- CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape);
+ CHECK(IsArray(shape)) << ShapeUtil::HumanString(shape);
CHECK_EQ(shape.dimensions_size(), Rank(shape));
return std::accumulate<decltype(shape.dimensions().begin()), int64>(
shape.dimensions().begin(), shape.dimensions().end(), 1LL,
@@ -403,23 +421,6 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return shape.element_type() == F32 && Rank(shape) == 0;
}
-/* static */ string ShapeUtil::HumanString(const Shape& shape) {
- if (IsTuple(shape)) {
- string text = "(";
- const char* prefix = "";
- for (const Shape& elem_shape : shape.tuple_shapes()) {
- tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
- prefix = ", ";
- }
- text += ")";
- return text;
- } else {
- return tensorflow::strings::StrCat(
- tensorflow::str_util::Lowercase(
- PrimitiveType_Name(shape.element_type())),
- "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
- }
-}
namespace {
@@ -470,48 +471,56 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
} // namespace
-/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
+/* static */ string ShapeUtil::HumanString(const Shape& shape) {
if (IsTuple(shape)) {
string text = "(";
const char* prefix = "";
for (const Shape& elem_shape : shape.tuple_shapes()) {
- tensorflow::strings::StrAppend(&text, prefix,
- HumanStringWithLayout(elem_shape));
+ StrAppend(&text, prefix, HumanString(elem_shape));
prefix = ", ";
}
text += ")";
return text;
- } else {
- string result = tensorflow::strings::StrCat(
- LowercasePrimitiveTypeName(shape.element_type()), "[");
- for (int i = 0; i < shape.dimensions().size(); i++) {
- tensorflow::strings::StrAppend(&result, (i > 0) ? "," : "",
- shape.dimensions(i));
+ }
+ return StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[",
+ tensorflow::str_util::Join(shape.dimensions(), ","), "]");
+}
+
+/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
+ if (IsTuple(shape)) {
+ string text = "(";
+ const char* prefix = "";
+ for (const Shape& elem_shape : shape.tuple_shapes()) {
+ StrAppend(&text, prefix, HumanStringWithLayout(elem_shape));
+ prefix = ", ";
}
- result += "]";
- if (!IsScalar(shape) && !IsOpaque(shape)) {
- if (LayoutUtil::HasLayout(shape)) {
- tensorflow::strings::StrAppend(&result,
- LayoutUtil::HumanString(shape.layout()));
- }
+ text += ")";
+ return text;
+ }
+ string result = StrCat(LowercasePrimitiveTypeName(shape.element_type()), "[");
+ for (int i = 0; i < shape.dimensions().size(); i++) {
+ StrAppend(&result, (i > 0) ? "," : "", shape.dimensions(i));
+ }
+ result += "]";
+ if (!IsScalar(shape) && IsArray(shape)) {
+ if (LayoutUtil::HasLayout(shape)) {
+ StrAppend(&result, LayoutUtil::HumanString(shape.layout()));
}
- return result;
}
+ return result;
}
/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
std::vector<string> parameters;
for (auto& shape : program_shape.parameters()) {
const int i = parameters.size();
- parameters.push_back(
- tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
- ? program_shape.parameter_names(i)
- : "(unknown)",
- ": ", HumanString(shape)));
+ parameters.push_back(StrCat(i < program_shape.parameter_names_size()
+ ? program_shape.parameter_names(i)
+ : "(unknown)",
+ ": ", HumanString(shape)));
}
- return tensorflow::strings::StrCat(
- "(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
- HumanString(program_shape.result()));
+ return StrCat("(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
+ HumanString(program_shape.result()));
}
namespace {
@@ -581,14 +590,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the primitive element type.
TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type,
StringToPrimitiveType(element_type_string));
- if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE ||
- primitive_type == OPAQUE) {
+ if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE) {
return InvalidArgument("Invalid element type string: \"%s\".",
element_type_string.c_str());
}
Shape result;
- if (format_string.empty() && layout_string.empty()) {
+ if (primitive_type == OPAQUE) {
+ result = ShapeUtil::MakeOpaqueShape();
+ } else if (primitive_type == TOKEN) {
+ result = ShapeUtil::MakeTokenShape();
+ } else if (format_string.empty() && layout_string.empty()) {
// Create a shape without a layout set.
result = ShapeUtil::MakeShape(primitive_type, dimensions);
} else if (format_string == "sparse") {
@@ -633,43 +645,44 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
}
/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
- if (lhs.element_type() == TUPLE) {
+ if (IsArray(lhs)) {
+ return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
+ } else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
+ } else {
+ // Opaque, token, etc types are vacuously compatible.
+ return true;
}
- if (lhs.element_type() == OPAQUE) {
- return rhs.element_type() == OPAQUE;
- }
- return SameElementType(lhs, rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringElementType(const Shape& lhs,
const Shape& rhs) {
- if (lhs.element_type() == TUPLE) {
+ if (IsArray(lhs)) {
+ return IsArray(rhs) && SameDimensions(lhs, rhs);
+ } else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringElementType);
+ } else {
+ // Opaque, token, etc types are vacuously compatible.
+ return true;
}
- if (lhs.element_type() == OPAQUE) {
- return rhs.element_type() == OPAQUE;
- }
- return ShapeUtil::IsArray(rhs) && SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
- if (lhs.element_type() == TUPLE) {
+ if (IsArray(lhs)) {
+ return IsArray(rhs) && SameElementTypeIgnoringFpPrecision(lhs, rhs) &&
+ CompatibleIgnoringElementType(lhs, rhs);
+ } else if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringFpPrecision);
+ } else {
+ // Opaque, token, etc types are vacuously compatible.
+ return true;
}
- if (lhs.element_type() == OPAQUE) {
- return rhs.element_type() == OPAQUE;
- }
- if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
- return CompatibleIgnoringElementType(lhs, rhs);
- }
- return false;
}
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
@@ -691,10 +704,6 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
switch (primitive_type) {
case PRED:
return sizeof(int8);
- case TUPLE:
- LOG(FATAL) << "tuples have no definitive size";
- case OPAQUE:
- LOG(FATAL) << "opaque have no definitive size";
case S8:
return sizeof(int8);
case S16:
@@ -721,6 +730,13 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return sizeof(double);
case C64:
return sizeof(complex64);
+ case TOKEN:
+ // Tokens require no space.
+ return 0;
+ case TUPLE:
+ case OPAQUE:
+ LOG(FATAL) << PrimitiveType_Name(primitive_type)
+ << " primitive type has no definitive size";
default:
LOG(FATAL) << "Unhandled primitive type " << primitive_type;
}
@@ -729,28 +745,32 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK_NE(OPAQUE, shape.element_type());
if (shape.element_type() == TUPLE) {
return ByteSizeOfTupleIndexTable(shape, pointer_size);
+ } else if (IsArray(shape)) {
+ int64 byte_size = ByteSizeOfElements(shape);
+ if (LayoutUtil::IsSparseArray(shape)) {
+ byte_size += ByteSizeOfSparseIndices(shape);
+ }
+ return byte_size;
+ } else if (shape.element_type() == TOKEN) {
+ return 0;
}
- int64 byte_size = ByteSizeOfElements(shape);
- if (LayoutUtil::IsSparseArray(shape)) {
- byte_size += ByteSizeOfSparseIndices(shape);
- }
- return byte_size;
+ LOG(FATAL) << PrimitiveType_Name(shape.element_type())
+ << " primitive type has no definitive size";
}
/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape,
int64 pointer_size) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK_EQ(TUPLE, shape.element_type());
+ CHECK_EQ(TUPLE, shape.element_type());
CHECK_GT(pointer_size, 0);
return pointer_size * shape.tuple_shapes_size();
}
/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK(ShapeUtil::IsArray(shape));
+ CHECK(ShapeUtil::IsArray(shape));
int64 allocated_element_count;
if (LayoutUtil::IsSparseArray(shape)) {
@@ -775,13 +795,17 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) {
TF_DCHECK_OK(ValidateShape(shape));
- DCHECK(LayoutUtil::IsSparseArray(shape));
+ CHECK(LayoutUtil::IsSparseArray(shape));
return LayoutUtil::MaxSparseElements(shape.layout()) *
ShapeUtil::Rank(shape) * sizeof(int64);
}
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
+ if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+ return InvalidArgument("shape has invalid element type: %s",
+ shape.ShortDebugString().c_str());
+ }
if (shape.element_type() == TUPLE) {
if (shape.dimensions_size() != 0) {
return InvalidArgument("tuples must not have dimensions specified");
@@ -797,10 +821,24 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
if (shape.tuple_shapes_size() > 0) {
return InvalidArgument("non-tuple shape has tuple_shapes field");
}
- if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
- return InvalidArgument("shape has invalid element type: %s",
- shape.ShortDebugString().c_str());
+
+ // Tokens and opaques can should not have layout or dimensions.
+ if (shape.element_type() == TOKEN || shape.element_type() == OPAQUE) {
+ if (shape.dimensions_size() != 0) {
+ return InvalidArgument(
+ "shape has %s element type, but has dimensions field: %s",
+ LowercasePrimitiveTypeName(shape.element_type()).c_str(),
+ shape.ShortDebugString().c_str());
+ }
+ if (shape.has_layout()) {
+ return InvalidArgument(
+ "shape has %s element type, but has layout field: %s",
+ LowercasePrimitiveTypeName(shape.element_type()).c_str(),
+ shape.ShortDebugString().c_str());
+ }
+ return Status::OK();
}
+
if (Rank(shape) != shape.dimensions_size()) {
return InvalidArgument(
"shape's rank is mismatched with dimension count; rank=%lld "
@@ -902,6 +940,8 @@ bool ShapeUtil::IsLeafIndex(const Shape& shape, const ShapeIndex& index) {
}
/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
+ CHECK(IsArray(shape));
+
std::vector<int64> dimension_sizes;
std::vector<int64> degenerate_dimensions;
for (int64 i = 0; i < shape.dimensions_size(); ++i) {
@@ -1066,6 +1106,9 @@ Status ForEachMutableSubshapeHelper(
/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
const Shape& shape_post) {
+ CHECK(IsArray(shape_pre));
+ CHECK(IsArray(shape_post));
+
auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
std::vector<int64> deleted_indices;
@@ -1123,6 +1166,9 @@ ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
/* static */ std::vector<std::pair<int64, int64>>
ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
const Shape& output_shape) {
+ CHECK(IsArray(input_shape));
+ CHECK(IsArray(output_shape));
+
// Unmodified dimensions are merely common factors of rank 1.
auto common_factors = CommonFactors(AsInt64Slice(input_shape.dimensions()),
AsInt64Slice(output_shape.dimensions()));
@@ -1176,8 +1222,10 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
const Shape& output_shape) {
- CHECK(LayoutUtil::HasLayout(input_shape) &&
- LayoutUtil::HasLayout(output_shape));
+ CHECK(IsArray(input_shape));
+ CHECK(IsArray(output_shape));
+ CHECK(LayoutUtil::HasLayout(input_shape));
+ CHECK(LayoutUtil::HasLayout(output_shape));
if (!SameElementType(input_shape, output_shape)) {
return false;
@@ -1339,6 +1387,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ tensorflow::gtl::optional<Shape> ShapeUtil::AlignLayouts(
const Shape& input_shape, const Shape& output_shape) {
+ CHECK(IsArray(input_shape));
+ CHECK(IsArray(output_shape));
+
int64 input_rank = Rank(input_shape);
int64 output_rank = Rank(output_shape);
@@ -1473,6 +1524,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
Shape shape) {
+ CHECK(IsArray(shape));
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
if (LayoutUtil::HasLayout(shape)) {
Layout* layout = shape.mutable_layout();
@@ -1494,6 +1546,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::FilterDimensions(
const std::function<bool(int64)>& p, Shape shape) {
+ CHECK(IsArray(shape));
std::vector<int64> dims_to_delete;
for (int64 i = shape.dimensions().size() - 1; i >= 0; --i) {
if (!p(i)) {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 9df31d5d21..3853ada6ba 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -169,7 +169,7 @@ class ShapeUtil {
// may not actually be able to store this number of elements. See
// LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of
// elements that can be stored in a sparse shape.
- // Precondition: !IsTuple(shape)
+ // Precondition: IsArray(shape)
static int64 ElementsIn(const Shape& shape);
// Returns true if 'shape' has zero elements.
@@ -180,13 +180,11 @@ class ShapeUtil {
// shapes. This includes only the size of the top-level buffer. For example, a
// tuple is stored as an array of pointers to other buffers. In this case,
// this method only returns the size of the pointer array.
- // Precondition: (!ShapeUtil::IsTuple(shape) || pointer_size > 0) &&
- // !ShapeUtil::IsOpaque(shape)
static int64 ByteSizeOf(const Shape& shape, int64 pointer_size = -1);
// Returns the number of bytes used to store the primitive_type.
//
- // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
+ // Precondition: ShapeUtil::IsArray(shape)
static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
// Returns the number of bytes required to store the tuple member pointers for
@@ -245,7 +243,7 @@ class ShapeUtil {
}
// Returns the higher-precision element type if a and b are both floating
- // point types; otherwise, checks that they have the same element type
+ // point types; otherwise, checks that that they have the same element type
// and returns it.
static PrimitiveType HigherPrecisionElementType(const Shape& a,
const Shape& b) {
@@ -293,10 +291,10 @@ class ShapeUtil {
// Scalar-specific
static bool IsScalar(const Shape& shape) {
- return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
+ return IsArray(shape) && Rank(shape) == 0;
}
static bool IsEffectiveScalar(const Shape& shape) {
- return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
+ return IsArray(shape) && TrueRank(shape) == 0;
}
static bool IsScalarF32(const Shape& shape);
@@ -325,6 +323,10 @@ class ShapeUtil {
// into a custom operation.
static Shape MakeOpaqueShape();
+ // Creates a token shape. Values of this shape are used for ordering
+ // side-effecting operations.
+ static Shape MakeTokenShape();
+
// Appends a shape to the given tuple.
static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
@@ -424,11 +426,15 @@ class ShapeUtil {
return shape.element_type() == OPAQUE;
}
+ // Returns whether the shape is an token value used for ordering
+ // side-effecting operations.
+ static bool IsToken(const Shape& shape) {
+ return shape.element_type() == TOKEN;
+ }
+
// Returns whether the shape is an array. Note that scalars are considered
// arrays.
- static bool IsArray(const Shape& shape) {
- return !IsTuple(shape) && !IsOpaque(shape);
- }
+ static bool IsArray(const Shape& shape);
// Returns whether the shape is a tuple with at least one element which is
// also a tuple.
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
index f7675e97da..ecdb6532f1 100644
--- a/tensorflow/compiler/xla/shape_util_test.cc
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -93,12 +93,14 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) {
}
TEST(ShapeUtilTest, ParseShapeStringNestedTuple) {
- string shape_string = "(f32[1],(f32[2]), f32[3])";
+ string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
TF_ASSERT_OK_AND_ASSIGN(Shape actual,
ShapeUtil::ParseShapeString(shape_string));
Shape expected = ShapeUtil::MakeTupleShape({
ShapeUtil::MakeShape(F32, {1}),
- ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
+ ShapeUtil::MakeOpaqueShape(),
ShapeUtil::MakeShape(F32, {3}),
});
ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
@@ -136,6 +138,23 @@ TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) {
<< "actual: " << ShapeUtil::HumanString(actual);
}
+TEST(ShapeUtilTest, ParseOpaqueType) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual,
+ ShapeUtil::ParseShapeString("opaque[]"));
+ Shape expected = ShapeUtil::MakeOpaqueShape();
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, ParseTokenType) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape actual, ShapeUtil::ParseShapeString("token[]"));
+ Shape expected = ShapeUtil::MakeTokenShape();
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
TEST(ShapeUtilTest, ParseInvalidShapeString) {
string shape_strings[] = {
"f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
@@ -295,6 +314,9 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64));
EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {})));
EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20})));
+
+ EXPECT_EQ(0, ShapeUtil::ByteSizeOfPrimitiveType(TOKEN));
+ EXPECT_EQ(0, ShapeUtil::ByteSizeOf(ShapeUtil::MakeTokenShape()));
}
TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
@@ -449,19 +471,21 @@ TEST(ShapeUtilTest, IsLeafIndex) {
TEST(ShapeUtilTest, HumanString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
+ Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
- Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
+ Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
+ EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
ShapeUtil::HumanString(tuple));
- EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(nested_tuple));
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
@@ -470,8 +494,10 @@ TEST(ShapeUtilTest, HumanString) {
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
ShapeUtil::HumanStringWithLayout(tuple));
- EXPECT_EQ("((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0})",
- ShapeUtil::HumanStringWithLayout(nested_tuple));
+ EXPECT_EQ(
+ "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
+ "token[])",
+ ShapeUtil::HumanStringWithLayout(nested_tuple));
ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
@@ -481,8 +507,9 @@ TEST(ShapeUtilTest, HumanString) {
"(unknown): u32[1,2], "
"(unknown): s32[3,4], "
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
- "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
- "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
+ "-> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.add_parameter_names("arg0");
@@ -497,8 +524,10 @@ TEST(ShapeUtilTest, HumanString) {
"matrix: u32[1,2], "
"matrix2: s32[3,4], "
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
- "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
- "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
+ "token[])) "
+ "-> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index b895ac045c..6bdfb0179c 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -66,11 +66,16 @@ enum PrimitiveType {
// in the dimensions field.
TUPLE = 13;
- // An opaque type used for passing context specific data to a custom
- // operation.
+ // An opaque type used for passing context-specific data to a custom
+ // operation. Shapes of this primitive type will have empty dimensions and
+ // tuple_shapes fields.
OPAQUE = 14;
- // Next = 17
+ // A token type threaded between side-effecting operations. Shapes of this
+ // primitive type will have empty dimensions and tuple_shapes fields.
+ TOKEN = 17;
+
+ // Next = 18
}
// Describes the value held inside padding elements.