aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util_test.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util_test.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util_test.cc b/tensorflow/contrib/lite/toco/tooling_util_test.cc
index 8609e5bedd..eb495646a2 100644
--- a/tensorflow/contrib/lite/toco/tooling_util_test.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util_test.cc
@@ -39,6 +39,8 @@ std::vector<ShapePair> CreateShapePairs() {
{Shape({256, 256, 3}), Shape({256, 256, 3}), Agreement::kBroadcast},
{Shape({256, 256, 3}), Shape({3}), Agreement::kBroadcast},
{Shape({8, 1, 6, 1}), Shape({7, 1, 5}), Agreement::kBroadcast},
+ {Shape({}), Shape({3}), Agreement::kBroadcast},
+ {Shape({}), Shape({3, 1}), Agreement::kBroadcast},
// These extend (and therefore broadcast).
{Shape({3}), Shape({3}), Agreement::kExtend},
@@ -54,6 +56,7 @@ std::vector<ShapePair> CreateShapePairs() {
{Shape({15, 3, 5}), Shape({15, 1, 5}), Agreement::kBroadcastNotExtend},
{Shape({15, 3, 5}), Shape({3, 5}), Agreement::kBroadcastNotExtend},
{Shape({15, 3, 5}), Shape({3, 1}), Agreement::kBroadcastNotExtend},
+ {Shape({3, 1}), Shape({}), Agreement::kBroadcastNotExtend},
// These do not broadcast (and therefore also do not extend).
{Shape({3}), Shape({4}), Agreement::kNeither},
@@ -175,6 +178,20 @@ TEST(NumElementsTest, UnsignedInt64) {
EXPECT_EQ(status.error_message(), kLargeTensorMessage);
}
+TEST(NumElementsTest, Scalar) {
+ tensorflow::Status status = tensorflow::Status::OK();
+
+ int32_t count;
+ status = NumElements(std::vector<int32_t>{}, &count);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(count, 1);
+
+ uint64_t countu64;
+ status = NumElements(std::vector<uint64_t>{}, &countu64);
+ EXPECT_TRUE(status.ok());
+ EXPECT_EQ(countu64, 1ULL);
+}
+
TEST(FusedActivationTest, DefaultsToUnfused) {
EXPECT_TRUE(OperatorSupportsFusedActivation(OperatorType::kAdd));
EXPECT_FALSE(OperatorSupportsFusedActivation(OperatorType::kNone));