aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/op_def_library_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/op_def_library_test.py')
-rw-r--r--tensorflow/python/ops/op_def_library_test.py1402
1 files changed, 1402 insertions, 0 deletions
diff --git a/tensorflow/python/ops/op_def_library_test.py b/tensorflow/python/ops/op_def_library_test.py
new file mode 100644
index 0000000000..72de4586a3
--- /dev/null
+++ b/tensorflow/python/ops/op_def_library_test.py
@@ -0,0 +1,1402 @@
+"""Tests for tensorflow.python.ops.op_def_library."""
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import types
+from tensorflow.python.ops.op_def_library import OpDefLibrary
+from tensorflow.python.platform import googletest
+
+
+# NOTE(mrry): Dummy shape registrations for ops used in the tests.
+ops.RegisterShape("Attr")(None)
+ops.RegisterShape("AttrBool")(None)
+ops.RegisterShape("AttrBoolList")(None)
+ops.RegisterShape("AttrDefault")(None)
+ops.RegisterShape("AttrEmptyListDefault")(None)
+ops.RegisterShape("AttrEnum")(None)
+ops.RegisterShape("AttrEnumList")(None)
+ops.RegisterShape("AttrFloat")(None)
+ops.RegisterShape("AttrListDefault")(None)
+ops.RegisterShape("AttrListMin")(None)
+ops.RegisterShape("AttrMin")(None)
+ops.RegisterShape("AttrShape")(None)
+ops.RegisterShape("AttrShapeList")(None)
+ops.RegisterShape("Binary")(None)
+ops.RegisterShape("ComplexStruct")(None)
+ops.RegisterShape("InPolymorphicTwice")(None)
+ops.RegisterShape("MixedStruct")(None)
+ops.RegisterShape("NInPolymorphicTwice")(None)
+ops.RegisterShape("NInTwice")(None)
+ops.RegisterShape("NInTwoTypeVariables")(None)
+ops.RegisterShape("NIntsIn")(None)
+ops.RegisterShape("NIntsOut")(None)
+ops.RegisterShape("NIntsOutDefault")(None)
+ops.RegisterShape("NPolymorphicIn")(None)
+ops.RegisterShape("NPolymorphicOut")(None)
+ops.RegisterShape("NPolymorphicOutDefault")(None)
+ops.RegisterShape("NPolymorphicRestrictIn")(None)
+ops.RegisterShape("NPolymorphicRestrictOut")(None)
+ops.RegisterShape("OutT")(None)
+ops.RegisterShape("OutTypeList")(None)
+ops.RegisterShape("OutTypeListRestrict")(None)
+ops.RegisterShape("Polymorphic")(None)
+ops.RegisterShape("PolymorphicDefaultOut")(None)
+ops.RegisterShape("PolymorphicOut")(None)
+ops.RegisterShape("RefIn")(None)
+ops.RegisterShape("RefOut")(None)
+ops.RegisterShape("ReservedAttr")(None)
+ops.RegisterShape("ReservedInput")(None)
+ops.RegisterShape("Restrict")(None)
+ops.RegisterShape("Simple")(None)
+ops.RegisterShape("SimpleStruct")(None)
+ops.RegisterShape("TypeList")(None)
+ops.RegisterShape("TypeListRestrict")(None)
+ops.RegisterShape("TypeListTwice")(None)
+
+
+class OpDefLibraryTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._lib = OpDefLibrary()
+ self._g = ops.Graph()
+ self._default_graph_controller = self._g.as_default()
+ self._default_graph_controller.__enter__()
+ self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
+ "output_arg { name: 'out' type: DT_FLOAT }")
+ self._add_op("name: 'OutT' output_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ def tearDown(self):
+ self._default_graph_controller.__exit__(None, None, None)
+
+ def _add_op(self, ascii):
+ op_def = op_def_pb2.OpDef()
+ text_format.Merge(ascii, op_def)
+ self._lib.add_op(op_def)
+
+ def Tensor(self, t, name="in"):
+ return self._lib.apply_op("OutT", T=t, name=name)
+
+ def testNoRegisteredOpFails(self):
+ with self.assertRaises(RuntimeError) as cm:
+ self._lib.apply_op("unknown", g=self._g)
+ self.assertEqual(cm.exception.message, "Unrecognized Op name unknown")
+
+ def testAddOpValidation(self):
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'MissingTypeAttr' "
+ "input_arg { name: 'a' type_attr: 'T' } ")
+ self.assertEqual(cm.exception.message,
+ "Inconsistent OpDef for 'MissingTypeAttr', "
+ "missing attr 'T'")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'BadTypeAttr' "
+ "output_arg { name: 'a' type_attr: 'T' } "
+ "attr { name: 'T' type: 'int' }")
+ self.assertEqual(
+ cm.exception.message,
+ "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'MissingNumberAttr' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ")
+ self.assertEqual(cm.exception.message,
+ "Inconsistent OpDef for 'MissingNumberAttr', "
+ "missing attr 'N'")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'BadNumberAttr' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'type' }")
+ self.assertEqual(
+ cm.exception.message,
+ "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'TwoTypesA' "
+ "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'TwoTypesA' must have one type field not 2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'TwoTypesB' "
+ "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'TwoTypesB' must have one type field not 2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'ThreeTypes' "
+ "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' "
+ "type_list_attr: 'U' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'U' type: 'list(type)' }")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'ThreeTypes' must have one type field not 3")
+
+ with self.assertRaises(TypeError) as cm:
+ self._add_op("name: 'NoTypes' output_arg { name: 'a' } ")
+ self.assertEqual(cm.exception.message,
+ "Arg 'a' of 'NoTypes' must have one type field not 0")
+
+ def testSimple(self):
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(types.float32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'Simple' op: 'Simple' input: 'Simple/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=4)
+ self.assertProtoEquals("""
+ name: 'Simple_1' op: 'Simple' input: 'Simple_1/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=5, name="named")
+ self.assertProtoEquals("""
+ name: 'named' op: 'Simple' input: 'named/a'
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d")
+ self.assertProtoEquals("""
+ name: 'two_d' op: 'Simple' input: 'two_d/a'
+ """, out.op.node_def)
+
+ def testSimpleFailures(self):
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a="Bad string")
+ self.assertEqual(cm.exception.message,
+ "Expected int32, got 'Bad string' instead.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=self.Tensor(types.string))
+ self.assertEqual(cm.exception.message,
+ "Input 'a' of 'Simple' Op has type string "
+ "that does not match expected type of int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=6, extra="bogus")
+ self.assertEqual(cm.exception.message,
+ "apply_op() got unexpected keyword arguments: extra")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus")
+ self.assertEqual(cm.exception.message,
+ "apply_op() got unexpected keyword arguments: extra1, "
+ "extra2")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple")
+ self.assertEqual(cm.exception.message, "No argument for input a")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", wrong=7)
+ self.assertEqual(cm.exception.message, "No argument for input a")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Simple", a=[self.Tensor(types.int32)])
+ self.assertStartsWith(cm.exception.message, "Expected int32, got")
+
+ def testReservedInput(self):
+ self._add_op("name: 'ReservedInput' "
+ "input_arg { name: 'input' type: DT_INT32 } ")
+ op = self._lib.apply_op("ReservedInput", input_=7, name="x")
+ self.assertProtoEquals("""
+ name: 'x' op: 'ReservedInput' input: 'x/input'
+ """, op.node_def)
+
+ def testPolymorphic(self):
+ self._add_op("name: 'Polymorphic' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("Polymorphic", a=7, name="p")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'Polymorphic' input: 'p/a'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Polymorphic", a="s", name="q")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'Polymorphic' input: 'q/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'r' op: 'Polymorphic' input: 'r/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Polymorphic", a="s", T=types.string)
+ self.assertEqual(cm.exception.message,
+ "Should not specify value for inferred attr 'T'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Polymorphic", a=[self.Tensor(types.bool)])
+ self.assertEqual(cm.exception.message,
+ "List of Tensors when single Tensor expected")
+
+ def testPolymorphicOut(self):
+ self._add_op("name: 'PolymorphicOut' "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("PolymorphicOut", T=types.int32, name="p")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'PolymorphicOut'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("PolymorphicOut", T=types.bool, name="q")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'PolymorphicOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("PolymorphicOut")
+ self.assertEqual(cm.exception.message,
+ "No argument for attr T")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("PolymorphicOut", T=None)
+ self.assertEqual(cm.exception.message,
+ "Expected DataType for argument 'T' not None.")
+
+ def testPolymorphicDefaultOut(self):
+ self._add_op("name: 'PolymorphicDefaultOut' "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' "
+ " default_value { type: DT_STRING } }")
+
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'p' op: 'PolymorphicDefaultOut'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("PolymorphicDefaultOut", T=types.bool,
+ name="q")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'q' op: 'PolymorphicDefaultOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ def testBinary(self):
+ self._add_op("name: 'Binary' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "input_arg { name: 'b' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ out = self._lib.apply_op("Binary", a=8, b=9, name="b")
+ self.assertEquals(types.int32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'b' op: 'Binary' input: 'b/a' input: 'b/b'
+ attr { key: 'T' value { type: DT_INT32 } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Binary", a="left", b="right", name="c")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'c' op: 'Binary' input: 'c/a' input: 'c/b'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Binary", a="left", b=12)
+ self.assertEqual(cm.exception.message,
+ "Expected string, got 12 instead.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Binary", a=self.Tensor(types.string),
+ b=self.Tensor(types.int32))
+ self.assertEqual(cm.exception.message,
+ "Input 'b' of 'Binary' Op has type int32 "
+ "that does not match type string of argument 'a'.")
+
+ def testRestrict(self):
+ self._add_op("name: 'Restrict' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ out = self._lib.apply_op("Restrict", a="foo", name="g")
+ self.assertEquals(types.string, out.dtype)
+ self.assertProtoEquals("""
+ name: 'g' op: 'Restrict' input: 'g/a'
+ attr { key: 'T' value { type: DT_STRING } }
+ """, out.op.node_def)
+
+ out = self._lib.apply_op("Restrict", a=True, name="h")
+ self.assertEquals(types.bool, out.dtype)
+ self.assertProtoEquals("""
+ name: 'h' op: 'Restrict' input: 'h/a'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Restrict", a=17)
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: "
+ "string, bool")
+
+ def testTypeList(self):
+ self._add_op("name: 'TypeList' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ op = self._lib.apply_op("TypeList", a=["foo"], name="z")
+ self.assertProtoEquals("""
+ name: 'z' op: 'TypeList' input: 'z/a_0'
+ attr { key: 'T' value { list { type: DT_STRING } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeList", a=[True, 12], name="y")
+ self.assertProtoEquals("""
+ name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1'
+ attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeList", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeList", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' "
+ "argument to 'TypeList' Op, not ")
+
+ def testTypeListTwice(self):
+ self._add_op("name: 'TypeListTwice' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "input_arg { name: 'b' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ op = self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", False],
+ name="z")
+ self.assertProtoEquals("""
+ name: 'z' op: 'TypeListTwice'
+ input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1'
+ attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6])
+ self.assertEqual(cm.exception.message,
+ "Input 'b' of 'TypeListTwice' Op has type list of "
+ "string, int32 that does not match type list "
+ "string, bool of argument 'a'.")
+
+ def testOutTypeList(self):
+ self._add_op("name: 'OutTypeList' "
+ "output_arg { name: 'out' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' }")
+
+ out, = self._lib.apply_op("OutTypeList", T=[types.float32], name="x")
+ self.assertEquals(types.float32, out.dtype)
+ self.assertProtoEquals("""
+ name: 'x' op: 'OutTypeList'
+ attr { key: 'T' value { list { type: DT_FLOAT } } }
+ """, out.op.node_def)
+
+ out1, out2 = self._lib.apply_op("OutTypeList",
+ T=[types.int32, types.bool],
+ name="w")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'w' op: 'OutTypeList'
+ attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } }
+ """, out1.op.node_def)
+
+ out = self._lib.apply_op("OutTypeList", T=[], name="empty")
+ self.assertEqual([], out)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("OutTypeList", T=types.int32)
+ self.assertEqual(cm.exception.message, "Expected list for attr T")
+
+ def testTypeListRestrict(self):
+ self._add_op("name: 'TypeListRestrict' "
+ "input_arg { name: 'a' type_list_attr: 'T' } "
+ "attr { name: 'T' type: 'list(type)' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v")
+ self.assertProtoEquals("""
+ name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1'
+ attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("TypeListRestrict", a=[True, 12])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testOutTypeListRestrict(self):
+ self._add_op("name: 'OutTypeListRestrict' "
+ "output_arg { name: 'out' type_list_attr: 't' } "
+ "attr { name: 't' type: 'list(type)' allowed_values { list { "
+ " type: DT_STRING type: DT_BOOL } } }")
+
+ out1, out2 = self._lib.apply_op("OutTypeListRestrict",
+ t=[types.bool, types.string],
+ name="u")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.string, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'OutTypeListRestrict'
+ attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } }
+ """, out1.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("OutTypeListRestrict",
+ t=[types.string, types.int32])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 't' "
+ "not in list of allowed values: string, bool")
+
+ def testAttr(self):
+ self._add_op("name: 'Attr' attr { name: 'a' type: 'int' }")
+ op = self._lib.apply_op("Attr", a=12, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a="bad")
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not 'bad'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a=[12])
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not [12].")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr", a=None)
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'a' not None.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("Attr")
+ self.assertEqual(cm.exception.message, "No argument for attr a")
+
+ def testAttrFloat(self):
+ self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }")
+
+ op = self._lib.apply_op("AttrFloat", a=1.2, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrFloat", a=12, name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrFloat", a="bad")
+ self.assertEqual(cm.exception.message,
+ "Expected float for argument 'a' not 'bad'.")
+
+ def testAttrBool(self):
+ self._add_op("name: 'AttrBool' attr { name: 'a' type: 'bool' }")
+
+ op = self._lib.apply_op("AttrBool", a=True, name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrBool", a=False, name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=0)
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 0.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=1)
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 1.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBool", a=[])
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not [].")
+
+ def testAttrBoolList(self):
+ self._add_op("name: 'AttrBoolList' attr { name: 'a' type: 'list(bool)' }")
+
+ op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t")
+ self.assertProtoEquals("""
+ name: 't' op: 'AttrBoolList'
+ attr { key: 'a' value { list { b: true b: false b:true } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrBoolList", a=[], name="u")
+ self.assertProtoEquals("""
+ name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("AttrBoolList", a=[0])
+ self.assertEqual(cm.exception.message,
+ "Expected bool for argument 'a' not 0.")
+
+ def testAttrMin(self):
+ self._add_op("name: 'AttrMin' attr { name: 'a' type: 'int' "
+ "has_minimum: true minimum: 5 }")
+ op = self._lib.apply_op("AttrMin", a=12, name="s")
+ self.assertProtoEquals("""
+ name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrMin", a=2)
+ self.assertEqual(cm.exception.message,
+ "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.")
+
+ def testAttrListMin(self):
+ self._add_op("name: 'AttrListMin' attr { name: 'a' type: 'list(int)' "
+ "has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r")
+ self.assertProtoEquals("""
+ name: 'r' op: 'AttrListMin'
+ attr { key: 'a' value { list { i: 1 i: 2 } } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrListMin", a=[17])
+ self.assertEqual(cm.exception.message,
+ "Attr 'a' of 'AttrListMin' Op "
+ "passed list of length 1 less than minimum 2.")
+
+ def testAttrEnum(self):
+ self._add_op("name: 'AttrEnum' "
+ "attr { name: 'a' type: 'string' "
+ " allowed_values { list { s: 'apples' s: 'oranges' } } }")
+
+ op = self._lib.apply_op("AttrEnum", a="oranges", name="e")
+ self.assertProtoEquals("""
+ name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrEnum", a="invalid")
+ self.assertEqual(cm.exception.message,
+ 'Attr \'a\' of \'AttrEnum\' Op '
+ 'passed string \'invalid\' not in: '
+ '"apples", "oranges".')
+
+ def testAttrEnumList(self):
+ self._add_op("name: 'AttrEnumList' "
+ "attr { name: 'a' type: 'list(string)' "
+ " allowed_values { list { s: 'apples' s: 'oranges' } } }")
+
+ op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f")
+ self.assertProtoEquals("""
+ name: 'f' op: 'AttrEnumList'
+ attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"])
+ self.assertEqual(cm.exception.message,
+ 'Attr \'a\' of \'AttrEnumList\' Op '
+ 'passed string \'invalid\' not '
+ 'in: "apples", "oranges".')
+
+ def testAttrShape(self):
+ self._add_op("name: 'AttrShape' attr { name: 'a' type: 'shape' }")
+
+ op = self._lib.apply_op("AttrShape", a=[5], name="s1")
+ self.assertProtoEquals("""
+ name: 's1' op: 'AttrShape'
+ attr { key: 'a' value { shape { dim { size: 5 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2")
+ self.assertProtoEquals("""
+ name: 's2' op: 'AttrShape'
+ attr { key: 'a' value {
+ shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op(
+ "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3")
+ self.assertProtoEquals("""
+ name: 's3' op: 'AttrShape'
+ attr { key: 'a' value {
+ shape { dim { size: 3 } dim { size: 2 } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShape", a=[], name="s4")
+ self.assertProtoEquals("""
+ name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } }
+ """, op.node_def)
+
+ shape = tensor_shape_pb2.TensorShapeProto()
+ shape.dim.add().size = 6
+ shape.dim.add().size = 3
+ op = self._lib.apply_op("AttrShape", a=shape, name="s5")
+ self.assertProtoEquals("""
+ name: 's5' op: 'AttrShape'
+ attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } }
+ """, op.node_def)
+
+ # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes.
+ # with self.assertRaises(TypeError) as cm:
+ # self._lib.apply_op("AttrShape", a=5)
+ # self.assertEqual(cm.exception.message,
+ # "Don't know how to convert 5 to a TensorShapeProto for "
+ # "argument 'a'")
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("AttrShape", a="ABC")
+
+ def testAttrShapeList(self):
+ self._add_op("name: 'AttrShapeList' attr { name: 'a' type: 'list(shape)' }")
+
+ op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl")
+ self.assertProtoEquals("""
+ name: 'sl' op: 'AttrShapeList'
+ attr { key: 'a' value { list {
+ shape { dim { size: 3 } dim { size: 2 } }
+ shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrShapeList", a=[], name="esl")
+ self.assertProtoEquals("""
+ name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testAttrDefault(self):
+ self._add_op("name: 'AttrDefault' "
+ "attr { name: 'a' type: 'string' "
+ " default_value { s: 'banana' } }")
+
+ op = self._lib.apply_op("AttrDefault", a=None, name="d")
+ self.assertProtoEquals("""
+ name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrDefault", a="kiwi", name="c")
+ self.assertProtoEquals("""
+ name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } }
+ """, op.node_def)
+
+ def testAttrListDefault(self):
+ self._add_op("name: 'AttrListDefault' "
+ "attr { name: 'a' type: 'list(int)' "
+ " default_value { list { i: 5 i: 15 } } }")
+
+ op = self._lib.apply_op("AttrListDefault", a=None, name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'AttrListDefault'
+ attr { key: 'a' value { list { i: 5 i: 15 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrListDefault", a=[3], name="a")
+ self.assertProtoEquals("""
+ name: 'a' op: 'AttrListDefault'
+ attr { key: 'a' value { list { i: 3 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrListDefault", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'AttrListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testAttrEmptyListDefault(self):
+ self._add_op("name: 'AttrEmptyListDefault' "
+ "attr { name: 'a' type: 'list(float)' "
+ " default_value { list { } } }")
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a")
+ self.assertProtoEquals("""
+ name: 'a' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { f: 3 } } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty")
+ self.assertProtoEquals("""
+ name: 'empty' op: 'AttrEmptyListDefault'
+ attr { key: 'a' value { list { } } }
+ """, op.node_def)
+
+ def testReservedAttr(self):
+ self._add_op("name: 'ReservedAttr' "
+ "attr { name: 'range' type: 'int' } ")
+ op = self._lib.apply_op("ReservedAttr", range_=7, name="x")
+ self.assertProtoEquals("""
+ name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } }
+ """, op.node_def)
+
+ def testNIntsIn(self):
+ self._add_op("name: 'NIntsIn' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1'
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NIntsIn'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
+ attr { key: 'N' value { i: 5 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=["foo", "bar"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
+ "[string, string] that do not match expected type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(types.string),
+ self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have "
+ "types [string, string] that do not match expected type "
+ "int32.")
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NIntsIn", a=[99])
+ self.assertEqual(cm.exception.message,
+ "List argument 'a' to 'NIntsIn' Op "
+ "with length 1 shorter than "
+ "minimum length 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[38, "bar"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op have types "
+ "[int32, string] that do not match expected type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=[self.Tensor(types.int32),
+ self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NIntsIn' Op "
+ "have types [int32, string] that do not match expected "
+ "type int32.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsIn", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' argument "
+ "to 'NIntsIn' Op, not ")
+
+ def testNPolymorphicIn(self):
+ self._add_op("name: 'NPolymorphicIn' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NPolymorphicIn'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 5 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicIn",
+ a=[1, self.Tensor(types.float32, name="x")],
+ name="q")
+ self.assertProtoEquals("""
+ name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=[99])
+ self.assertEqual(cm.exception.message,
+ "List argument 'a' to 'NPolymorphicIn' Op with length 1 "
+ "shorter than minimum length 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=[38, "bar"])
+ self.assertEqual(cm.exception.message,
+ "All tensors passed to 'a' of 'NPolymorphicIn' "
+ "Op must have the same type.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn",
+ a=[38, self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
+ "have types [int32, string] that don't all match.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn",
+ a=["abcd", self.Tensor(types.int32)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'a' of 'NPolymorphicIn' Op "
+ "have types [string, int32] that don't all match.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicIn", a=17)
+ self.assertStartsWith(cm.exception.message,
+ "Expected list for 'a' argument "
+ "to 'NPolymorphicIn' Op, not ")
+
+ def testNPolymorphicRestrictIn(self):
+ self._add_op("name: 'NPolymorphicRestrictIn' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' allowed_values { "
+ " list { type: DT_STRING type: DT_BOOL } } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"],
+ name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NPolymorphicRestrictIn", a=[False, True, False],
+ name="b")
+ self.assertProtoEquals("""
+ name: 'b' op: 'NPolymorphicRestrictIn'
+ input: 'b/a_0' input: 'b/a_1' input: 'b/a_2'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2])
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testNInTwice(self):
+ self._add_op("name: 'NInTwice' "
+ "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "input_arg { name: 'b' type: DT_STRING number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInTwice'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwice", a=[], b=[], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInTwice' Op "
+ "with length 1 must match "
+ "length 3 of argument 'a'.")
+
+ def testNInPolymorphicTwice(self):
+ self._add_op("name: 'NInPolymorphicTwice' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInPolymorphicTwice'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInPolymorphicTwice' Op "
+ "with length 1 "
+ "must match length 3 of argument 'a'.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'NInPolymorphicTwice' "
+ "Op have types [string, string] that do not match type "
+ "int32 inferred from earlier arguments.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NInPolymorphicTwice",
+ a=[self.Tensor(types.int32)],
+ b=[self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of "
+ "'NInPolymorphicTwice' Op have types [string] that do not "
+ "match type int32 inferred from earlier arguments.")
+
+ def testNInTwoTypeVariables(self):
+ self._add_op("name: 'NInTwoTypeVariables' "
+ "input_arg { name: 'a' type_attr: 'S' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'S' type: 'type' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }")
+
+ op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[True, False],
+ name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'NInTwoTypeVariables'
+ input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'NInTwoTypeVariables'
+ input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("NInTwoTypeVariables",
+ a=[self.Tensor(types.int32, name="q")],
+ b=[self.Tensor(types.string, name="r")],
+ name="p")
+ self.assertProtoEquals("""
+ name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r'
+ attr { key: 'S' value { type: DT_INT32 } }
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 1 } }
+ """, op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"])
+ self.assertEqual(cm.exception.message,
+ "List argument 'b' to 'NInTwoTypeVariables' Op "
+ "with length 1 "
+ "must match length 3 of argument 'a'.")
+
+ def testInPolymorphicTwice(self):
+ self._add_op("name: 'InPolymorphicTwice' "
+ "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "input_arg { name: 'b' type_attr: 'T' number_attr: 'M' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 } "
+ "attr { name: 'M' type: 'int' has_minimum: true minimum: 0 } ")
+
+ op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n")
+ self.assertProtoEquals("""
+ name: 'n' op: 'InPolymorphicTwice'
+ input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 1 } }
+ attr { key: 'M' value { i: 3 } }
+ """, op.node_def)
+
+ op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o")
+ self.assertProtoEquals("""
+ name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 1 } }
+ attr { key: 'M' value { i: 0 } }
+ """, op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5])
+ self.assertEqual(cm.exception.message,
+ "Don't know how to infer type variable from empty input "
+ "list passed to input 'a' of 'InPolymorphicTwice' Op.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op "
+ "have types [string, string] that do not match type int32 "
+ "inferred from earlier arguments.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("InPolymorphicTwice",
+ a=[self.Tensor(types.int32)],
+ b=[self.Tensor(types.string)])
+ self.assertEqual(cm.exception.message,
+ "Tensors in list passed to 'b' of 'InPolymorphicTwice' "
+ "Op have types [string] that do not match type int32 "
+ "inferred from earlier arguments.")
+
+ def testNIntsOut(self):
+ self._add_op("name: 'NIntsOut' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3, out4, out5 = self._lib.apply_op(
+ "NIntsOut", N=5, name="o")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertEquals(types.int32, out4.dtype)
+ self.assertEquals(types.int32, out5.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } }
+ """, out5.op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NIntsOut", N=1)
+ self.assertEqual(cm.exception.message,
+ "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NIntsOut", N=[3])
+ self.assertEqual(cm.exception.message,
+ "Expected int for argument 'N' not [3].")
+
+ def testNIntsOutDefault(self):
+ self._add_op("name: 'NIntsOutDefault' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2"
+ " default_value { i:3 } }")
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NIntsOutDefault", N=None, name="z")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } }
+ """, out2.op.node_def)
+
+ def testNPolymorphicOut(self):
+ self._add_op("name: 'NPolymorphicOut' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2,
+ T=types.int32, name="n")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'n' op: 'NPolymorphicOut'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOut", T=types.string, N=3, name="o")
+ self.assertEquals(types.string, out1.dtype)
+ self.assertEquals(types.string, out2.dtype)
+ self.assertEquals(types.string, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'NPolymorphicOut'
+ attr { key: 'T' value { type: DT_STRING } }
+ attr { key: 'N' value { i: 3 } }
+ """, out3.op.node_def)
+
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("NPolymorphicOut", N=1, T=types.string)
+ self.assertEqual(cm.exception.message,
+ "Attr 'N' of 'NPolymorphicOut' Op "
+ "passed 1 less than minimum 2.")
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicOut", N=3, T=[types.string])
+ self.assertEqual(
+ cm.exception.message,
+ "Expected DataType for argument 'T' not [tf.string].")
+
+ def testNPolymorphicOutDefault(self):
+ self._add_op("name: 'NPolymorphicOutDefault' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type'"
+ " default_value { type: DT_BOOL } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 "
+ " default_value { i: 2 } }")
+
+ out1, out2 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=None, T=None, name="r")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertProtoEquals("""
+ name: 'r' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=3, T=None, name="s")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertEquals(types.bool, out3.dtype)
+ self.assertProtoEquals("""
+ name: 's' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ out1, out2 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=None, T=types.int32, name="t")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertProtoEquals("""
+ name: 't' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 2 } }
+ """, out1.op.node_def)
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicOutDefault", N=3, T=types.int32, name="u")
+ self.assertEquals(types.int32, out1.dtype)
+ self.assertEquals(types.int32, out2.dtype)
+ self.assertEquals(types.int32, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'NPolymorphicOutDefault'
+ attr { key: 'T' value { type: DT_INT32 } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ def testNPolymorphicRestrictOut(self):
+ self._add_op("name: 'NPolymorphicRestrictOut' "
+ "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } "
+ "attr { name: 'T' type: 'type' allowed_values { "
+ " list { type: DT_STRING type: DT_BOOL } } } "
+ "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }")
+
+ out1, out2, out3 = self._lib.apply_op(
+ "NPolymorphicRestrictOut", N=3, T=types.bool, name="u")
+ self.assertEquals(types.bool, out1.dtype)
+ self.assertEquals(types.bool, out2.dtype)
+ self.assertEquals(types.bool, out3.dtype)
+ self.assertProtoEquals("""
+ name: 'u' op: 'NPolymorphicRestrictOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ attr { key: 'N' value { i: 3 } }
+ """, out1.op.node_def)
+
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=types.int32)
+ self.assertEqual(cm.exception.message,
+ "DataType int32 for attr 'T' "
+ "not in list of allowed values: string, bool")
+
+ def testRef(self):
+ self._add_op("name: 'RefIn' "
+ "input_arg { name: 'a' type_attr: 'T' is_ref: true } "
+ "attr { name: 'T' type: 'type' } ")
+ self._add_op("name: 'RefOut' "
+ "output_arg { name: 'a' type_attr: 'T' is_ref: true } "
+ "attr { name: 'T' type: 'type' } ")
+
+ out = self._lib.apply_op("RefOut", T=types.bool, name="o")
+ self.assertEquals(types.bool_ref, out.dtype)
+ self.assertProtoEquals("""
+ name: 'o' op: 'RefOut'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, out.op.node_def)
+
+ op = self._lib.apply_op("RefIn", a=out, name="i")
+ self.assertProtoEquals("""
+ name: 'i' op: 'RefIn' input: 'o'
+ attr { key: 'T' value { type: DT_BOOL } }
+ """, op.node_def)
+
+ # Can pass ref to non-ref input.
+ out = self._lib.apply_op("RefOut", T=types.int32, name="r")
+ out = self._lib.apply_op("Simple", a=out, name="s")
+ self.assertProtoEquals("""
+ name: 's' op: 'Simple' input: 'r'
+ """, out.op.node_def)
+
+ # Can't pass non-ref to ref input.
+ with self.assertRaises(TypeError) as cm:
+ self._lib.apply_op("RefIn", a=2)
+ self.assertEqual(cm.exception.message,
+ "Input 'a' of 'RefIn' Op requires l-value input")
+
+ def testSpecifyDevice(self):
+ with self._g.device("ADevice"):
+ self._lib.apply_op("Simple", a=3)
+ # We look at the whole graph here to make sure the Const op is also given
+ # the specified device.
+ graph_def = self._g.as_graph_def()
+ self.assertEqual(len(graph_def.node), 2)
+ for node in graph_def.node:
+ self.assertEqual(node.device, "ADevice")
+
+ def testStructuredOutputSingleList(self):
+ self._add_op("name: 'SimpleStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "attr { name: 'n_a' type: 'int' }")
+ for n_a in [0, 1, 3]:
+ a = self._lib.apply_op("SimpleStruct", n_a=n_a)
+ self.assertTrue(isinstance(a, list))
+ self.assertEqual(n_a, len(a))
+
+ def testStructuredOutputListAndSingle(self):
+ self._add_op("name: 'MixedStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "output_arg { name: 'b' type: DT_FLOAT } "
+ "attr { name: 'n_a' type: 'int' }")
+ for n_a in [0, 1, 3]:
+ a, b = self._lib.apply_op("MixedStruct", n_a=n_a)
+ self.assertTrue(isinstance(a, list))
+ self.assertEqual(n_a, len(a))
+ self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertTrue(isinstance(b, ops.Tensor))
+ self.assertEqual(types.float32, b.dtype)
+
+ def testStructuredOutputMultipleLists(self):
+ self._add_op("name: 'ComplexStruct' "
+ "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } "
+ "output_arg { name: 'b' type: DT_INT64 number_attr: 'n_b' } "
+ "output_arg { name: 'c' type_list_attr: 't_c' } "
+ "attr { name: 'n_a' type: 'int' } "
+ "attr { name: 'n_b' type: 'int' } "
+ "attr { name: 't_c' type: 'list(type)' }")
+ for n_a in [0, 1, 3]:
+ for n_b in [0, 1, 3]:
+ for t_c in [[],
+ [types.int32],
+ [types.int32, types.float32]]:
+ a, b, c = self._lib.apply_op("ComplexStruct",
+ n_a=n_a, n_b=n_b, t_c=t_c)
+
+ self.assertEqual(n_a, len(a))
+ self.assertTrue(all(x.dtype == types.int32 for x in a))
+ self.assertEqual(n_b, len(b))
+ self.assertTrue(all(x.dtype == types.int64 for x in b))
+ self.assertEqual(t_c, [x.dtype for x in c])
+
+
+class OpDefLibraryGraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._lib = OpDefLibrary()
+ self._g = ops.Graph()
+ self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } "
+ "output_arg { name: 'out' type: DT_FLOAT }")
+ self._add_op("name: 'Binary' "
+ "input_arg { name: 'a' type_attr: 'T' } "
+ "input_arg { name: 'b' type_attr: 'T' } "
+ "output_arg { name: 'out' type_attr: 'T' } "
+ "attr { name: 'T' type: 'type' }")
+
+ def _add_op(self, ascii):
+ op_def = op_def_pb2.OpDef()
+ text_format.Merge(ascii, op_def)
+ self._lib.add_op(op_def)
+
+ def testNoGraph(self):
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(out.graph, ops.get_default_graph())
+
+ def testDefaultGraph(self):
+ with self._g.as_default():
+ out = self._lib.apply_op("Simple", a=3)
+ self.assertEquals(out.graph, self._g)
+
+ def testIgnoreDefaultGraphWithGraphArgument(self):
+ default_g = ops.Graph()
+ with default_g.as_default():
+ out = self._lib.apply_op("Simple", a=3, g=self._g)
+ self.assertEquals(ops.get_default_graph(), default_g)
+ self.assertEquals(out.graph, self._g)
+
+ def testDifferentGraphFails(self):
+ a = self._lib.apply_op("Simple", a=3, g=self._g)
+ other_g = ops.Graph()
+ b = self._lib.apply_op("Simple", a=4, g=other_g)
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("Binary", a=a, b=b)
+ self.assertTrue("must be from the same graph" in cm.exception.message)
+
+ def testDifferentGraphFailsWithGraphArgument(self):
+ other_g = ops.Graph()
+ a = self._lib.apply_op("Simple", a=3, g=other_g)
+ b = self._lib.apply_op("Simple", a=4, g=other_g)
+ with self.assertRaises(ValueError) as cm:
+ self._lib.apply_op("Binary", a=a, b=b, g=self._g)
+ self.assertTrue(
+ "not from the passed-in graph" in cm.exception.message)
+
+
+if __name__ == "__main__":
+ googletest.main()