aboutsummaryrefslogtreecommitdiffhomepage
path: root/unsupported/test/cxx11_tensor_block_eval.cpp
diff options
context:
space:
mode:
authorGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-07 15:34:26 -0700
committerGravatar Eugene Zhulenev <ezhulenev@google.com>2019-10-07 15:34:26 -0700
commitf74ab8cb8de5e425ddd25f4b06657926a2ad4599 (patch)
tree21686c69f54cd402fdf6508cedcfd25750f70898 /unsupported/test/cxx11_tensor_block_eval.cpp
parent3afb640b5647654f272b1903b71877cb60ed3a78 (diff)
Add block evaluation to TensorEvalTo and fix few small bugs
Diffstat (limited to 'unsupported/test/cxx11_tensor_block_eval.cpp')
-rw-r--r--unsupported/test/cxx11_tensor_block_eval.cpp71
1 files changed, 48 insertions, 23 deletions
diff --git a/unsupported/test/cxx11_tensor_block_eval.cpp b/unsupported/test/cxx11_tensor_block_eval.cpp
index 75252362c..1dc0a9e2c 100644
--- a/unsupported/test/cxx11_tensor_block_eval.cpp
+++ b/unsupported/test/cxx11_tensor_block_eval.cpp
@@ -131,6 +131,7 @@ static void VerifyBlockEvaluator(Expression expr, GenBlockParams gen_block) {
// TensorEvaluator is needed to produce tensor blocks of the expression.
auto eval = TensorEvaluator<const decltype(expr), Device>(expr, d);
+ eval.evalSubExprsIfNeeded(nullptr);
// Choose a random offsets, sizes and TensorBlockDescriptor.
TensorBlockParams<NumDims> block_params = gen_block();
@@ -266,29 +267,6 @@ static void test_eval_tensor_reshape() {
[&shuffled]() { return SkewedInnerBlock<Layout>(shuffled); });
}
-template <typename T, int Layout>
-static void test_eval_tensor_reshape_with_bcast() {
- Index dim = internal::random<Index>(1, 100);
-
- Tensor<T, 2, Layout> lhs(1, dim);
- Tensor<T, 2, Layout> rhs(dim, 1);
- lhs.setRandom();
- rhs.setRandom();
-
- auto reshapeLhs = NByOne(dim);
- auto reshapeRhs = OneByM(dim);
-
- auto bcastLhs = OneByM(dim);
- auto bcastRhs = NByOne(dim);
-
- DSizes<Index, 2> dims(dim, dim);
-
- VerifyBlockEvaluator<T, 2, Layout>(
- lhs.reshape(reshapeLhs).broadcast(bcastLhs) +
- rhs.reshape(reshapeRhs).broadcast(bcastRhs),
- [dims]() { return SkewedInnerBlock<Layout, 2>(dims); });
-}
-
template <typename T, int NumDims, int Layout>
static void test_eval_tensor_cast() {
DSizes<Index, NumDims> dims = RandomDims<NumDims>(10, 20);
@@ -355,6 +333,52 @@ static void test_eval_tensor_padding() {
[&padded_dims]() { return SkewedInnerBlock<Layout>(padded_dims); });
}
+template <typename T, int Layout>
+static void test_eval_tensor_reshape_with_bcast() {
+ Index dim = internal::random<Index>(1, 100);
+
+ Tensor<T, 2, Layout> lhs(1, dim);
+ Tensor<T, 2, Layout> rhs(dim, 1);
+ lhs.setRandom();
+ rhs.setRandom();
+
+ auto reshapeLhs = NByOne(dim);
+ auto reshapeRhs = OneByM(dim);
+
+ auto bcastLhs = OneByM(dim);
+ auto bcastRhs = NByOne(dim);
+
+ DSizes<Index, 2> dims(dim, dim);
+
+ VerifyBlockEvaluator<T, 2, Layout>(
+ lhs.reshape(reshapeLhs).broadcast(bcastLhs) +
+ rhs.reshape(reshapeRhs).broadcast(bcastRhs),
+ [dims]() { return SkewedInnerBlock<Layout, 2>(dims); });
+}
+
+template <typename T, int Layout>
+static void test_eval_tensor_forced_eval() {
+ Index dim = internal::random<Index>(1, 100);
+
+ Tensor<T, 2, Layout> lhs(dim, 1);
+ Tensor<T, 2, Layout> rhs(1, dim);
+ lhs.setRandom();
+ rhs.setRandom();
+
+ auto bcastLhs = OneByM(dim);
+ auto bcastRhs = NByOne(dim);
+
+ DSizes<Index, 2> dims(dim, dim);
+
+ VerifyBlockEvaluator<T, 2, Layout>(
+ (lhs.broadcast(bcastLhs) + rhs.broadcast(bcastRhs)).eval().reshape(dims),
+ [dims]() { return SkewedInnerBlock<Layout, 2>(dims); });
+
+ VerifyBlockEvaluator<T, 2, Layout>(
+ (lhs.broadcast(bcastLhs) + rhs.broadcast(bcastRhs)).eval().reshape(dims),
+ [dims]() { return RandomBlock<Layout, 2>(dims, 1, 50); });
+}
+
// -------------------------------------------------------------------------- //
// Verify that assigning block to a Tensor expression produces the same result
// as an assignment to TensorSliceOp (writing a block is is identical to
@@ -482,6 +506,7 @@ EIGEN_DECLARE_TEST(cxx11_tensor_block_eval) {
CALL_SUBTESTS_DIMS_LAYOUTS(test_eval_tensor_padding);
CALL_SUBTESTS_LAYOUTS(test_eval_tensor_reshape_with_bcast);
+ CALL_SUBTESTS_LAYOUTS(test_eval_tensor_forced_eval);
CALL_SUBTESTS_DIMS_LAYOUTS(test_assign_to_tensor);
CALL_SUBTESTS_DIMS_LAYOUTS(test_assign_to_tensor_reshape);