diff options
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry_test.cc')
-rw-r--r-- | tensorflow/core/framework/variant_op_registry_test.cc | 96 |
1 files changed, 50 insertions, 46 deletions
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 7055e62c0e..b2443e8676 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -89,41 +89,37 @@ struct VariantValue { int value; }; -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", - VariantValue::ShapeFn); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn); REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "TEST VariantValue", VariantValue::CPUToGPUCopyFn); + VariantValue::CPUToGPUCopyFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, VariantValue, - "TEST VariantValue", VariantValue::CPUZerosLikeFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, VariantValue, - "TEST VariantValue", VariantValue::GPUZerosLikeFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - VariantValue, "TEST VariantValue", - VariantValue::CPUAddFn); + VariantValue, VariantValue::CPUAddFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - VariantValue, "TEST VariantValue", - VariantValue::GPUAddFn); + VariantValue, VariantValue::GPUAddFn); } // namespace TEST(VariantOpShapeRegistryTest, TestBasic) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), + class Blah {}; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()), nullptr); - auto* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); + auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn( + MakeTypeIndex<VariantValue>()); EXPECT_NE(shape_fn, nullptr); TensorShape shape; @@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { TEST(VariantOpShapeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantShapeFn f; - string kTypeName = "fjfjfj"; - registry.RegisterShapeFn(kTypeName, f); - EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), - "fjfjfj already registered"); + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); + registry.RegisterShapeFn(kTypeIndex, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpDecodeRegistryTest, TestBasic) { @@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) { TEST(VariantOpCopyToGPURegistryTest, TestBasic) { // No registered copy fn for GPU<->GPU. - EXPECT_EQ( - UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), - nullptr); + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, + MakeTypeIndex<VariantValue>()), + nullptr); auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, + MakeTypeIndex<VariantValue>()); EXPECT_NE(copy_to_gpu_fn, nullptr); VariantValue vv{true /* early_exit */}; @@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) { TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, - kTypeName, f); + kTypeIndex, f); EXPECT_DEATH(registry.RegisterDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), - "fjfjfj already registered"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantUnaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_CPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_CPU, kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_GPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_GPU, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpAddRegistryTest, TestBasicCPU) { - return; + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpAddRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) { TEST(VariantOpAddRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantBinaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); } } // namespace tensorflow |