aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-06-16 00:06:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-16 00:09:15 -0700
commit1c697bc9094365cf5dab1ec1550eba019dffa3b8 (patch)
tree01d75f03cdc9919b0ec084ffe00687831be0199c /tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
parent990e1f218c7180b2ebf407b8ec06d59936e9cc12 (diff)
Teach gather-reshape folding to work with degenerate dims
I was hoping not to do this, but the motivating benchmark for all this work has reshapes on degenerate dimensions. This also forced me to introduce a new node to the analysis which isn't great (we don't want to replicate HLO inside IndexedArrayAnalysis!) but this is cleanest solution I can think of. In brief I support gather-reshape folding with degenerate dimensions by disallowing it in the core tricky part of the algorithm and instead reshaping the degenerate dimensions "in and out" in a helper that calls the core part of the folding logic. Also worth calling out that before we weren't doing something conservative -- we were just buggy. For instance the CHECK_NE(candidate_operand_dim, 0) in ComputeReshapePassthroughDimPairs can fail with degenerate dims. I also made some other supporting changes: - I was not checking window bounds in ComputeArrayForGather. I've fixed this and beefed up testing in this area (the hammer for all my nails). - Added a bunch of VLOG(3) info that was useful when debugging. - Added a simple helper to the test that makes the strings I'm matching against "whitespace insensitive" so that I can indent these. I'm happy to pull these out into separate CLs if that makes reviewing easier but for now I took the path of least resistance. :) PiperOrigin-RevId: 200821883
Diffstat (limited to 'tensorflow/compiler/xla/service/indexed_array_analysis_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis_test.cc313
1 files changed, 306 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
index 373556ebeb..fc2befe05b 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis_test.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <ctype.h>
+
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
@@ -34,6 +36,27 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
}
private:
+ // Replaces seqences of whitespace with a single space. This makes the
+ // strings being matched against "whitespace insensitive" which lets us indent
+ // them for readability.
+ string CanonicalizeWhitespace(const string& text) {
+ string result;
+
+ for (char c : text) {
+ if (!isspace(c)) {
+ result.push_back(c);
+ } else if (!result.empty() && result.back() != ' ') {
+ result.push_back(' ');
+ }
+ }
+
+ while (!result.empty() && result.back() == ' ') {
+ result.pop_back();
+ }
+
+ return result;
+ }
+
void AssertArrayForRootExpressionIsImpl(const string& hlo_text,
const string& root_expression,
bool print_constants) {
@@ -44,10 +67,10 @@ class IndexedArrayAnalysisTest : public HloVerifiedTestBase {
IndexedArrayAnalysis::Array* const array_result,
indexed_tensor_analysis.GetArrayFor(
module().entry_computation()->root_instruction()));
- string string_result =
- indexed_tensor_analysis.ToString(array_result, print_constants);
+ string string_result = CanonicalizeWhitespace(
+ indexed_tensor_analysis.ToString(array_result, print_constants));
LOG(INFO) << string_result;
- ASSERT_EQ(string_result, root_expression);
+ ASSERT_EQ(string_result, CanonicalizeWhitespace(root_expression));
}
};
@@ -91,6 +114,82 @@ ENTRY main {
hlo_text, "(scalar-indexed-const (constant s32[3,3]) %indices 0->[0])");
}
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed0) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] constant(s32[3,3]{{1,2,3},{1,2,3},{1,2,3}})
+ indices = s32[5,2] parameter(0)
+ ROOT gather = s32[5] gather(operand, indices),
+ output_window_dims={},
+ elided_window_dims={0,1},
+ gather_dims_to_operand_dims={0,1},
+ index_vector_dim=1,
+ window_bounds={1,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed1) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3,1] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,3] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0,2},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,3,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed2) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3,1] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,2,3] gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={2},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={2,3,1}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
+TEST_F(IndexedArrayAnalysisTest, GatherIsNotScalarIndexed3) {
+ string hlo_text = R"(
+HloModule SimpleGather
+
+ENTRY main {
+ operand = s32[3,3] parameter(0)
+ indices = s32[5] parameter(1)
+ ROOT gather = s32[5,2] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,2}
+}
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, "%gather");
+}
+
TEST_F(IndexedArrayAnalysisTest, GatherOfGather_OneToOne) {
string hlo_text = R"(
HloModule SimpleGather
@@ -273,7 +372,157 @@ ENTRY main {
"(scalar-indexed-const (constant s32[3,3,4]) %indices 0->[0,3])");
}
-TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative0) {
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather3) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[2,6] constant(s32[2,6]{
+ {1,2,3,4,5,6},{1,2,3,4,5,6}})
+ indices = s32[1] parameter(0)
+ gather = s32[1,6] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,6}
+ ROOT reshape = s32[1,1,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,1,6])
+ (reshape %indices to s32[])
+ 0->[])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather4) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 1, 2, 3 } })
+
+ i.0 = s64[1,3]{1,0} parameter(0)
+ g.0 = s32[1,3,3]{2,1,0} gather(operand, i.0), output_window_dims={2},
+ elided_window_dims={0}, gather_dims_to_operand_dims={0},
+ index_vector_dim=2, window_bounds={1,3}
+
+ i.1 = s64[1] parameter(1)
+ g.1 = s32[1,1,3]{2,1,0} gather(g.0, i.1), output_window_dims={0,2},
+ elided_window_dims={1}, gather_dims_to_operand_dims={1},
+ index_vector_dim=1, window_bounds={1,1,3}
+
+ ROOT reshape = s32[1,3]{1,0} reshape(g.1)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,3])
+ (reshape
+ (scalar-indexed %i.0 %i.1 1->[1])
+ to s64[])
+ 0->[])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather5) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[1,6] constant(s32[1,6]{{1,2,3,4,5,6}})
+ indices = s32[1] parameter(0)
+ gather = s32[1,6] gather(operand, indices),
+ output_window_dims={1},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=1,
+ window_bounds={1,6}
+ ROOT reshape = s32[1,1,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[1,1,1,6])
+ (reshape %indices to s32[])
+ 0->[])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather6) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[1,2,6] constant(s32[1,2,6]{{
+ {1,2,3,4,5,6},{1,2,3,4,5,6}}})
+ indices = s32[1] parameter(0)
+ gather = s32[1,1,6] gather(operand, indices),
+ output_window_dims={1,2},
+ elided_window_dims={1},
+ gather_dims_to_operand_dims={1},
+ index_vector_dim=1,
+ window_bounds={1,1,6}
+ ROOT reshape = s32[1,1,1,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,1,1,6] s32[2,1,1,1,6] {
+ { /*i0=0*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } },
+ { /*i0=1*/ { /*i1=0*/ { /*i2=0*/ {1, 2, 3, 4, 5, 6} } } } })
+ (reshape %indices to s32[])
+ 0->[])
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text,
+ expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGather7) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[2,6] constant(s32[2,6]{
+ {1,2,3,4,5,6},{1,2,3,4,5,6}})
+ indices = s32[1,5] parameter(0)
+ gather = s32[1,5,6] gather(operand, indices),
+ output_window_dims={2},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,6}
+ ROOT reshape = s32[1,1,5,6] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(scalar-indexed-const
+ (constant s32[2,1,1,6] s32[2,1,1,6] {
+ { /*i0=0*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } },
+ { /*i0=1*/ { /*i1=0*/ {1, 2, 3, 4, 5, 6} } } })
+ (reshape %indices to s32[5])
+ 0->[2])
+)";
+
+ AssertArrayWithConstantsForRootExpressionIs(hlo_text,
+ expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold0) {
string hlo_text = R"(
HloModule ReshapeOfGather
@@ -290,10 +539,19 @@ ENTRY main {
}
)";
- AssertArrayForRootExpressionIs(hlo_text, "%reshape");
+ const char* expected_root_expression = R"(
+(reshape
+ (scalar-indexed-const
+ (constant s32[3,4])
+ %indices
+ 0->[0,2])
+ to s32[5,2,2,2,3])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
}
-TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNegative1) {
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold1) {
string hlo_text = R"(
HloModule ReshapeOfGather
@@ -313,7 +571,48 @@ ENTRY main {
}
)";
- AssertArrayForRootExpressionIs(hlo_text, "%reshape");
+ const char* expected_root_expression = R"(
+(reshape
+ (scalar-indexed-const
+ (constant s32[3,5,2])
+ %indices
+ 1->[2])
+ to s32[6,7])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
+}
+
+TEST_F(IndexedArrayAnalysisTest, ReshapeOfGatherNoFold2) {
+ string hlo_text = R"(
+HloModule ReshapeOfGather
+
+ENTRY main {
+ operand = s32[3,4,1] constant(s32[3,4,1]{
+ {{1},{2},{3},{4}},
+ {{1},{2},{3},{4}},
+ {{1},{2},{3},{4}}})
+ indices = s32[5,6] parameter(0)
+ gather = s32[5,4,6,1] gather(operand, indices),
+ output_window_dims={1,3},
+ elided_window_dims={0},
+ gather_dims_to_operand_dims={0},
+ index_vector_dim=2,
+ window_bounds={1,4,1}
+ ROOT reshape = s32[5,2,2,2,3,1] reshape(gather)
+}
+)";
+
+ const char* expected_root_expression = R"(
+(reshape
+ (scalar-indexed-const
+ (constant s32[3,4,1])
+ %indices
+ 0->[0,2])
+ to s32[5,2,2,2,3,1])
+)";
+
+ AssertArrayForRootExpressionIs(hlo_text, expected_root_expression);
}
TEST_F(IndexedArrayAnalysisTest, UnaryOpOfGather) {