#include "tensorflow/core/framework/node_def_builder.h" #include #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/lib/core/status_test_util.h" #include namespace tensorflow { namespace { class NodeDefBuilderTest : public ::testing::Test { protected: // Specify an OpDef via an OpDefBuilder. void Op(const OpDefBuilder& op_def_builder) { EXPECT_OK(op_def_builder.Finalize(&op_def_)); } // Resets builder_ with a new NodeDefBuilder using the Op from the last call // to Op() above. NodeDefBuilder& Builder() { EXPECT_FALSE(op_def_.name().empty()) << "Must call Op() before Builder()"; builder_.reset(new NodeDefBuilder("n", &op_def_)); return *builder_; } // Calls Finalize() and verifies it returns success and the result matches // expectations. void ExpectSuccess(const NodeDefBuilder& builder, DataTypeSlice expected_in_types, DataTypeSlice expected_out_types, StringPiece proto) { NodeDef node_def; Status status = builder.Finalize(&node_def); EXPECT_OK(status); if (!status.ok()) return; NodeDef expected; protobuf::TextFormat::ParseFromString(strings::StrCat("name: 'n' ", proto), &expected); EXPECT_EQ(node_def.DebugString(), expected.DebugString()); DataTypeVector in_types, out_types; status = InOutTypesForNode(node_def, builder.op_def(), &in_types, &out_types); EXPECT_OK(status); if (!status.ok()) return; EXPECT_EQ(DataTypeSliceString(expected_in_types), DataTypeVectorString(in_types)); EXPECT_EQ(DataTypeSliceString(expected_out_types), DataTypeVectorString(out_types)); status = ValidateNodeDef(node_def, op_def_); EXPECT_OK(status); } // Calls Finalize() and verifies it returns an error. // Each message must appear as a substring of the error. void ExpectFailures(const NodeDefBuilder& builder, const std::vector& messages) { NodeDef node_def; Status status = builder.Finalize(&node_def); EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); if (status.ok()) return; for (const string& message : messages) { EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) << status << ", " << message; } } // Calls Finalize() and verifies it returns an error. // Message must appear as a substring of the error. void ExpectFailure(const NodeDefBuilder& builder, const string& message) { ExpectFailures(builder, {message}); } // Like ExpectFailure(), except that the error can come from // ValidateNodeDef(). void ExpectInvalid(const NodeDefBuilder& builder, const string& message) { NodeDef node_def; Status status = builder.Finalize(&node_def); if (status.ok()) { status = ValidateNodeDef(node_def, op_def_); } EXPECT_FALSE(status.ok()) << SummarizeNodeDef(node_def); if (status.ok()) return; EXPECT_TRUE(StringPiece(status.error_message()).contains(message)) << "Actual error: " << status.error_message() << "\nDoes not contain: " << message; } OpDef op_def_; std::unique_ptr builder_; }; TEST_F(NodeDefBuilderTest, Simple) { Op(OpDefBuilder("Simple").Input("a: int32").Output("out: float")); ExpectSuccess(Builder().Input("x", 0, DT_INT32), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "x" )proto"); // Port != 0 ExpectSuccess(Builder().Input("y", 2, DT_INT32), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "y:2" )proto"); // FakeInput ExpectSuccess(Builder().Input(FakeInput()), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); // Ref input ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); // ControlInput ExpectSuccess( Builder().ControlInput("x").Input(FakeInput()).ControlInput("y"), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: ["a", "^x", "^y"] )proto"); // Device ExpectSuccess(Builder().Input(FakeInput()).Device("ddd"), {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "a" device: "ddd" )proto"); // Extra input ExpectFailure(Builder().Input("x", 0, DT_INT32).Input("y", 0, DT_INT32), "More Input() calls than the 1 input_args while building " "NodeDef 'n' using Op " "out:float>"); // Missing input ExpectFailure(Builder(), "0 inputs specified of 1 inputs in Op while"); { // Finalize() twice. NodeDefBuilder& builder = Builder(); builder.Input(FakeInput()).Finalize(nullptr); // First call to Finalize() // ExpectSuccess() also calls Finalize(). ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); } { // Input() after Finalize() NodeDefBuilder& builder = Builder(); // Calling Finalize() before enough inputs -> error. ExpectFailure(builder, "0 inputs specified of 1 inputs in Op while"); builder.Input(FakeInput()); // Calling Finalize() with enough inputs -> success ExpectSuccess(builder, {DT_INT32}, {DT_FLOAT}, R"proto( op: "Simple" input: "a" )proto"); // Calling Finalize() with too many inputs -> error. builder.Input(FakeInput(DT_INT32)); ExpectFailure(builder, "More Input() calls than the 1 input_args while"); } // Wrong input type ExpectFailure(Builder().Input("x", 0, DT_FLOAT), "Input 'a' passed float expected int32 "); ExpectFailure(Builder().Input("x", 0, DT_FLOAT_REF), "Input 'a' passed float_ref expected int32 "); // List input ExpectFailure(Builder().Input(FakeInput(3, DT_FLOAT)), "List provided to input 'a' when single Tensor expected while"); ExpectFailure(Builder().Input(FakeInput(3)), "List provided to input 'a' when single Tensor expected while"); // Bad ControlInput ExpectInvalid(Builder().Input(FakeInput()).ControlInput("z:2"), "Control input '^z:2' must not have ':' in NodeDef:"); // Bad input name ExpectFailure(Builder().Input("", 0, DT_INT32), "Empty input node name while"); ExpectFailure(Builder().Input("^x", 0, DT_INT32), "Non-control input starting with ^: ^x while"); } TEST_F(NodeDefBuilderTest, OpDoesNotExist) { NodeDefBuilder builder("n", "Op Does Not Exist"); builder.Input(FakeInput()) .Input(FakeInput(12)) .ControlInput("y") .Attr("foo", 12) .Device("device"); ExpectFailure( builder, "Op type not registered 'Op Does Not Exist' while building NodeDef 'n'"); } TEST_F(NodeDefBuilderTest, Polymorphic) { Op(OpDefBuilder("Polymorphic") .Input("v: T") .Output("out: T") .Attr("T: type")); ExpectSuccess(Builder().Input(FakeInput(DT_INT32)), {DT_INT32}, {DT_INT32}, R"proto( op: "Polymorphic" input: "a" attr { key: "T" value { type: DT_INT32 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(DT_FLOAT)), {DT_FLOAT}, {DT_FLOAT}, R"proto( op: "Polymorphic" input: "a" attr { key: "T" value { type: DT_FLOAT } } )proto"); // Redundant Attr() ExpectSuccess(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_BOOL), {DT_BOOL}, {DT_BOOL}, R"proto( op: "Polymorphic" input: "a" attr { key: "T" value { type: DT_BOOL } } )proto"); // Conficting Attr() ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Attr("T", DT_STRING), "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); ExpectFailure(Builder().Attr("T", DT_STRING).Input(FakeInput(DT_BOOL)), "Inconsistent values for attr 'T' DT_STRING vs. DT_BOOL while"); ExpectFailure(Builder().Attr("T", 12).Input(FakeInput(DT_BOOL)), "Inconsistent values for attr 'T' 12 vs. DT_BOOL while"); } TEST_F(NodeDefBuilderTest, PolymorphicOut) { Op(OpDefBuilder("PolymorphicOut").Output("out: T").Attr("T: type")); ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32}, R"proto( op: "PolymorphicOut" attr { key: "T" value { type: DT_INT32 } } )proto"); ExpectSuccess(Builder().Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( op: "PolymorphicOut" attr { key: "T" value { type: DT_FLOAT } } )proto"); // Redundant attr ExpectSuccess(Builder().Attr("T", DT_FLOAT).Attr("T", DT_FLOAT), {}, {DT_FLOAT}, R"proto( op: "PolymorphicOut" attr { key: "T" value { type: DT_FLOAT } } )proto"); // Conflicting attr ExpectFailure(Builder().Attr("T", DT_BOOL).Attr("T", DT_FLOAT), "Inconsistent values for attr 'T' DT_BOOL vs. DT_FLOAT while"); // Missing attr ExpectInvalid(Builder(), "NodeDef missing attr 'T' from"); // Attr has the wrong type ExpectInvalid(Builder().Attr("T", {DT_INT32, DT_BOOL}), "AttrValue had value with type list(type) when type expected"); ExpectInvalid(Builder().Attr("T", 12), "AttrValue had value with type int when type expected"); } TEST_F(NodeDefBuilderTest, PolymorphicDefaultOut) { Op(OpDefBuilder("PolymorphicDefaultOut") .Output("out: T") .Attr("T: type = DT_STRING")); ExpectSuccess(Builder(), {}, {DT_STRING}, R"proto( op: "PolymorphicDefaultOut" attr { key: "T" value { type: DT_STRING } } )proto"); ExpectSuccess(Builder().Attr("T", DT_BOOL), {}, {DT_BOOL}, R"proto( op: "PolymorphicDefaultOut" attr { key: "T" value { type: DT_BOOL } } )proto"); } TEST_F(NodeDefBuilderTest, Binary) { Op(OpDefBuilder("Binary").Input("a: T").Input("b: T").Output("out: T").Attr( "T: type")); ExpectSuccess(Builder().Input(FakeInput(DT_INT32)).Input(FakeInput(DT_INT32)), {DT_INT32, DT_INT32}, {DT_INT32}, R"proto( op: "Binary" input: "a" input: "b" attr { key: "T" value { type: DT_INT32 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(DT_STRING)).Input(FakeInput()), {DT_STRING, DT_STRING}, {DT_STRING}, R"proto( op: "Binary" input: "a" input: "b" attr { key: "T" value { type: DT_STRING } } )proto"); // Type mismatch ExpectFailure(Builder().Input(FakeInput(DT_BOOL)).Input(FakeInput(DT_STRING)), "Inconsistent values for attr 'T' DT_BOOL vs. DT_STRING while"); } TEST_F(NodeDefBuilderTest, Restrict) { Op(OpDefBuilder("Restrict") .Input("a: T") .Output("out: T") .Attr("T: {string, bool}")); ExpectSuccess(Builder().Input(FakeInput(DT_STRING)), {DT_STRING}, {DT_STRING}, R"proto( op: "Restrict" input: "a" attr { key: "T" value { type: DT_STRING } } )proto"); ExpectInvalid(Builder().Input(FakeInput(DT_INT32)), "Value for attr 'T' of int32 is not in the list of allowed " "values: string, bool"); } TEST_F(NodeDefBuilderTest, TypeList) { Op(OpDefBuilder("TypeList").Input("a: T").Attr("T: list(type)")); ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_INT32})), {DT_STRING, DT_INT32}, {}, R"proto( op: "TypeList" input: ["a", "a:1"] attr { key: "T" value { list { type: [DT_STRING, DT_INT32] } } } )proto"); ExpectSuccess(Builder().Input(FakeInput(3, DT_BOOL)), {DT_BOOL, DT_BOOL, DT_BOOL}, {}, R"proto( op: "TypeList" input: ["a", "a:1", "a:2"] attr { key: "T" value { list { type: [DT_BOOL, DT_BOOL, DT_BOOL] } } } )proto"); ExpectInvalid(Builder().Input(FakeInput(0)), "Length for attr 'T' of 0 must be at least minimum 1"); ExpectInvalid(Builder().Input(FakeInput({})), "Length for attr 'T' of 0 must be at least minimum 1"); ExpectInvalid(Builder().Input(FakeInput(DT_BOOL)), "Single tensor passed to 'a', expected list while"); ExpectFailures(Builder().Input(FakeInput()), {"2 errors while building NodeDef", "Could not infer list of types for input 'a': " "No attr named 'T' in NodeDef:", "0 inputs specified of 1 inputs in Op"}); } TEST_F(NodeDefBuilderTest, TypeListNoMin) { Op(OpDefBuilder("TypeListNoMin").Input("a: T").Attr("T: list(type) >= 0")); ExpectSuccess(Builder().Input(FakeInput(0)), {}, {}, R"proto( op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); ExpectSuccess(Builder().Input(FakeInput(DataTypeVector())), {}, {}, R"proto( op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); ExpectSuccess(Builder().Input(FakeInput({})), {}, {}, R"proto( op: "TypeListNoMin" attr { key: "T" value { list { } } } )proto"); ExpectSuccess(Builder().Input(FakeInput({DT_BOOL})), {DT_BOOL}, {}, R"proto( op: "TypeListNoMin" input: "a" attr { key: "T" value { list { type: DT_BOOL } } } )proto"); } TEST_F(NodeDefBuilderTest, TypeListTwice) { Op(OpDefBuilder("TypeListTwice") .Input("a: T") .Input("b: T") .Attr("T: list(type) >= 0")); ExpectSuccess(Builder() .Input(FakeInput({DT_INT32, DT_BOOL})) .Input(FakeInput({DT_INT32, DT_BOOL})), {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); ExpectSuccess( Builder().Input(FakeInput({DT_INT32, DT_BOOL})).Input(FakeInput()), {DT_INT32, DT_BOOL, DT_INT32, DT_BOOL}, {}, R"proto( op: "TypeListTwice" input: ["a", "a:1", "b", "b:1"] attr { key: "T" value { list { type: [DT_INT32, DT_BOOL] } } } )proto"); ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(0)), {}, {}, R"proto( op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, R"proto( op: "TypeListTwice" attr { key: "T" value { list { } } } )proto"); ExpectFailure(Builder() .Input(FakeInput({DT_INT32, DT_BOOL})) .Input(FakeInput({DT_INT32, DT_STRING})), "Inconsistent values for attr 'T' [DT_INT32, DT_BOOL] vs. " "[DT_INT32, DT_STRING] while"); } TEST_F(NodeDefBuilderTest, OutTypeList) { Op(OpDefBuilder("OutTypeList").Output("out: T").Attr("T: list(type) >= 0")); ExpectSuccess(Builder().Attr("T", {DT_FLOAT}), {}, {DT_FLOAT}, R"proto( op: "OutTypeList" attr { key: "T" value { list { type: DT_FLOAT } } } )proto"); ExpectSuccess(Builder().Attr("T", {DT_STRING, DT_BOOL}), {}, {DT_STRING, DT_BOOL}, R"proto( op: "OutTypeList" attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); ExpectSuccess(Builder().Attr("T", DataTypeVector()), {}, {}, R"proto( op: "OutTypeList" attr { key: "T" value { list { } } } )proto"); ExpectInvalid(Builder().Attr("T", DT_FLOAT), "AttrValue had value with type type when list(type) expected"); } TEST_F(NodeDefBuilderTest, TypeListRestrict) { Op(OpDefBuilder("TypeListRestrict") .Input("a: T") .Attr("T: list({string, bool}) >= 0")); ExpectSuccess(Builder().Input(FakeInput({DT_STRING, DT_BOOL})), {DT_STRING, DT_BOOL}, {}, R"proto( op: "TypeListRestrict" input: ["a", "a:1"] attr { key: "T" value { list { type: [DT_STRING, DT_BOOL] } } } )proto"); ExpectInvalid(Builder().Input(FakeInput({DT_STRING, DT_INT32})), "Value for attr 'T' of int32 is not in the list of allowed " "values: string, bool"); } TEST_F(NodeDefBuilderTest, OutTypeListRestrict) { Op(OpDefBuilder("OutTypeListRestrict") .Output("out: t") .Attr("t: list({string, bool}) >= 0")); ExpectSuccess(Builder().Attr("t", {DT_BOOL, DT_STRING}), {}, {DT_BOOL, DT_STRING}, R"proto( op: "OutTypeListRestrict" attr { key: "t" value { list { type: [DT_BOOL, DT_STRING] } } } )proto"); ExpectInvalid(Builder().Attr("t", {DT_STRING, DT_INT32}), "Value for attr 't' of int32 is not in the list of allowed " "values: string, bool"); } TEST_F(NodeDefBuilderTest, Attr) { Op(OpDefBuilder("Attr").Attr("a: int")); ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( op: "Attr" attr { key: "a" value { i: 12 } } )proto"); // Attr has wrong type ExpectInvalid(Builder().Attr("a", "bad"), "AttrValue had value with type string when int expected"); ExpectInvalid(Builder().Attr("a", {12}), "AttrValue had value with type list(int) when int expected"); // Missing attr ExpectInvalid(Builder(), "NodeDef missing attr 'a' from Op<"); // Wrong attr ExpectInvalid(Builder().Attr("b", 12), "NodeDef mentions attr 'b' not in Op<"); // Extra attr ExpectInvalid(Builder().Attr("a", 12).Attr("extra", 12), "NodeDef mentions attr 'extra' not in Op<"); } TEST_F(NodeDefBuilderTest, AttrFloat) { Op(OpDefBuilder("AttrFloat").Attr("a: float")); ExpectSuccess(Builder().Attr("a", 1.2f /* float */), {}, {}, R"proto( op: "AttrFloat" attr { key: "a" value { f: 1.2 } } )proto"); ExpectSuccess(Builder().Attr("a", 1.2 /* double */), {}, {}, R"proto( op: "AttrFloat" attr { key: "a" value { f: 1.2 } } )proto"); // Won't automatically cast int to float ExpectInvalid(Builder().Attr("a", 12), "AttrValue had value with type int when float expected"); } TEST_F(NodeDefBuilderTest, AttrBoolList) { Op(OpDefBuilder("AttrBoolList").Attr("a: list(bool)")); ExpectSuccess(Builder().Attr("a", {true, false, true}), {}, {}, R"proto( op: "AttrBoolList" attr { key: "a" value { list { b: [true, false, true] } } } )proto"); ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( op: "AttrBoolList" attr { key: "a" value { list { } } } )proto"); // Won't cast int -> bool. ExpectInvalid(Builder().Attr("a", {0}), "AttrValue had value with type list(int) when list(bool) " "expected"); } TEST_F(NodeDefBuilderTest, AttrMin) { Op(OpDefBuilder("AttrMin").Attr("a: int >= 5")); ExpectSuccess(Builder().Attr("a", 12), {}, {}, R"proto( op: "AttrMin" attr { key: "a" value { i: 12 } } )proto"); ExpectInvalid(Builder().Attr("a", 2), "Value for attr 'a' of 2 must be at least minimum 5"); } TEST_F(NodeDefBuilderTest, AttrListMin) { Op(OpDefBuilder("AttrListMin").Attr("a: list(int) >= 2")); ExpectSuccess(Builder().Attr("a", {1, 2}), {}, {}, R"proto( op: "AttrListMin" attr { key: "a" value { list { i: [1, 2] } } } )proto"); ExpectInvalid(Builder().Attr("a", {17}), "Length for attr 'a' of 1 must be at least minimum 2"); } TEST_F(NodeDefBuilderTest, AttrEnum) { Op(OpDefBuilder("AttrEnum").Attr("a: {'apples', 'oranges'}")); ExpectSuccess(Builder().Attr("a", "oranges"), {}, {}, R"proto( op: "AttrEnum" attr { key: "a" value { s: "oranges" } } )proto"); ExpectInvalid( Builder().Attr("a", "invalid"), "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " "\"apples\", \"oranges\""); } TEST_F(NodeDefBuilderTest, AttrEnumList) { Op(OpDefBuilder("AttrEnumList").Attr("a: list({'apples', 'oranges'})")); ExpectSuccess(Builder().Attr("a", {"oranges", "apples"}), {}, {}, R"proto( op: "AttrEnumList" attr { key: "a" value { list { s: ["oranges", "apples"] } } } )proto"); ExpectInvalid( Builder().Attr("a", {"apples", "invalid", "oranges"}), "Value for attr 'a' of \"invalid\" is not in the list of allowed values: " "\"apples\", \"oranges\""); } TEST_F(NodeDefBuilderTest, AttrShape) { Op(OpDefBuilder("AttrShape").Attr("a: shape")); ExpectSuccess(Builder().Attr("a", TensorShape({5})), {}, {}, R"proto( op: "AttrShape" attr { key: "a" value { shape { dim { size: 5 } } } } )proto"); ExpectSuccess(Builder().Attr("a", TensorShape({4, 3, 2})), {}, {}, R"proto( op: "AttrShape" attr { key: "a" value { shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } )proto"); ExpectSuccess(Builder().Attr("a", TensorShape({3, 2})), {}, {}, R"proto( op: "AttrShape" attr { key: "a" value { shape { dim { size: 3 } dim { size: 2 } } } } )proto"); ExpectSuccess(Builder().Attr("a", TensorShape()), {}, {}, R"proto( op: "AttrShape" attr { key: "a" value { shape { } } } )proto"); } TEST_F(NodeDefBuilderTest, AttrDefault) { Op(OpDefBuilder("AttrDefault").Attr("a: string = 'banana'")); ExpectSuccess(Builder(), {}, {}, R"proto( op: "AttrDefault" attr { key: "a" value { s: "banana" } } )proto"); ExpectSuccess(Builder().Attr("a", "kiwi"), {}, {}, R"proto( op: "AttrDefault" attr { key: "a" value { s: "kiwi" } } )proto"); } TEST_F(NodeDefBuilderTest, AttrManyDefault) { Op(OpDefBuilder("AttrManyDefault") .Attr("a: string = 'banana'") .Attr("b: string = 'kiwi'")); ExpectSuccess(Builder(), {}, {}, R"proto( op: "AttrManyDefault" attr { key: "a" value { s: "banana" } } attr { key: "b" value { s: "kiwi" } } )proto"); Op(OpDefBuilder("AttrManyDefaultWithMandatory") .Attr("a: string = 'banana'") .Attr("b: string = 'kiwi'") .Attr("c: string")); ExpectSuccess(Builder().Attr("c", "strawberry"), {}, {}, R"proto( op: "AttrManyDefaultWithMandatory" attr { key: "c" value { s: "strawberry" } } attr { key: "a" value { s: "banana" } } attr { key: "b" value { s: "kiwi" } } )proto"); Op(OpDefBuilder("AttrManyDefaultAndInferred") .Input("input: T") .Attr("T: {float, double}") .Attr("a: string") .Attr("b: list(string) >= 1") .Attr("c: bool = true") .Attr("d: float = 0.3") .Attr("e: string") .Attr("f: float = 0.25")); ExpectSuccess(Builder() .Input(FakeInput(DT_FLOAT)) .Attr("a", "foo") .Attr("e", "foo") .Attr("b", std::vector({"bar", "baz"})) .Attr("f", 1.0f), {DT_FLOAT}, {}, R"proto( op: "AttrManyDefaultAndInferred" input: "a" attr { key: "T" value { type: DT_FLOAT } } attr { key: "a" value { s: "foo" } } attr { key: "e" value { s: "foo" } } attr { key: "b" value { list { s: "bar" s: "baz" } } } attr { key: "f" value { f: 1.0 } } attr { key: "c" value { b: true } } attr { key: "d" value { f: 0.3 } } )proto"); } TEST_F(NodeDefBuilderTest, AttrListDefault) { Op(OpDefBuilder("AttrListDefault").Attr("a: list(int) = [5, 15]")); ExpectSuccess(Builder(), {}, {}, R"proto( op: "AttrListDefault" attr { key: "a" value { list { i: [5, 15] } } } )proto"); ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( op: "AttrListDefault" attr { key: "a" value { list { i: 3 } } } )proto"); ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( op: "AttrListDefault" attr { key: "a" value { list { } } } )proto"); } TEST_F(NodeDefBuilderTest, AttrEmptyListDefault) { Op(OpDefBuilder("AttrEmptyListDefault").Attr("a: list(int) = []")); ExpectSuccess(Builder(), {}, {}, R"proto( op: "AttrEmptyListDefault" attr { key: "a" value { list { } } } )proto"); ExpectSuccess(Builder().Attr("a", {3}), {}, {}, R"proto( op: "AttrEmptyListDefault" attr { key: "a" value { list { i: 3 } } } )proto"); ExpectSuccess(Builder().Attr("a", std::vector()), {}, {}, R"proto( op: "AttrEmptyListDefault" attr { key: "a" value { list { } } } )proto"); } TEST_F(NodeDefBuilderTest, NIntsIn) { Op(OpDefBuilder("NIntsIn").Input("a: N*int32").Attr("N: int >= 2")); ExpectSuccess(Builder().Input(FakeInput(2)), {DT_INT32, DT_INT32}, {}, R"proto( op: "NIntsIn" input: ["a", "a:1"] attr { key: "N" value { i: 2 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(5, DT_INT32)), {DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( op: "NIntsIn" input: ["a", "a:1", "a:2", "a:3", "a:4"] attr { key: "N" value { i: 5 } } )proto"); ExpectFailures(Builder().Input(FakeInput(2, DT_STRING)), {"2 errors while building NodeDef", "Input 'a' passed string expected int32"}); ExpectInvalid(Builder().Input(FakeInput(1)), "Value for attr 'N' of 1 must be at least minimum 2"); ExpectFailures( Builder().Input(FakeInput(DT_INT32)), {"2 errors while building NodeDef", "Could not infer length of input 'a': No attr named 'N' in NodeDef:", "0 inputs specified of 1 inputs in Op"}); ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), "Input 'a' passed string expected int32 while"); ExpectFailures( Builder().Input(FakeInput()), {"2 errors while building NodeDef", "Could not infer length of input 'a': No attr named 'N' in NodeDef:", "0 inputs specified of 1 inputs in Op"}); } TEST_F(NodeDefBuilderTest, NPolymorphicIn) { Op(OpDefBuilder("NPolymorphicIn") .Input("a: N*T") .Attr("T: type") .Attr("N: int >= 2")); ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)), {DT_INT32, DT_INT32}, {}, R"proto( op: "NPolymorphicIn" input: ["a", "a:1"] attr { key: "N" value { i: 2 } } attr { key: "T" value { type: DT_INT32 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( op: "NPolymorphicIn" input: ["a", "a:1", "a:2"] attr { key: "N" value { i: 3 } } attr { key: "T" value { type: DT_STRING } } )proto"); ExpectFailures( Builder().Input(FakeInput(2)), {"2 errors while building NodeDef", "Could not infer type for input 'a': No attr named 'T' in NodeDef:", "0 inputs specified of 1 inputs in Op"}); ExpectFailure(Builder().Input(FakeInput({DT_INT32, DT_STRING})), "Input 'a' passed string expected int32 while"); ExpectFailure(Builder().Input({{"in", 0, DT_INT32}, {"in", 1, DT_STRING}}), "Input 'a' passed string expected int32 while"); ExpectInvalid(Builder().Input(FakeInput(1, DT_INT32)), "Value for attr 'N' of 1 must be at least minimum 2"); ExpectFailure(Builder().Input("in", 0, DT_INT32), "Single tensor passed to 'a', expected list while"); } TEST_F(NodeDefBuilderTest, NPolymorphicRestrictIn) { Op(OpDefBuilder("NPolymorphicRestrictIn") .Input("a: N*T") .Attr("T: {string, bool}") .Attr("N: int >= 2")); ExpectSuccess(Builder().Input(FakeInput(2, DT_BOOL)), {DT_BOOL, DT_BOOL}, {}, R"proto( op: "NPolymorphicRestrictIn" input: ["a", "a:1"] attr { key: "N" value { i: 2 } } attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectSuccess(Builder().Input(FakeInput(3, DT_STRING)), {DT_STRING, DT_STRING, DT_STRING}, {}, R"proto( op: "NPolymorphicRestrictIn" input: ["a", "a:1", "a:2"] attr { key: "N" value { i: 3 } } attr { key: "T" value { type: DT_STRING } } )proto"); ExpectInvalid(Builder().Input(FakeInput(2, DT_INT32)), "Value for attr 'T' of int32 is not in the list of allowed " "values: string, bool"); } TEST_F(NodeDefBuilderTest, NInTwice) { Op(OpDefBuilder("NInTwice") .Input("a: N*int32") .Input("b: N*string") .Attr("N: int >= 0")); ExpectSuccess(Builder().Input(FakeInput(2)).Input(FakeInput(2)), {DT_INT32, DT_INT32, DT_STRING, DT_STRING}, {}, R"proto( op: "NInTwice" input: ["a", "a:1", "b", "b:1"] attr { key: "N" value { i: 2 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput()), {}, {}, R"proto( op: "NInTwice" attr { key: "N" value { i: 0 } } )proto"); ExpectFailure(Builder().Input(FakeInput(3)).Input(FakeInput(1)), "Inconsistent values for attr 'N' 3 vs. 1 while"); } TEST_F(NodeDefBuilderTest, NInPolymorphicTwice) { Op(OpDefBuilder("NInPolymorphicTwice") .Input("a: N*T") .Input("b: N*T") .Attr("T: type") .Attr("N: int >= 0")); ExpectSuccess(Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput()), {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( op: "NInPolymorphicTwice" input: ["a", "a:1", "b", "b:1"] attr { key: "N" value { i: 2 } } attr { key: "T" value { type: DT_INT32 } } )proto"); ExpectFailure( Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_INT32)), "Inconsistent values for attr 'N' 3 vs. 1 while"); ExpectFailure(Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1)), "Inconsistent values for attr 'N' 3 vs. 1 while"); ExpectFailure( Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); ExpectFailure( Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_STRING)), "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); } TEST_F(NodeDefBuilderTest, NInTwoTypeVariables) { Op(OpDefBuilder("NInTwoTypeVariables") .Input("a: N*S") .Input("b: N*T") .Attr("S: type") .Attr("T: type") .Attr("N: int >= 0")); ExpectSuccess( Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_BOOL)), {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( op: "NInTwoTypeVariables" input: ["a", "a:1", "b", "b:1"] attr { key: "N" value { i: 2 } } attr { key: "S" value { type: DT_INT32 } } attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectSuccess( Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(DT_BOOL)), {DT_INT32, DT_INT32, DT_BOOL, DT_BOOL}, {}, R"proto( op: "NInTwoTypeVariables" input: ["a", "a:1", "b", "b:1"] attr { key: "N" value { i: 2 } } attr { key: "S" value { type: DT_INT32 } } attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectFailure( Builder().Input(FakeInput(3, DT_INT32)).Input(FakeInput(1, DT_STRING)), "Inconsistent values for attr 'N' 3 vs. 1 while"); } TEST_F(NodeDefBuilderTest, InPolymorphicTwice) { Op(OpDefBuilder("InPolymorphicTwice") .Input("a: N*T") .Input("b: M*T") .Attr("T: type") .Attr("N: int >= 0") .Attr("M: int >= 0")); ExpectSuccess( Builder().Input(FakeInput(1, DT_INT32)).Input(FakeInput(3, DT_INT32)), {DT_INT32, DT_INT32, DT_INT32, DT_INT32}, {}, R"proto( op: "InPolymorphicTwice" input: ["a", "b", "b:1", "b:2"] attr { key: "N" value { i: 1 } } attr { key: "T" value { type: DT_INT32 } } attr { key: "M" value { i: 3 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(1, DT_BOOL)).Input(FakeInput(0)), {DT_BOOL}, {}, R"proto( op: "InPolymorphicTwice" input: "a" attr { key: "N" value { i: 1 } } attr { key: "T" value { type: DT_BOOL } } attr { key: "M" value { i: 0 } } )proto"); ExpectSuccess(Builder().Input(FakeInput(0)).Input(FakeInput(1, DT_BOOL)), {DT_BOOL}, {}, R"proto( op: "InPolymorphicTwice" input: "b" attr { key: "N" value { i: 0 } } attr { key: "M" value { i: 1 } } attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectFailure( Builder().Input(FakeInput(2, DT_INT32)).Input(FakeInput(2, DT_STRING)), "Inconsistent values for attr 'T' DT_INT32 vs. DT_STRING while"); } TEST_F(NodeDefBuilderTest, NIntsOut) { Op(OpDefBuilder("NIntsOut").Output("a: N*int32").Attr("N: int >= 2")); ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( op: "NIntsOut" attr { key: "N" value { i: 2 } } )proto"); ExpectSuccess(Builder().Attr("N", 3), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( op: "NIntsOut" attr { key: "N" value { i: 3 } } )proto"); ExpectInvalid(Builder().Attr("N", 1), "Value for attr 'N' of 1 must be at least minimum 2"); ExpectInvalid(Builder().Attr("N", {3}), "AttrValue had value with type list(int) when int expected"); ExpectInvalid(Builder(), "NodeDef missing attr 'N' from"); } TEST_F(NodeDefBuilderTest, NIntsOutDefault) { Op(OpDefBuilder("NIntsOutDefault") .Output("a: N*int32") .Attr("N: int >= 2 = 3")); ExpectSuccess(Builder(), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( op: "NIntsOutDefault" attr { key: "N" value { i: 3 } } )proto"); ExpectSuccess(Builder().Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( op: "NIntsOutDefault" attr { key: "N" value { i: 2 } } )proto"); } TEST_F(NodeDefBuilderTest, NPolymorphicOut) { Op(OpDefBuilder("NPolymorphicOut") .Output("a: N*T") .Attr("T: type") .Attr("N: int >= 2")); ExpectSuccess(Builder().Attr("T", DT_INT32).Attr("N", 2), {}, {DT_INT32, DT_INT32}, R"proto( op: "NPolymorphicOut" attr { key: "T" value { type: DT_INT32 } } attr { key: "N" value { i: 2 } } )proto"); ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_STRING), {}, {DT_STRING, DT_STRING, DT_STRING}, R"proto( op: "NPolymorphicOut" attr { key: "N" value { i: 3 } } attr { key: "T" value { type: DT_STRING } } )proto"); ExpectInvalid(Builder().Attr("N", 1).Attr("T", DT_STRING), "Value for attr 'N' of 1 must be at least minimum 2"); ExpectInvalid(Builder().Attr("N", 3).Attr("T", {DT_STRING}), "AttrValue had value with type list(type) when type expected"); } TEST_F(NodeDefBuilderTest, NPolymorphicOutDefault) { Op(OpDefBuilder("NPolymorphicOutDefault") .Output("a: N*T") .Attr("T: type = DT_BOOL") .Attr("N: int >= 2 = 2")); ExpectSuccess(Builder(), {}, {DT_BOOL, DT_BOOL}, R"proto( op: "NPolymorphicOutDefault" attr { key: "T" value { type: DT_BOOL } } attr { key: "N" value { i: 2 } } )proto"); ExpectSuccess(Builder().Attr("N", 3), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( op: "NPolymorphicOutDefault" attr { key: "N" value { i: 3 } } attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectSuccess(Builder().Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32}, R"proto( op: "NPolymorphicOutDefault" attr { key: "T" value { type: DT_INT32 } } attr { key: "N" value { i: 2 } } )proto"); ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_INT32), {}, {DT_INT32, DT_INT32, DT_INT32}, R"proto( op: "NPolymorphicOutDefault" attr { key: "N" value { i: 3 } } attr { key: "T" value { type: DT_INT32 } } )proto"); } TEST_F(NodeDefBuilderTest, NPolymorphicRestrictOut) { Op(OpDefBuilder("NPolymorphicRestrictOut") .Output("a: N*T") .Attr("T: {string, bool}") .Attr("N: int >= 2")); ExpectSuccess(Builder().Attr("N", 3).Attr("T", DT_BOOL), {}, {DT_BOOL, DT_BOOL, DT_BOOL}, R"proto( op: "NPolymorphicRestrictOut" attr { key: "N" value { i: 3 } } attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectInvalid(Builder().Attr("N", 3).Attr("T", DT_INT32), "Value for attr 'T' of int32 is not in the list of allowed " "values: string, bool"); } TEST_F(NodeDefBuilderTest, RefIn) { Op(OpDefBuilder("RefIn").Input("a: Ref(int32)")); ExpectSuccess(Builder().Input(FakeInput(DT_INT32_REF)), {DT_INT32_REF}, {}, R"proto( op: "RefIn" input: "a" )proto"); ExpectFailure(Builder().Input(FakeInput(DT_BOOL_REF)), "Input 'a' passed bool_ref expected int32_ref while"); ExpectFailure(Builder().Input(FakeInput(DT_INT32)), "Input 'a' passed int32 expected int32_ref while"); } TEST_F(NodeDefBuilderTest, PolymorphicRefIn) { Op(OpDefBuilder("PolymorphicRefIn").Input("a: Ref(T)").Attr("T: type")); ExpectSuccess(Builder().Input(FakeInput(DT_BOOL_REF)), {DT_BOOL_REF}, {}, R"proto( op: "PolymorphicRefIn" input: "a" attr { key: "T" value { type: DT_BOOL } } )proto"); ExpectFailure(Builder().Input(FakeInput(DT_BOOL)), "Input 'a' passed bool expected ref type while"); } TEST_F(NodeDefBuilderTest, RefOut) { Op(OpDefBuilder("RefOut").Output("a: Ref(string)")); ExpectSuccess(Builder(), {}, {DT_STRING_REF}, R"proto( op: "RefOut" )proto"); } TEST_F(NodeDefBuilderTest, PolymorphicRefOut) { Op(OpDefBuilder("PolymorphicRefOut").Output("a: Ref(t)").Attr("t: type")); ExpectSuccess(Builder().Attr("t", DT_BOOL), {}, {DT_BOOL_REF}, R"proto( op: "PolymorphicRefOut" attr { key: "t" value { type: DT_BOOL } } )proto"); } TEST_F(NodeDefBuilderTest, SpecifyDevice) { Op(OpDefBuilder("SpecifyDevice")); ExpectSuccess(Builder().Device("ADevice"), {}, {}, R"proto( op: "SpecifyDevice" device: "ADevice" )proto"); } } // namespace } // namespace tensorflow