aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant_op_registry_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry_test.cc')
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc96
1 files changed, 50 insertions, 46 deletions
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62c0e..b2443e8676 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@ struct VariantValue {
int value;
};
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
- VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+ VariantValue::CPUToGPUCopyFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
- "TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
- "TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- VariantValue, "TEST VariantValue",
- VariantValue::CPUAddFn);
+ VariantValue, VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- VariantValue, "TEST VariantValue",
- VariantValue::GPUAddFn);
+ VariantValue, VariantValue::GPUAddFn);
} // namespace
TEST(VariantOpShapeRegistryTest, TestBasic) {
- EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+ class Blah {};
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);
- auto* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+ auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;
@@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
- string kTypeName = "fjfjfj";
- registry.RegisterShapeFn(kTypeName, f);
- EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
- "fjfjfj already registered");
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+ registry.RegisterShapeFn(kTypeIndex, f);
+ EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
// No registered copy fn for GPU<->GPU.
- EXPECT_EQ(
- UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
- nullptr);
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+ MakeTypeIndex<VariantValue>()),
+ nullptr);
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE,
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(copy_to_gpu_fn, nullptr);
VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
- kTypeName, f);
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
- "fjfjfj already registered");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_CPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_CPU, kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_GPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_GPU, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
- return;
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
}
} // namespace tensorflow