aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_verifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_verifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc51
1 files changed, 51 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index c92db0be14..04c6ba3eeb 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -123,5 +124,55 @@ TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
EXPECT_FALSE(verifier().Run(module.get()).status().ok());
}
+TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+callme {
+ ROOT param = (s32[], f32[4]) parameter(0)
+}
+
+ENTRY entry {
+ p0 = (f32[4], s32[]) parameter(0)
+ ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("shape does not match parameter"));
+}
+
+TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
+ const char* const hlo_string = R"(
+HloModule Module
+
+true_branch {
+ tparam = (s32[], f32[4]) parameter(0)
+ ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
+}
+
+false_branch {
+ fparam = (s32[], f32[4]) parameter(0)
+ ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
+}
+
+ENTRY entry {
+ p0 = (f32[4], s32[]) parameter(0)
+ constant = pred[] constant(true)
+ ROOT conditional = f32[4] conditional(constant, p0, p0),
+ true_computation=true_branch, false_computation=false_branch
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(hlo_string));
+
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("shape does not match parameter"));
+}
+
} // namespace
} // namespace xla