diff options
Diffstat (limited to 'tensorflow/core/framework/node_def_builder_test.cc')
-rw-r--r-- | tensorflow/core/framework/node_def_builder_test.cc | 1036 |
1 files changed, 1036 insertions, 0 deletions
diff --git a/tensorflow/core/framework/node_def_builder_test.cc b/tensorflow/core/framework/node_def_builder_test.cc new file mode 100644 index 0000000000..6fd4a8d1ed --- /dev/null +++ b/tensorflow/core/framework/node_def_builder_test.cc @@ -0,0 +1,1036 @@ +#include "tensorflow/core/framework/node_def_builder.h" + +#include <memory> +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/op_def_util.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include <gtest/gtest.h> + +namespace tensorflow { +namespace { + +class NodeDefBuilderTest : public ::testing::Test { + protected: + // Specify an OpDef via an OpDefBuilder. + void Op(const OpDefBuilder& op_def_builder) { + EXPECT_OK(op_def_builder.Finalize(&op_def_)); + } + + // Resets builder_ with a new NodeDefBuilder using the Op from the last call + // to Op() above. + NodeDefBuilder& Builder() { + EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()"; + builder_.reset(new NodeDefBuilder("n", &op_def_)); + return *builder_; + } + + // Calls Finalize() and verifies it returns success and the result matches + // expectations. + void ExpectSuccess(const NodeDefBuilder& builder, + DataTypeSlice expected_in_types, + DataTypeSlice expected_out_types, StringPiece proto) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_OK(status); + if (!status.ok()) return; + NodeDef expected; + protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto), + &expected); + EXPECT_EQ(node_def.DebugString(), expected.DebugString()); + + DataTypeVector in_types, out_types; + status = + InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types); + EXPECT_OK(status); + if (!status.ok()) return; + EXPECT_EQ(DataTypeSliceString(expected_in_types), + DataTypeVectorString(in_types)); + EXPECT_EQ(DataTypeSliceString(expected_out_types), + DataTypeVectorString(out_types)); + + status = ValidateNodeDef(node_def, op_def_); + EXPECT_OK(status); + } + + // Calls Finalize() and verifies it returns an error. + // Each message must appear as a substring of the error. + void ExpectFailures(const NodeDefBuilder& builder, + const std::vector<string>& messages) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + for (const string& message : messages) { + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << status << ", " << message; + } + } + + // Calls Finalize() and verifies it returns an error. + // Message must appear as a substring of the error. + void ExpectFailure(const NodeDefBuilder& builder, const string& message) { + ExpectFailures(builder, {message}); + } + + // Like ExpectFailure(), except that the error can come from + // ValidateNodeDef(). + void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { + NodeDef node_def; + Status status = builder.Finalize(&node_def); + if (status.ok()) { + status = ValidateNodeDef(node_def, op_def_); + } + EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); + if (status.ok()) return; + EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) + << "Actual error: " << status.error_message() + << "\nDoes not contain: " << message; + } + + OpDef op_def_; + std::unique_ptr<NodeDefBuilder> builder_; +}; + +TEST_F(NodeDefBuilderTest, Simple) { + Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float")); + + ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "x" )proto"); + + // Port != 0 + ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "y:2" )proto"); + + // FakeInput + ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT}, + R"proto( op: "Simple" input: "a" )proto"); + + // Ref input + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32}, + {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); + + // ControlInput + ExpectSuccess( + Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"), + {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: ["a", "^x", "^y"] )proto"); + + // Device + ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32}, + {DT_FLOAT}, R"proto( + op: "Simple" input: "a" device: "ddd" )proto"); + + // Extra input + ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32), + "More Input() calls than the 1 input_args while building " + "NodeDef 'n' using Op<name=Simple; signature=a:int32 -> " + "out:float>"); + + // Missing input + ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while"); + + { // Finalize() twice. + NodeDefBuilder& builder = Builder(); + builder.Input(FakeInput()).Finalize(nullptr); // First call to Finalize() + // ExpectSuccess() also calls Finalize(). + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + } + + { // Input() after Finalize() + NodeDefBuilder& builder = Builder(); + // Calling Finalize() before enough inputs -> error. + ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while"); + builder.Input(FakeInput()); + // Calling Finalize() with enough inputs -> success + ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( + op: "Simple" input: "a" )proto"); + // Calling Finalize() with too many inputs -> error. + builder.Input(FakeInput(DT_INT32)); + ExpectFailure(builder, "More Input() calls than the 1 input_args while"); + } + + // Wrong input type + ExpectFailure(Builder().Input("x", 0, DT_FLOAT), + "Input 'a' passed float expected int32 "); + + ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF), + "Input 'a' passed float_ref expected int32 "); + + // List input + ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)), + "List provided to input 'a' when single Tensor expected while"); + + ExpectFailure(Builder().Input(FakeInput(3)), + "List provided to input 'a' when single Tensor expected while"); + + // Bad ControlInput + ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"), + "Control input '^z:2' must not have ':' in NodeDef:"); + + // Bad input name + ExpectFailure(Builder().Input("", 0, DT_INT32), + "Empty input node name while"); + + ExpectFailure(Builder().Input("^x", 0, DT_INT32), + "Non-control input starting with ^: ^x while"); +} + +TEST_F(NodeDefBuilderTest, OpDoesNotExist) { + NodeDefBuilder builder("n", "Op Does Not Exist"); + builder.Input(FakeInput()) + .Input(FakeInput(12)) + .ControlInput("y") + .Attr("foo", 12) + .Device("device"); + ExpectFailure( + builder, + "Op type not registered 'Op Does Not Exist' while building NodeDef 'n'"); +} + +TEST_F(NodeDefBuilderTest, Polymorphic) { + Op(OpDefBuilder("Polymorphic") + .Input("v: T") + .Output("out: T") + .Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT}, + R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant Attr() + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL), + {DT_BOOL}, {DT_BOOL}, R"proto( + op: "Polymorphic" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + // Conficting Attr() + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); + + ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while"); + + ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)), + "Inconsistent values for attr 'T' 12 vs. DT_BOOL while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicOut) { + Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type")); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Redundant attr + ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {}, + {DT_FLOAT}, R"proto( + op: "PolymorphicOut" + attr { key: "T" value { type: DT_FLOAT } } )proto"); + + // Conflicting attr + ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'T' from"); + + // Attr has the wrong type + ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}), + "AttrValue had value with type list(type) when type expected"); + + ExpectInvalid(Builder().Attr("T", 12), + "AttrValue had value with type int when type expected"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) { + Op(OpDefBuilder("PolymorphicDefaultOut") + .Output("out: T") + .Attr("T: type = DT_STRING")); + + ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto( + op: "PolymorphicDefaultOut" + attr { key: "T" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, Binary) { + Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr( + "T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)), + {DT_INT32, DT_INT32}, {DT_INT32}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()), + {DT_STRING, DT_STRING}, {DT_STRING}, R"proto( + op: "Binary" input: "a" input: "b" + attr { key: "T" value { type: DT_STRING } } )proto"); + + // Type mismatch + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, Restrict) { + Op(OpDefBuilder("Restrict") + .Input("a: T") + .Output("out: T") + .Attr("T: {string, bool}")); + ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING}, + R"proto( + op: "Restrict" input: "a" + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, TypeList) { + Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + {DT_STRING, DT_INT32}, {}, R"proto( + op: "TypeList" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } } + )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)), + {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "TypeList" input: ["a", "a:1", "a:2"] + attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } } + )proto"); + + ExpectInvalid(Builder().Input(FakeInput(0)), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput({})), + "Length for attr 'T' of 0 must be at least minimum 1"); + + ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)), + "Single tensor passed to 'a', expected list while"); + + ExpectFailures(Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer list of types for input 'a': " + "No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, TypeListNoMin) { + Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto( + op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto( + op: "TypeListNoMin" input: "a" + attr { key: "T" value { list { type: DT_BOOL } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, TypeListTwice) { + Op(OpDefBuilder("TypeListTwice") + .Input("a: T") + .Input("b: T") + .Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_BOOL})), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()), + {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( + op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] + attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); + + ExpectFailure(Builder() + .Input(FakeInput({DT_INT32, DT_BOOL})) + .Input(FakeInput({DT_INT32, DT_STRING})), + "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. " + "[DT_INT32, DT_STRING] while"); +} + +TEST_F(NodeDefBuilderTest, OutTypeList) { + Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0")); + + ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: DT_FLOAT } } } )proto"); + + ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {}, + {DT_STRING, DT_BOOL}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto( + op: "OutTypeList" + attr { key: "T" value { list { } } } )proto"); + + ExpectInvalid(Builder().Attr("T", DT_FLOAT), + "AttrValue had value with type type when list(type) expected"); +} + +TEST_F(NodeDefBuilderTest, TypeListRestrict) { + Op(OpDefBuilder("TypeListRestrict") + .Input("a: T") + .Attr("T: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})), + {DT_STRING, DT_BOOL}, {}, R"proto( + op: "TypeListRestrict" input: ["a", "a:1"] + attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, OutTypeListRestrict) { + Op(OpDefBuilder("OutTypeListRestrict") + .Output("out: t") + .Attr("t: list({string, bool}) >= 0")); + + ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {}, + {DT_BOOL, DT_STRING}, R"proto( + op: "OutTypeListRestrict" + attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto"); + + ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}), + "Value for attr 't' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, Attr) { + Op(OpDefBuilder("Attr").Attr("a: int")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "Attr" attr { key: "a" value { i: 12 } } )proto"); + + // Attr has wrong type + ExpectInvalid(Builder().Attr("a", "bad"), + "AttrValue had value with type string when int expected"); + + ExpectInvalid(Builder().Attr("a", {12}), + "AttrValue had value with type list(int) when int expected"); + + // Missing attr + ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<"); + + // Wrong attr + ExpectInvalid(Builder().Attr("b", 12), + "NodeDef mentions attr 'b' not in Op<"); + + // Extra attr + ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12), + "NodeDef mentions attr 'extra' not in Op<"); +} + +TEST_F(NodeDefBuilderTest, AttrFloat) { + Op(OpDefBuilder("AttrFloat").Attr("a: float")); + + ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto( + op: "AttrFloat" attr { key: "a" value { f: 1.2 } } + )proto"); + + // Won't automatically cast int to float + ExpectInvalid(Builder().Attr("a", 12), + "AttrValue had value with type int when float expected"); +} + +TEST_F(NodeDefBuilderTest, AttrBoolList) { + Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)")); + + ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto( + op: "AttrBoolList" + attr { key: "a" value { list { b: [true, false, true] } } } + )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector<bool>()), {}, {}, R"proto( + op: "AttrBoolList" attr { key: "a" value { list { } } } + )proto"); + + // Won't cast int -> bool. + ExpectInvalid(Builder().Attr("a", {0}), + "AttrValue had value with type list(int) when list(bool) " + "expected"); +} + +TEST_F(NodeDefBuilderTest, AttrMin) { + Op(OpDefBuilder("AttrMin").Attr("a: int >= 5")); + + ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( + op: "AttrMin" attr { key: "a" value { i: 12 } } )proto"); + + ExpectInvalid(Builder().Attr("a", 2), + "Value for attr 'a' of 2 must be at least minimum 5"); +} + +TEST_F(NodeDefBuilderTest, AttrListMin) { + Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2")); + + ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto( + op: "AttrListMin" + attr { key: "a" value { list { i: [1, 2] } } } )proto"); + + ExpectInvalid(Builder().Attr("a", {17}), + "Length for attr 'a' of 1 must be at least minimum 2"); +} + +TEST_F(NodeDefBuilderTest, AttrEnum) { + Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}")); + + ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto( + op: "AttrEnum" + attr { key: "a" value { s: "oranges" } } )proto"); + + ExpectInvalid( + Builder().Attr("a", "invalid"), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrEnumList) { + Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})")); + + ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto( + op: "AttrEnumList" + attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto"); + + ExpectInvalid( + Builder().Attr("a", {"apples", "invalid", "oranges"}), + "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " + "\"apples\", \"oranges\""); +} + +TEST_F(NodeDefBuilderTest, AttrShape) { + Op(OpDefBuilder("AttrShape").Attr("a: shape")); + + ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { dim { size: 5 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {}, + R"proto( + op: "AttrShape" + attr { key: "a" value { shape { + dim { size: 3 } dim { size: 2 } } } } )proto"); + + ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto( + op: "AttrShape" + attr { key: "a" value { shape { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrDefault) { + Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "banana" } } )proto"); + + ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto( + op: "AttrDefault" + attr { key: "a" value { s: "kiwi" } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrManyDefault) { + Op(OpDefBuilder("AttrManyDefault") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrManyDefault" + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultWithMandatory") + .Attr("a: string = 'banana'") + .Attr("b: string = 'kiwi'") + .Attr("c: string")); + + ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto( + op: "AttrManyDefaultWithMandatory" + attr { key: "c" value { s: "strawberry" } } + attr { key: "a" value { s: "banana" } } + attr { key: "b" value { s: "kiwi" } } )proto"); + + Op(OpDefBuilder("AttrManyDefaultAndInferred") + .Input("input: T") + .Attr("T: {float, double}") + .Attr("a: string") + .Attr("b: list(string) >= 1") + .Attr("c: bool = true") + .Attr("d: float = 0.3") + .Attr("e: string") + .Attr("f: float = 0.25")); + + ExpectSuccess(Builder() + .Input(FakeInput(DT_FLOAT)) + .Attr("a", "foo") + .Attr("e", "foo") + .Attr("b", std::vector<string>({"bar", "baz"})) + .Attr("f", 1.0f), + {DT_FLOAT}, {}, R"proto( + op: "AttrManyDefaultAndInferred" + input: "a" + attr { key: "T" value { type: DT_FLOAT } } + attr { key: "a" value { s: "foo" } } + attr { key: "e" value { s: "foo" } } + attr { key: "b" value { list { s: "bar" s: "baz" } } } + attr { key: "f" value { f: 1.0 } } + attr { key: "c" value { b: true } } + attr { key: "d" value { f: 0.3 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrListDefault) { + Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: [5, 15] } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto( + op: "AttrListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) { + Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []")); + + ExpectSuccess(Builder(), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); + + ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { i: 3 } } } )proto"); + + ExpectSuccess(Builder().Attr("a", std::vector<int>()), {}, {}, R"proto( + op: "AttrEmptyListDefault" + attr { key: "a" value { list { } } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NIntsIn) { + Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {}, + R"proto( + op: "NIntsIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NIntsIn" + input: ["a", "a:1", "a:2", "a:3", "a:4"] + attr { key: "N" value { i: 5 } } )proto"); + + ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)), + {"2 errors while building NodeDef", + "Input 'a' passed string expected int32"}); + + ExpectInvalid(Builder().Input(FakeInput(1)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailures( + Builder().Input(FakeInput(DT_INT32)), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectFailures( + Builder().Input(FakeInput()), + {"2 errors while building NodeDef", + "Could not infer length of input 'a': No attr named 'N' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicIn) { + Op(OpDefBuilder("NPolymorphicIn") + .Input("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32}, + {}, R"proto( + op: "NPolymorphicIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectFailures( + Builder().Input(FakeInput(2)), + {"2 errors while building NodeDef", + "Could not infer type for input 'a': No attr named 'T' in NodeDef:", + "0 inputs specified of 1 inputs in Op"}); + + ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})), + "Input 'a' passed string expected int32 while"); + + ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), + "Input 'a' passed string expected int32 while"); + + ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectFailure(Builder().Input("in", 0, DT_INT32), + "Single tensor passed to 'a', expected list while"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) { + Op(OpDefBuilder("NPolymorphicRestrictIn") + .Input("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {}, + R"proto( + op: "NPolymorphicRestrictIn" input: ["a", "a:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), + {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( + op: "NPolymorphicRestrictIn" + input: ["a", "a:1", "a:2"] + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, NInTwice) { + Op(OpDefBuilder("NInTwice") + .Input("a: N*int32") + .Input("b: N*string") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)), + {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto( + op: "NInTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, + R"proto( + op: "NInTwice" attr { key: "N" value { i: 0 } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) { + Op(OpDefBuilder("NInPolymorphicTwice") + .Input("a: N*T") + .Input("b: N*T") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "NInPolymorphicTwice" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) { + Op(OpDefBuilder("NInTwoTypeVariables") + .Input("a: N*S") + .Input("b: N*T") + .Attr("S: type") + .Attr("T: type") + .Attr("N: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)), + {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( + op: "NInTwoTypeVariables" + input: ["a", "a:1", "b", "b:1"] + attr { key: "N" value { i: 2 } } + attr { key: "S" value { type: DT_INT32 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)), + "Inconsistent values for attr 'N' 3 vs. 1 while"); +} + +TEST_F(NodeDefBuilderTest, InPolymorphicTwice) { + Op(OpDefBuilder("InPolymorphicTwice") + .Input("a: N*T") + .Input("b: M*T") + .Attr("T: type") + .Attr("N: int >= 0") + .Attr("M: int >= 0")); + + ExpectSuccess( + Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)), + {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( + op: "InPolymorphicTwice" + input: ["a", "b", "b:1", "b:2"] + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_INT32 } } + attr { key: "M" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "a" + attr { key: "N" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } + attr { key: "M" value { i: 0 } } )proto"); + + ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)), + {DT_BOOL}, {}, R"proto( + op: "InPolymorphicTwice" input: "b" + attr { key: "N" value { i: 0 } } + attr { key: "M" value { i: 1 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure( + Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), + "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); +} + +TEST_F(NodeDefBuilderTest, NIntsOut) { + Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32}, + R"proto( + op: "NIntsOut" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid(Builder().Attr("N", {3}), + "AttrValue had value with type list(int) when int expected"); + + ExpectInvalid(Builder(), "NodeDef missing attr 'N' from"); +} + +TEST_F(NodeDefBuilderTest, NIntsOutDefault) { + Op(OpDefBuilder("NIntsOutDefault") + .Output("a: N*int32") + .Attr("N: int >= 2 = 3")); + + ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 3 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( + op: "NIntsOutDefault" + attr { key: "N" value { i: 2 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOut) { + Op(OpDefBuilder("NPolymorphicOut") + .Output("a: N*T") + .Attr("T: type") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {}, + {DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOut" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {}, + {DT_STRING, DT_STRING, DT_STRING}, R"proto( + op: "NPolymorphicOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_STRING } } )proto"); + + ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING), + "Value for attr 'N' of 1 must be at least minimum 2"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}), + "AttrValue had value with type list(type) when type expected"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) { + Op(OpDefBuilder("NPolymorphicOutDefault") + .Output("a: N*T") + .Attr("T: type = DT_BOOL") + .Attr("N: int >= 2 = 2")); + + ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_BOOL } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32}, + R"proto( + op: "NPolymorphicOutDefault" + attr { key: "T" value { type: DT_INT32 } } + attr { key: "N" value { i: 2 } } )proto"); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {}, + {DT_INT32, DT_INT32, DT_INT32}, R"proto( + op: "NPolymorphicOutDefault" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_INT32 } } )proto"); +} + +TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) { + Op(OpDefBuilder("NPolymorphicRestrictOut") + .Output("a: N*T") + .Attr("T: {string, bool}") + .Attr("N: int >= 2")); + + ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {}, + {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( + op: "NPolymorphicRestrictOut" + attr { key: "N" value { i: 3 } } + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32), + "Value for attr 'T' of int32 is not in the list of allowed " + "values: string, bool"); +} + +TEST_F(NodeDefBuilderTest, RefIn) { + Op(OpDefBuilder("RefIn").Input("a: Ref(int32)")); + + ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {}, + R"proto( + op: "RefIn" input: "a" )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)), + "Input 'a' passed bool_ref expected int32_ref while"); + + ExpectFailure(Builder().Input(FakeInput(DT_INT32)), + "Input 'a' passed int32 expected int32_ref while"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefIn) { + Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type")); + + ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {}, + R"proto( + op: "PolymorphicRefIn" input: "a" + attr { key: "T" value { type: DT_BOOL } } )proto"); + + ExpectFailure(Builder().Input(FakeInput(DT_BOOL)), + "Input 'a' passed bool expected ref type while"); +} + +TEST_F(NodeDefBuilderTest, RefOut) { + Op(OpDefBuilder("RefOut").Output("a: Ref(string)")); + + ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto( + op: "RefOut" )proto"); +} + +TEST_F(NodeDefBuilderTest, PolymorphicRefOut) { + Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type")); + + ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto( + op: "PolymorphicRefOut" + attr { key: "t" value { type: DT_BOOL } } )proto"); +} + +TEST_F(NodeDefBuilderTest, SpecifyDevice) { + Op(OpDefBuilder("SpecifyDevice")); + + ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto( + op: "SpecifyDevice" device: "ADevice" )proto"); +} + +} // namespace +} // namespace tensorflow |