diff options
Diffstat (limited to 'tensorflow/core/framework/op_kernel_test.cc')
-rw-r--r-- | tensorflow/core/framework/op_kernel_test.cc | 803 |
1 files changed, 803 insertions, 0 deletions
diff --git a/tensorflow/core/framework/op_kernel_test.cc b/tensorflow/core/framework/op_kernel_test.cc new file mode 100644 index 0000000000..9400ef24f8 --- /dev/null +++ b/tensorflow/core/framework/op_kernel_test.cc @@ -0,0 +1,803 @@ +#include "tensorflow/core/framework/op_kernel.h" + +#include <memory> +#include <vector> +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include <gtest/gtest.h> + +class DummyKernel : public tensorflow::OpKernel { + public: + explicit DummyKernel(tensorflow::OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(tensorflow::OpKernelContext* context) override {} +}; + +// Test that registration works outside a namespace. +REGISTER_OP("Test1").Input("a: float").Input("b: int32").Output("o: uint8"); +REGISTER_KERNEL_BUILDER(Name("Test1").Device(tensorflow::DEVICE_CPU), + DummyKernel); + +namespace foo { +bool match_signature_ = false; + +// Test that registration works inside a different namespace. +class TestOp2 : public ::tensorflow::OpKernel { + public: + explicit TestOp2(::tensorflow::OpKernelConstruction* context) + : OpKernel(context) { + ::tensorflow::Status status = context->MatchSignature( + {::tensorflow::DT_INT32}, {::tensorflow::DT_INT32}); + match_signature_ = status.ok(); + context->SetStatus(status); + } + void Compute(::tensorflow::OpKernelContext* context) override {} +}; + +REGISTER_OP("Test2").Input("i: T").Output("o: T").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("Test2") + .Device(::tensorflow::DEVICE_GPU) + .HostMemory("i") + .HostMemory("o"), + TestOp2); +} // namespace foo + +namespace tensorflow { + +// Two operations with the same name but different devices. +REGISTER_OP("Test3").Input("a: T").Input("b: T").Attr("T: type"); + +class TestOp3Cpu : public tensorflow::OpKernel { + public: + explicit TestOp3Cpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_CPU).TypeConstraint<int8>("T"), TestOp3Cpu); + +namespace { + +class TestOp3Gpu : public tensorflow::OpKernel { + public: + explicit TestOp3Gpu(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(OpKernelContext* context) override {} +}; + +REGISTER_KERNEL_BUILDER( + Name("Test3").Device(DEVICE_GPU).TypeConstraint<float>("T"), TestOp3Cpu); + +// An Op registered for both +REGISTER_OP("Test4").Input("i: float").Output("o: float"); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("Test4").Device(DEVICE_GPU), DummyKernel); + +static std::vector<DeviceType> DeviceTypes() { + return {DeviceType(DEVICE_GPU), DeviceType(DEVICE_CPU)}; +} + +class OpKernelTest : public ::testing::Test { + public: + OpKernelTest() : device_(Env::Default()) {} + + protected: + NodeDef CreateNodeDef(const string& op_type, const DataTypeVector& inputs) { + NodeDefBuilder builder(op_type + "-op", op_type); + for (DataType dt : inputs) { + builder.Input(FakeInput(dt)); + } + NodeDef node_def; + TF_CHECK_OK(builder.Finalize(&node_def)); + return node_def; + } + + void ExpectEqual(const string& what, const DataTypeVector& expected, + const DataTypeVector& observed) { + EXPECT_EQ(expected.size(), observed.size()) << what; + const int size = std::min(expected.size(), observed.size()); + for (int i = 0; i < size; ++i) { + bool match = TypesCompatible(expected[i], observed[i]); + EXPECT_TRUE(match) << what << " i:" << i << ", expected: " << expected[i] + << ", observed: " << observed[i]; + } + } + + void ExpectSuccess(const string& op_type, DeviceType device_type, + const DataTypeVector& inputs, + const DataTypeVector& outputs) { + Status status; + std::unique_ptr<OpKernel> op( + CreateOpKernel(device_type, &device_, cpu_allocator(), + CreateNodeDef(op_type, inputs), &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + ExpectEqual("inputs", op->input_types(), inputs); + ExpectEqual("outputs", op->output_types(), outputs); + } + } + + void ExpectFailure(const string& ascii_node_def, DeviceType device_type, + error::Code code) { + NodeDef node_def; + protobuf::TextFormat::ParseFromString(ascii_node_def, &node_def); + Status status; + std::unique_ptr<OpKernel> op(CreateOpKernel( + device_type, &device_, cpu_allocator(), node_def, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + } + } + + private: + DeviceBase device_; +}; + +TEST_F(OpKernelTest, SuccessCpu) { + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT, DT_INT32}, {DT_UINT8}); + ExpectSuccess("Test1", DEVICE_CPU, {DT_FLOAT_REF, DT_INT32}, {DT_UINT8}); +} + +TEST_F(OpKernelTest, SuccessGpu) { + foo::match_signature_ = false; + ExpectSuccess("Test2", DEVICE_GPU, {DT_INT32}, {DT_INT32}); + EXPECT_TRUE(foo::match_signature_); +} + +TEST_F(OpKernelTest, SuccessBothCpuAndGpu) { + ExpectSuccess("Test3", DEVICE_CPU, {DT_INT8, DT_INT8}, {}); + ExpectSuccess("Test3", DEVICE_GPU, {DT_FLOAT, DT_FLOAT}, {}); +} + +TEST_F(OpKernelTest, CpuTypeRegistered) { + NodeDef ndef = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); +} + +TEST_F(OpKernelTest, CpuAndGpuTypeRegistered) { + { + // Try a node def of an op that is registered for a specific type + // only on CPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_INT8, DT_INT8}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[0]); + } + { + // Try a node def of an op that is registered for a specific type + // only on GPU. + NodeDef ndef = CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(1, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + } + { + // Try a node def of an op that is only registered for other types. + NodeDef ndef = CreateNodeDef("Test3", {DT_STRING, DT_STRING}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(0, devs.size()); + } + + { + // Try a node def of an op that is registered for both. + NodeDef ndef = CreateNodeDef("Test4", {DT_FLOAT}); + DeviceTypeVector devs; + ASSERT_OK(SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs)); + EXPECT_EQ(2, devs.size()); + EXPECT_EQ(DeviceType(DEVICE_GPU), devs[0]); + EXPECT_EQ(DeviceType(DEVICE_CPU), devs[1]); + } +} + +TEST_F(OpKernelTest, NotFound) { + const auto not_found = error::NOT_FOUND; + // Something with that op type name exists, but only with a + // different DeviceType. + ExpectFailure(CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_INT8, DT_INT8}).DebugString(), + DEVICE_GPU, not_found); + ExpectFailure(CreateNodeDef("Test3", {DT_FLOAT, DT_FLOAT}).DebugString(), + DEVICE_CPU, not_found); + + // No kernel with that signature registered. + ExpectFailure(CreateNodeDef("Test3", {DT_INT32, DT_INT32}).DebugString(), + DEVICE_GPU, not_found); + + // Nothing with that op type name exists. + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_CPU, not_found); + ExpectFailure("name: 'NF' op: 'Testnotfound'", DEVICE_GPU, not_found); +} + +TEST_F(OpKernelTest, TooFewInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.clear_input(); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); + node_def.add_input("a"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, TooManyInputs) { + const auto invalid = error::INVALID_ARGUMENT; + NodeDef node_def = CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}); + node_def.add_input("c"); + ExpectFailure(node_def.DebugString(), DEVICE_CPU, invalid); +} + +TEST_F(OpKernelTest, MatchSignatureFailes) { + const auto invalid = error::INVALID_ARGUMENT; + foo::match_signature_ = true; + ExpectFailure(CreateNodeDef("Test2", {DT_FLOAT}).DebugString(), DEVICE_GPU, + invalid); + EXPECT_FALSE(foo::match_signature_); +} + +class DummyDevice : public DeviceBase { + public: + DummyDevice(Env* env, bool save) : DeviceBase(env), save_(save) {} + bool SaveTemporaryTensors() const override { return save_; } + Allocator* GetAllocator(AllocatorAttributes /*attr*/) override { + return cpu_allocator(); + } + + private: + bool save_; +}; + +TEST_F(OpKernelTest, SaveTempFalse) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.device = new DummyDevice(env, false); + Status status; + std::unique_ptr<OpKernel> op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(params); + + Tensor t; + EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + EXPECT_EQ(0, ctx->num_temps()); + + delete ctx; + delete params.device; +} + +TEST_F(OpKernelTest, SaveTempTrue) { + Env* env = Env::Default(); + OpKernelContext::Params params; + params.device = new DummyDevice(env, true); + Status status; + std::unique_ptr<OpKernel> op( + CreateOpKernel(DEVICE_CPU, params.device, cpu_allocator(), + CreateNodeDef("Test1", {DT_FLOAT, DT_INT32}), &status)); + EXPECT_TRUE(status.ok()); + params.op_kernel = op.get(); + OpKernelContext* ctx = new OpKernelContext(params); + + Tensor t; + EXPECT_OK(ctx->allocate_temp(DT_FLOAT, TensorShape(), &t)); + + EXPECT_EQ(1, ctx->num_temps()); + + delete ctx; + delete params.device; +} + +class OpKernelBuilderTest : public ::testing::Test { + protected: + // Each attr is described by a "name|type|value". + NodeDef CreateNodeDef(const string& op_type, + const std::vector<string>& attrs) { + NodeDef node_def; + node_def.set_name(op_type + "-op"); + node_def.set_op(op_type); + for (const string& attr_desc : attrs) { + std::vector<string> parts = str_util::Split(attr_desc, '|'); + CHECK_EQ(parts.size(), 3); + AttrValue attr_value; + CHECK(ParseAttrValue(parts[1], parts[2], &attr_value)) << attr_desc; + node_def.mutable_attr()->insert( + AttrValueMap::value_type(parts[0], attr_value)); + } + return node_def; + } + + std::unique_ptr<OpKernel> ExpectSuccess(const string& op_type, + DeviceType device_type, + const std::vector<string>& attrs, + DataTypeSlice input_types = {}) { + Status status; + NodeDef def = CreateNodeDef(op_type, attrs); + for (size_t i = 0; i < input_types.size(); ++i) { + def.add_input("a:0"); + } + + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel() + std::unique_ptr<OpKernel> op( + CreateOpKernel(device_type, &device, cpu_allocator(), def, &status)); + EXPECT_TRUE(status.ok()) << status; + EXPECT_TRUE(op != nullptr); + if (op != nullptr) { + EXPECT_EQ(input_types.size(), op->num_inputs()); + EXPECT_EQ(0, op->num_outputs()); + } + + // Test SupportedDeviceTypesForNode() + DeviceTypeVector devices; + EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + bool found = false; + for (DeviceType dt : devices) { + if (dt == device_type) { + found = true; + } + } + EXPECT_TRUE(found) << "Missing " << device_type << " from " + << devices.size() << " devices."; + + // In case the caller wants to use the OpKernel + return op; + } + + void ExpectFailure(const string& op_type, DeviceType device_type, + const std::vector<string>& attrs, error::Code code) { + Status status; + const NodeDef def = CreateNodeDef(op_type, attrs); + Env* env = Env::Default(); + DeviceBase device(env); + + // Test CreateOpKernel(). + std::unique_ptr<OpKernel> op( + CreateOpKernel(device_type, &device, cpu_allocator(), def, &status)); + EXPECT_TRUE(op == nullptr); + EXPECT_FALSE(status.ok()); + if (!status.ok()) { + LOG(INFO) << "Status message: " << status.error_message(); + EXPECT_EQ(code, status.code()); + + // Test SupportedDeviceTypesForNode(). + DeviceTypeVector devices; + if (errors::IsNotFound(status)) { + EXPECT_OK(SupportedDeviceTypesForNode(DeviceTypes(), def, &devices)); + for (DeviceType dt : devices) { + EXPECT_NE(dt, device_type); + } + } else { + Status status2 = + SupportedDeviceTypesForNode(DeviceTypes(), def, &devices); + EXPECT_EQ(status.code(), status2.code()); + } + } + } +}; + +REGISTER_OP("BuildCPU"); +REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderCPU) { + ExpectSuccess("BuildCPU", DEVICE_CPU, {}); + ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); +} + +REGISTER_OP("BuildGPU"); +REGISTER_KERNEL_BUILDER(Name("BuildGPU").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderGPU) { + ExpectFailure("BuildGPU", DEVICE_CPU, {}, error::NOT_FOUND); + ExpectSuccess("BuildGPU", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildBoth"); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("BuildBoth").Device(DEVICE_GPU), DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderBoth) { + ExpectSuccess("BuildBoth", DEVICE_CPU, {}); + ExpectSuccess("BuildBoth", DEVICE_GPU, {}); +} + +REGISTER_OP("BuildTypeAttr").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeAttr") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeAttr) { + ExpectSuccess("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_FLOAT"}); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("BuildTypeListAttr").Attr("T: list(type)"); +REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr") + .Device(DEVICE_CPU) + .TypeConstraint<bool>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); + ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, + {"T|list(type)|[DT_BOOL, DT_BOOL]"}); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, + error::NOT_FOUND); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); + ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, + error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernel"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernel").Device(DEVICE_CPU), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernel) { + const NodeDef ndef = CreateNodeDef("DuplicateKernel", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernel", DEVICE_CPU, {}, error::INVALID_ARGUMENT); +} + +REGISTER_OP("DuplicateKernelForT").Attr("T: type"); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + DummyKernel); +REGISTER_KERNEL_BUILDER(Name("DuplicateKernelForT") + .Device(DEVICE_CPU) + .TypeConstraint<float>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, DuplicateKernelForT) { + const NodeDef ndef = + CreateNodeDef("DuplicateKernelForT", {"T|type|DT_FLOAT"}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Multiple OpKernel registrations match NodeDef")); + + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); + ExpectFailure("DuplicateKernelForT", DEVICE_CPU, {"T|type|DT_BOOL"}, + error::NOT_FOUND); +} + +REGISTER_OP("BadConstraint").Attr("dtype: type"); +REGISTER_KERNEL_BUILDER(Name("BadConstraint") + .Device(DEVICE_CPU) + // Mistake: "T" should be "dtype". + .TypeConstraint<float>("T"), + DummyKernel); + +TEST_F(OpKernelBuilderTest, BadConstraint) { + const NodeDef ndef = CreateNodeDef("BadConstraint", {}); + DeviceTypeVector devs; + Status status = SupportedDeviceTypesForNode(DeviceTypes(), ndef, &devs); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("OpKernel 'BadConstraint' has constraint on attr " + "'T' not in NodeDef")); + + ExpectFailure("BadConstraint", DEVICE_CPU, {"dtype|type|DT_FLOAT"}, + error::INVALID_ARGUMENT); +} + +class GetAttrKernel : public ::tensorflow::OpKernel { + public: + explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) { + string attr_name; + OP_REQUIRES_OK(context, context->GetAttr("attr_name", &attr_name)); + + status.emplace_back("s", context->GetAttr(attr_name, &s)); + status.emplace_back("s_list", context->GetAttr(attr_name, &s_list)); + status.emplace_back("i", context->GetAttr(attr_name, &i)); + status.emplace_back("i_list", context->GetAttr(attr_name, &i_list)); + status.emplace_back("i32", context->GetAttr(attr_name, &i32)); + status.emplace_back("i32_list", context->GetAttr(attr_name, &i32_list)); + status.emplace_back("f", context->GetAttr(attr_name, &f)); + status.emplace_back("f_list", context->GetAttr(attr_name, &f_list)); + status.emplace_back("b", context->GetAttr(attr_name, &b)); + status.emplace_back("b_list", context->GetAttr(attr_name, &b_list)); + status.emplace_back("type", context->GetAttr(attr_name, &type)); + status.emplace_back("type_list", context->GetAttr(attr_name, &type_list)); + status.emplace_back("type_vector", + context->GetAttr(attr_name, &type_vector)); + status.emplace_back("shape_proto", + context->GetAttr(attr_name, &shape_proto)); + status.emplace_back("shape_proto_list", + context->GetAttr(attr_name, &shape_proto_list)); + status.emplace_back("shape", context->GetAttr(attr_name, &shape)); + status.emplace_back("shape_list", context->GetAttr(attr_name, &shape_list)); + } + void Compute(::tensorflow::OpKernelContext* context) override {} + + void ExpectOk(std::initializer_list<string> keys) { + for (const auto& key_status : status) { + // Only the status for keys in "keys" should be ok(). + bool in_keys = false; + for (const string& key : keys) { + if (key_status.first == key) { + in_keys = true; + } + } + EXPECT_EQ(in_keys, key_status.second.ok()) + << "key_status: " << key_status.first << ", " << key_status.second; + } + } + + string s; + std::vector<string> s_list; + int64 i; + std::vector<int64> i_list; + int32 i32; + std::vector<int32> i32_list; + float f; + std::vector<float> f_list; + bool b; + std::vector<bool> b_list; + DataType type; + std::vector<DataType> type_list; + DataTypeVector type_vector; + TensorShapeProto shape_proto; + std::vector<TensorShapeProto> shape_proto_list; + TensorShape shape; + std::vector<TensorShape> shape_list; + std::vector<std::pair<string, Status>> status; +}; + +class GetAttrTest : public OpKernelBuilderTest {}; + +REGISTER_OP("GetAttrStringList") + .Attr("attr_name: string") + .Attr("a: list(string)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrStringList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, StringList) { + std::unique_ptr<OpKernel> op_kernel = + ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(string)|['foo', 'bar']"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"s_list"}); + EXPECT_EQ(std::vector<string>({"foo", "bar"}), get_attr_kernel->s_list); + + op_kernel = ExpectSuccess("GetAttrStringList", DEVICE_CPU, + {"attr_name|string|'b'", "a|list(string)|['baz']"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({}); + EXPECT_TRUE(get_attr_kernel->s_list.empty()); +} + +REGISTER_OP("GetAttrInt") + .Attr("attr_name: string") + .Attr("a: int") + .Attr("b: list(int)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrInt").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Int) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i", "i32"}); + EXPECT_EQ(35, get_attr_kernel->i); + EXPECT_EQ(35, get_attr_kernel->i32); + + op_kernel = ExpectSuccess( + "GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|35", "b|list(int)|[-1, 2, -4]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list", "i32_list"}); + EXPECT_EQ(std::vector<int64>({-1, 2, -4}), get_attr_kernel->i_list); + EXPECT_EQ(std::vector<int32>({-1, 2, -4}), get_attr_kernel->i32_list); + + // 8589934592 == 2^33, too big to fit in an int32 + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'a'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i"}); // no i32 + EXPECT_EQ(8589934592ll, get_attr_kernel->i); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr a has value 8589934592 out of range for an int32", + key_status.second.error_message()); + } + } + + op_kernel = ExpectSuccess("GetAttrInt", DEVICE_CPU, + {"attr_name|string|'b'", "a|int|8589934592", + "b|list(int)|[-8589934592]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"i_list"}); // no i32_list + EXPECT_EQ(std::vector<int64>({-8589934592ll}), get_attr_kernel->i_list); + for (const auto& key_status : get_attr_kernel->status) { + if (key_status.first == "i32_list") { + EXPECT_EQ(error::INVALID_ARGUMENT, key_status.second.code()); + EXPECT_EQ("Attr b has value -8589934592 out of range for an int32", + key_status.second.error_message()); + } + } +} + +REGISTER_OP("GetAttrShape") + .Attr("attr_name: string") + .Attr("a: shape") + .Attr("b: list(shape)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrShape").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Shape) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'a'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape", "shape_proto"}); + EXPECT_EQ(get_attr_kernel->shape_proto.ShortDebugString(), "dim { size: 3 }"); + EXPECT_EQ("[3]", get_attr_kernel->shape.ShortDebugString()); + + op_kernel = ExpectSuccess( + "GetAttrShape", DEVICE_CPU, + {"attr_name|string|'b'", "a|shape|{ dim { size: 3 } }", + "b|list(shape)|[{ dim { size:2 } }, { dim { size: 4 } }]"}); + get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"shape_list", "shape_proto_list"}); + ASSERT_EQ(2, get_attr_kernel->shape_proto_list.size()); + EXPECT_EQ(get_attr_kernel->shape_proto_list[0].ShortDebugString(), + "dim { size: 2 }"); + EXPECT_EQ(get_attr_kernel->shape_proto_list[1].ShortDebugString(), + "dim { size: 4 }"); + ASSERT_EQ(2, get_attr_kernel->shape_list.size()); + EXPECT_EQ("[2]", get_attr_kernel->shape_list[0].ShortDebugString()); + EXPECT_EQ("[4]", get_attr_kernel->shape_list[1].ShortDebugString()); +} + +REGISTER_OP("GetAttrType").Attr("attr_name: string").Attr("a: type"); +REGISTER_KERNEL_BUILDER(Name("GetAttrType").Device(DEVICE_CPU), GetAttrKernel); + +TEST_F(GetAttrTest, Type) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrType", DEVICE_CPU, {"attr_name|string|'a'", "a|type|DT_FLOAT"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + get_attr_kernel->ExpectOk({"type"}); + EXPECT_EQ(DT_FLOAT, get_attr_kernel->type); +} + +REGISTER_OP("GetAttrTypeList").Attr("attr_name: string").Attr("a: list(type)"); +REGISTER_KERNEL_BUILDER(Name("GetAttrTypeList").Device(DEVICE_CPU), + GetAttrKernel); + +TEST_F(GetAttrTest, TypeList) { + std::unique_ptr<OpKernel> op_kernel = ExpectSuccess( + "GetAttrTypeList", DEVICE_CPU, + {"attr_name|string|'a'", "a|list(type)|[DT_INT32, DT_BOOL]"}); + auto* get_attr_kernel = static_cast<GetAttrKernel*>(op_kernel.get()); + + get_attr_kernel->ExpectOk({"type_list", "type_vector"}); + ASSERT_EQ(2, get_attr_kernel->type_list.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_list[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_list[1]); + ASSERT_EQ(2, get_attr_kernel->type_vector.size()); + EXPECT_EQ(DT_INT32, get_attr_kernel->type_vector[0]); + EXPECT_EQ(DT_BOOL, get_attr_kernel->type_vector[1]); +} + +REGISTER_OP("HostMemoryTest") + .Input("a: float") + .Input("b: T") + .Input("c: N * string") + .Output("o: N * T") + .Attr("T: type") + .Attr("N: int"); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); +REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") + .Device(DEVICE_GPU) + .HostMemory("a") + .HostMemory("c") + .HostMemory("o"), + DummyKernel); + +TEST(MemoryTypesForNode, Simple) { + NodeDef node_def; + ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") + .Input(FakeInput()) + .Input(FakeInput(DT_BOOL)) + .Input(FakeInput(3)) + .Finalize(&node_def)); + MemoryTypeVector input, output; + + EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector(5, DEVICE_MEMORY), input); + EXPECT_EQ(MemoryTypeVector(3, DEVICE_MEMORY), output); + + EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, + &input, &output)); + EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY}), + input); + EXPECT_EQ(MemoryTypeVector(3, HOST_MEMORY), output); +} + +class BaseKernel : public ::tensorflow::OpKernel { + public: + explicit BaseKernel(OpKernelConstruction* context) : OpKernel(context) {} + void Compute(::tensorflow::OpKernelContext* context) override {} + virtual int Which() const = 0; +}; + +template <int WHICH> +class LabeledKernel : public BaseKernel { + public: + using BaseKernel::BaseKernel; + int Which() const override { return WHICH; } +}; + +class LabelTest : public OpKernelBuilderTest {}; + +REGISTER_OP("LabeledKernel"); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU), + LabeledKernel<0>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("one"), + LabeledKernel<1>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<2>); +REGISTER_KERNEL_BUILDER(Name("LabeledKernel").Device(DEVICE_CPU).Label("dupe"), + LabeledKernel<3>); + +TEST_F(LabelTest, Default) { + std::unique_ptr<OpKernel> op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); + auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); + EXPECT_EQ(0, get_labeled_kernel->Which()); +} + +TEST_F(LabelTest, Specified) { + std::unique_ptr<OpKernel> op_kernel = + ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); + auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); + EXPECT_EQ(1, get_labeled_kernel->Which()); +} + +TEST_F(LabelTest, Duplicate) { + ExpectFailure("LabeledKernel", DEVICE_CPU, {"_kernel|string|'dupe'"}, + error::INVALID_ARGUMENT); +} + +} // namespace +} // namespace tensorflow |