# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for tensorflow.python.ops.op_def_library.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from google.protobuf import text_format from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.framework.op_def_library import OpDefLibrary from tensorflow.python.platform import googletest def _unknown_shape(op): """Shape function for use with ops whose output shapes are unknown.""" return [tensor_shape.unknown_shape() for _ in op.outputs] # NOTE(mrry): Dummy shape registrations for ops used in the tests, since they # don't have C++ op registrations on which to attach C++ shape fns. ops.RegisterShape("Attr")(_unknown_shape) ops.RegisterShape("AttrBool")(_unknown_shape) ops.RegisterShape("AttrBoolList")(_unknown_shape) ops.RegisterShape("AttrDefault")(_unknown_shape) ops.RegisterShape("AttrEmptyListDefault")(_unknown_shape) ops.RegisterShape("AttrEnum")(_unknown_shape) ops.RegisterShape("AttrEnumList")(_unknown_shape) ops.RegisterShape("AttrFloat")(_unknown_shape) ops.RegisterShape("AttrListDefault")(_unknown_shape) ops.RegisterShape("AttrListMin")(_unknown_shape) ops.RegisterShape("AttrMin")(_unknown_shape) ops.RegisterShape("AttrShape")(_unknown_shape) ops.RegisterShape("AttrShapeList")(_unknown_shape) ops.RegisterShape("AttrPartialShape")(_unknown_shape) ops.RegisterShape("AttrPartialShapeList")(_unknown_shape) ops.RegisterShape("AttrTypeDefault")(_unknown_shape) ops.RegisterShape("AttrListTypeDefault")(_unknown_shape) ops.RegisterShape("Binary")(_unknown_shape) ops.RegisterShape("ComplexStruct")(_unknown_shape) ops.RegisterShape("InPolymorphicTwice")(_unknown_shape) ops.RegisterShape("MixedStruct")(_unknown_shape) ops.RegisterShape("NInPolymorphicTwice")(_unknown_shape) ops.RegisterShape("NInTwice")(_unknown_shape) ops.RegisterShape("NInTwoTypeVariables")(_unknown_shape) ops.RegisterShape("NIntsIn")(_unknown_shape) ops.RegisterShape("NIntsOut")(_unknown_shape) ops.RegisterShape("NIntsOutDefault")(_unknown_shape) ops.RegisterShape("NPolymorphicIn")(_unknown_shape) ops.RegisterShape("NPolymorphicOut")(_unknown_shape) ops.RegisterShape("NPolymorphicOutDefault")(_unknown_shape) ops.RegisterShape("NPolymorphicRestrictIn")(_unknown_shape) ops.RegisterShape("NPolymorphicRestrictOut")(_unknown_shape) ops.RegisterShape("OutT")(_unknown_shape) ops.RegisterShape("OutTypeList")(_unknown_shape) ops.RegisterShape("OutTypeListRestrict")(_unknown_shape) ops.RegisterShape("Polymorphic")(_unknown_shape) ops.RegisterShape("PolymorphicDefaultOut")(_unknown_shape) ops.RegisterShape("PolymorphicOut")(_unknown_shape) ops.RegisterShape("RefIn")(_unknown_shape) ops.RegisterShape("RefOut")(_unknown_shape) ops.RegisterShape("ReservedAttr")(_unknown_shape) ops.RegisterShape("ReservedInput")(_unknown_shape) ops.RegisterShape("Restrict")(_unknown_shape) ops.RegisterShape("Simple")(_unknown_shape) ops.RegisterShape("SimpleStruct")(_unknown_shape) ops.RegisterShape("TwoRefsIn")(_unknown_shape) ops.RegisterShape("TypeList")(_unknown_shape) ops.RegisterShape("TypeListRestrict")(_unknown_shape) ops.RegisterShape("TypeListTwice")(_unknown_shape) class OpDefLibraryTest(test_util.TensorFlowTestCase): def setUp(self): self._lib = OpDefLibrary() self._g = ops.Graph() self._default_graph_controller = self._g.as_default() self._default_graph_controller.__enter__() self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } " "output_arg { name: 'out' type: DT_FLOAT }") self._add_op("name: 'OutT' output_arg { name: 'a' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") def tearDown(self): self._default_graph_controller.__exit__(None, None, None) def _add_op(self, ascii): op_def = op_def_pb2.OpDef() text_format.Merge(ascii, op_def) self._lib.add_op(op_def) def Tensor(self, t, name="in"): return self._lib.apply_op("OutT", T=t, name=name) def testNoRegisteredOpFails(self): with self.assertRaises(RuntimeError) as cm: self._lib.apply_op("unknown") self.assertEqual(str(cm.exception), "Unrecognized Op name unknown") def testAddOpValidation(self): with self.assertRaises(TypeError) as cm: self._add_op("name: 'MissingTypeAttr' " "input_arg { name: 'a' type_attr: 'T' } ") self.assertEqual(str(cm.exception), "Inconsistent OpDef for 'MissingTypeAttr', " "missing attr 'T'") with self.assertRaises(TypeError) as cm: self._add_op("name: 'BadTypeAttr' " "output_arg { name: 'a' type_attr: 'T' } " "attr { name: 'T' type: 'int' }") self.assertEqual( str(cm.exception), "Attr 'T' of 'BadTypeAttr' used as a type_attr but has type int") with self.assertRaises(TypeError) as cm: self._add_op("name: 'MissingNumberAttr' " "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } ") self.assertEqual(str(cm.exception), "Inconsistent OpDef for 'MissingNumberAttr', " "missing attr 'N'") with self.assertRaises(TypeError) as cm: self._add_op("name: 'BadNumberAttr' " "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "attr { name: 'N' type: 'type' }") self.assertEqual( str(cm.exception), "Attr 'N' of 'BadNumberAttr' used as a number_attr but has type type") with self.assertRaises(TypeError) as cm: self._add_op("name: 'TwoTypesA' " "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' } " "attr { name: 'T' type: 'type' }") self.assertEqual(str(cm.exception), "Arg 'a' of 'TwoTypesA' must have one type field not 2") with self.assertRaises(TypeError) as cm: self._add_op("name: 'TwoTypesB' " "input_arg { name: 'a' type: DT_INT32 type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") self.assertEqual(str(cm.exception), "Arg 'a' of 'TwoTypesB' must have one type field not 2") with self.assertRaises(TypeError) as cm: self._add_op("name: 'ThreeTypes' " "input_arg { name: 'a' type: DT_INT32 type_attr: 'T' " "type_list_attr: 'U' } " "attr { name: 'T' type: 'type' } " "attr { name: 'U' type: 'list(type)' }") self.assertEqual(str(cm.exception), "Arg 'a' of 'ThreeTypes' must have one type field not 3") with self.assertRaises(TypeError) as cm: self._add_op("name: 'NoTypes' output_arg { name: 'a' } ") self.assertEqual(str(cm.exception), "Arg 'a' of 'NoTypes' must have one type field not 0") def testSimple(self): out = self._lib.apply_op("Simple", a=3) self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'Simple' op: 'Simple' input: 'Simple/a' """, out.op.node_def) out = self._lib.apply_op("Simple", a=4) self.assertProtoEquals(""" name: 'Simple_1' op: 'Simple' input: 'Simple_1/a' """, out.op.node_def) out = self._lib.apply_op("Simple", a=5, name="named") self.assertProtoEquals(""" name: 'named' op: 'Simple' input: 'named/a' """, out.op.node_def) out = self._lib.apply_op("Simple", a=[[1, 2, 3], [4, 5, 6]], name="two_d") self.assertProtoEquals(""" name: 'two_d' op: 'Simple' input: 'two_d/a' """, out.op.node_def) def testSimpleFailures(self): with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a="Bad string") self.assertEqual(str(cm.exception), "Expected int32 passed to parameter 'a' of op 'Simple', " "got 'Bad string' of type 'str' instead.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=self.Tensor(dtypes.string)) self.assertEqual(str(cm.exception), "Input 'a' of 'Simple' Op has type string " "that does not match expected type of int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=6, extra="bogus") self.assertEqual(str(cm.exception), "apply_op() got unexpected keyword arguments: extra") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a=6, extra1="bogus", extra2="also_bogus") self.assertEqual(str(cm.exception), "apply_op() got unexpected keyword arguments: extra1, " "extra2") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple") self.assertEqual(str(cm.exception), "No argument for input a") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", wrong=7) self.assertEqual(str(cm.exception), "No argument for input a") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Simple", a={"label": 1}) self.assertEqual(str(cm.exception), "Expected int32 passed to parameter 'a' of op 'Simple', " "got {'label': 1} of type 'dict' instead.") def testReservedInput(self): self._add_op("name: 'ReservedInput' " "input_arg { name: 'input' type: DT_INT32 } ") op = self._lib.apply_op("ReservedInput", input_=7, name="x") self.assertProtoEquals(""" name: 'x' op: 'ReservedInput' input: 'x/input' """, op.node_def) def testPolymorphic(self): self._add_op("name: 'Polymorphic' " "input_arg { name: 'a' type_attr: 'T' } " "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("Polymorphic", a=7, name="p") self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'Polymorphic' input: 'p/a' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("Polymorphic", a="s", name="q") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'Polymorphic' input: 'q/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("Polymorphic", a=["s", "t", "u"], name="r") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'r' op: 'Polymorphic' input: 'r/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("Polymorphic", a="s", T=dtypes.string) self.assertEqual(str(cm.exception), "Should not specify value for inferred attr 'T'.") def testPolymorphicOut(self): self._add_op("name: 'PolymorphicOut' " "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("PolymorphicOut", T=dtypes.int32, name="p") self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("PolymorphicOut", T=dtypes.bool, name="q") self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicOut' attr { key: 'T' value { type: DT_BOOL } } """, out.op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("PolymorphicOut") self.assertEqual(str(cm.exception), "No argument for attr T") with self.assertRaises(TypeError) as cm: self._lib.apply_op("PolymorphicOut", T=None) self.assertEqual(str(cm.exception), "Expected DataType for argument 'T' not None.") def testPolymorphicDefaultOut(self): self._add_op("name: 'PolymorphicDefaultOut' " "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' " " default_value { type: DT_STRING } }") out = self._lib.apply_op("PolymorphicDefaultOut", T=None, name="p") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'p' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("PolymorphicDefaultOut", T=dtypes.bool, name="q") self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'q' op: 'PolymorphicDefaultOut' attr { key: 'T' value { type: DT_BOOL } } """, out.op.node_def) def testBinary(self): self._add_op("name: 'Binary' " "input_arg { name: 'a' type_attr: 'T' } " "input_arg { name: 'b' type_attr: 'T' } " "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") out = self._lib.apply_op("Binary", a=8, b=9, name="b") self.assertEqual(dtypes.int32, out.dtype) self.assertProtoEquals(""" name: 'b' op: 'Binary' input: 'b/a' input: 'b/b' attr { key: 'T' value { type: DT_INT32 } } """, out.op.node_def) out = self._lib.apply_op("Binary", a="left", b="right", name="c") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'c' op: 'Binary' input: 'c/a' input: 'c/b' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("Binary", a="left", b=12) self.assertEqual(str(cm.exception), "Expected string passed to parameter 'b' of op 'Binary', " "got 12 of type 'int' instead.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Binary", a=self.Tensor(dtypes.string), b=self.Tensor(dtypes.int32)) self.assertEqual(str(cm.exception), "Input 'b' of 'Binary' Op has type int32 " "that does not match type string of argument 'a'.") def testRestrict(self): self._add_op("name: 'Restrict' " "input_arg { name: 'a' type_attr: 'T' } " "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' allowed_values { list { " " type: DT_STRING type: DT_BOOL } } }") out = self._lib.apply_op("Restrict", a="foo", name="g") self.assertEqual(dtypes.string, out.dtype) self.assertProtoEquals(""" name: 'g' op: 'Restrict' input: 'g/a' attr { key: 'T' value { type: DT_STRING } } """, out.op.node_def) out = self._lib.apply_op("Restrict", a=True, name="h") self.assertEqual(dtypes.bool, out.dtype) self.assertProtoEquals(""" name: 'h' op: 'Restrict' input: 'h/a' attr { key: 'T' value { type: DT_BOOL } } """, out.op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("Restrict", a=17) self.assertEqual(str(cm.exception), "Value passed to parameter 'a' has DataType int32 " "not in list of allowed values: string, bool") def testTypeList(self): self._add_op("name: 'TypeList' " "input_arg { name: 'a' type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") op = self._lib.apply_op("TypeList", a=["foo"], name="z") self.assertProtoEquals(""" name: 'z' op: 'TypeList' input: 'z/a_0' attr { key: 'T' value { list { type: DT_STRING } } } """, op.node_def) op = self._lib.apply_op("TypeList", a=[True, 12], name="y") self.assertProtoEquals(""" name: 'y' op: 'TypeList' input: 'y/a_0' input: 'y/a_1' attr { key: 'T' value { list { type: DT_BOOL type: DT_INT32 } } } """, op.node_def) op = self._lib.apply_op("TypeList", a=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'TypeList' attr { key: 'T' value { list { } } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeList", a=17) self.assertStartsWith(str(cm.exception), "Expected list for 'a' " "argument to 'TypeList' Op, not ") with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeList", a=[self.Tensor(dtypes.int32), None]) self.assertStartsWith(str(cm.exception), "Tensors in list passed to 'a' of 'TypeList' Op " "have types [int32, ]") def testTypeListTwice(self): self._add_op("name: 'TypeListTwice' " "input_arg { name: 'a' type_list_attr: 'T' } " "input_arg { name: 'b' type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") op = self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", False], name="z") self.assertProtoEquals(""" name: 'z' op: 'TypeListTwice' input: 'z/a_0' input: 'z/a_1' input: 'z/b_0' input: 'z/b_1' attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } } """, op.node_def) op = self._lib.apply_op("TypeListTwice", a=[], b=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'TypeListTwice' attr { key: 'T' value { list { } } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeListTwice", a=["foo", True], b=["bar", 6]) self.assertEqual(str(cm.exception), "Input 'b' of 'TypeListTwice' Op has type list of " "string, int32 that does not match type list " "string, bool of argument 'a'.") def testOutTypeList(self): self._add_op("name: 'OutTypeList' " "output_arg { name: 'out' type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' }") out, = self._lib.apply_op("OutTypeList", T=[dtypes.float32], name="x") self.assertEqual(dtypes.float32, out.dtype) self.assertProtoEquals(""" name: 'x' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_FLOAT } } } """, out.op.node_def) out1, out2 = self._lib.apply_op("OutTypeList", T=[dtypes.int32, dtypes.bool], name="w") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" name: 'w' op: 'OutTypeList' attr { key: 'T' value { list { type: DT_INT32 type: DT_BOOL } } } """, out1.op.node_def) out = self._lib.apply_op("OutTypeList", T=[], name="empty") self.assertEqual([], out) with self.assertRaises(TypeError) as cm: self._lib.apply_op("OutTypeList", T=dtypes.int32) self.assertEqual(str(cm.exception), "Expected list for attr T") def testTypeListRestrict(self): self._add_op("name: 'TypeListRestrict' " "input_arg { name: 'a' type_list_attr: 'T' } " "attr { name: 'T' type: 'list(type)' allowed_values { list { " " type: DT_STRING type: DT_BOOL } } }") op = self._lib.apply_op("TypeListRestrict", a=["foo", False], name="v") self.assertProtoEquals(""" name: 'v' op: 'TypeListRestrict' input: 'v/a_0' input: 'v/a_1' attr { key: 'T' value { list { type: DT_STRING type: DT_BOOL } } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("TypeListRestrict", a=[True, 12]) self.assertEqual(str(cm.exception), "Value passed to parameter 'a' has DataType int32 " "not in list of allowed values: string, bool") def testOutTypeListRestrict(self): self._add_op("name: 'OutTypeListRestrict' " "output_arg { name: 'out' type_list_attr: 't' } " "attr { name: 't' type: 'list(type)' allowed_values { list { " " type: DT_STRING type: DT_BOOL } } }") out1, out2 = self._lib.apply_op("OutTypeListRestrict", t=[dtypes.bool, dtypes.string], name="u") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.string, out2.dtype) self.assertProtoEquals(""" name: 'u' op: 'OutTypeListRestrict' attr { key: 't' value { list { type: DT_BOOL type: DT_STRING } } } """, out1.op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("OutTypeListRestrict", t=[dtypes.string, dtypes.int32]) self.assertEqual(str(cm.exception), "Value passed to parameter 't' has DataType int32 " "not in list of allowed values: string, bool") def testAttr(self): self._add_op("name: 'Attr' attr { name: 'a' type: 'int' }") op = self._lib.apply_op("Attr", a=12, name="t") self.assertProtoEquals(""" name: 't' op: 'Attr' attr { key: 'a' value { i: 12 } } """, op.node_def) op = self._lib.apply_op("Attr", a=tensor_shape.Dimension(13), name="u") self.assertProtoEquals(""" name: 'u' op: 'Attr' attr { key: 'a' value { i: 13 } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a="bad") self.assertEqual(str(cm.exception), "Expected int for argument 'a' not 'bad'.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a=[12]) self.assertEqual(str(cm.exception), "Expected int for argument 'a' not [12].") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr", a=None) self.assertEqual(str(cm.exception), "Expected int for argument 'a' not None.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("Attr") self.assertEqual(str(cm.exception), "No argument for attr a") def testAttrFloat(self): self._add_op("name: 'AttrFloat' attr { name: 'a' type: 'float' }") op = self._lib.apply_op("AttrFloat", a=1.2, name="t") self.assertProtoEquals(""" name: 't' op: 'AttrFloat' attr { key: 'a' value { f: 1.2 } } """, op.node_def) op = self._lib.apply_op("AttrFloat", a=12, name="u") self.assertProtoEquals(""" name: 'u' op: 'AttrFloat' attr { key: 'a' value { f: 12 } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrFloat", a="bad") self.assertEqual(str(cm.exception), "Expected float for argument 'a' not 'bad'.") def testAttrBool(self): self._add_op("name: 'AttrBool' attr { name: 'a' type: 'bool' }") op = self._lib.apply_op("AttrBool", a=True, name="t") self.assertProtoEquals(""" name: 't' op: 'AttrBool' attr { key: 'a' value { b: true } } """, op.node_def) op = self._lib.apply_op("AttrBool", a=False, name="u") self.assertProtoEquals(""" name: 'u' op: 'AttrBool' attr { key: 'a' value { b: false } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=0) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=1) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 1.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBool", a=[]) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not [].") def testAttrBoolList(self): self._add_op("name: 'AttrBoolList' attr { name: 'a' type: 'list(bool)' }") op = self._lib.apply_op("AttrBoolList", a=[True, False, True], name="t") self.assertProtoEquals(""" name: 't' op: 'AttrBoolList' attr { key: 'a' value { list { b: true b: false b:true } } } """, op.node_def) op = self._lib.apply_op("AttrBoolList", a=[], name="u") self.assertProtoEquals(""" name: 'u' op: 'AttrBoolList' attr { key: 'a' value { list { } } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("AttrBoolList", a=[0]) self.assertEqual(str(cm.exception), "Expected bool for argument 'a' not 0.") def testAttrMin(self): self._add_op("name: 'AttrMin' attr { name: 'a' type: 'int' " "has_minimum: true minimum: 5 }") op = self._lib.apply_op("AttrMin", a=12, name="s") self.assertProtoEquals(""" name: 's' op: 'AttrMin' attr { key: 'a' value { i: 12 } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrMin", a=2) self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrMin' Op passed 2 less than minimum 5.") def testAttrListMin(self): self._add_op("name: 'AttrListMin' attr { name: 'a' type: 'list(int)' " "has_minimum: true minimum: 2 }") op = self._lib.apply_op("AttrListMin", a=[1, 2], name="r") self.assertProtoEquals(""" name: 'r' op: 'AttrListMin' attr { key: 'a' value { list { i: 1 i: 2 } } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrListMin", a=[17]) self.assertEqual(str(cm.exception), "Attr 'a' of 'AttrListMin' Op " "passed list of length 1 less than minimum 2.") def testAttrEnum(self): self._add_op("name: 'AttrEnum' " "attr { name: 'a' type: 'string' " " allowed_values { list { s: 'apples' s: 'oranges' } } }") op = self._lib.apply_op("AttrEnum", a="oranges", name="e") self.assertProtoEquals(""" name: 'e' op: 'AttrEnum' attr { key: 'a' value { s: 'oranges' } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrEnum", a="invalid") self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnum\' Op ' 'passed string \'invalid\' not in: ' '"apples", "oranges".') def testAttrEnumList(self): self._add_op("name: 'AttrEnumList' " "attr { name: 'a' type: 'list(string)' " " allowed_values { list { s: 'apples' s: 'oranges' } } }") op = self._lib.apply_op("AttrEnumList", a=["oranges", "apples"], name="f") self.assertProtoEquals(""" name: 'f' op: 'AttrEnumList' attr { key: 'a' value { list { s: 'oranges' s: 'apples' } } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("AttrEnumList", a=["apples", "invalid", "oranges"]) self.assertEqual(str(cm.exception), 'Attr \'a\' of \'AttrEnumList\' Op ' 'passed string \'invalid\' not ' 'in: "apples", "oranges".') def testAttrShape(self): self._add_op("name: 'AttrShape' attr { name: 'a' type: 'shape' }") op = self._lib.apply_op("AttrShape", a=[5], name="s1") self.assertProtoEquals(""" name: 's1' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 5 } } } } """, op.node_def) op = self._lib.apply_op("AttrShape", a=(4, 3, 2), name="s2") self.assertProtoEquals(""" name: 's2' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 4 } dim { size: 3 } dim { size: 2 } } } } """, op.node_def) op = self._lib.apply_op( "AttrShape", a=tensor_shape.TensorShape([3, 2]), name="s3") self.assertProtoEquals(""" name: 's3' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 3 } dim { size: 2 } } } } """, op.node_def) op = self._lib.apply_op("AttrShape", a=[], name="s4") self.assertProtoEquals(""" name: 's4' op: 'AttrShape' attr { key: 'a' value { shape { } } } """, op.node_def) shape = tensor_shape_pb2.TensorShapeProto() shape.dim.add().size = 6 shape.dim.add().size = 3 op = self._lib.apply_op("AttrShape", a=shape, name="s5") self.assertProtoEquals(""" name: 's5' op: 'AttrShape' attr { key: 'a' value { shape { dim { size: 6 } dim { size: 3 } } } } """, op.node_def) # TODO(josh11b): Re-enable this test once we stop promoting scalars to shapes. # with self.assertRaises(TypeError) as cm: # self._lib.apply_op("AttrShape", a=5) # self.assertEqual(str(cm.exception), # "Don't know how to convert 5 to a TensorShapeProto for " # "argument 'a'") with self.assertRaises(TypeError): self._lib.apply_op("AttrShape", a="ABC") def testAttrShapeList(self): self._add_op("name: 'AttrShapeList' attr { name: 'a' type: 'list(shape)' }") op = self._lib.apply_op("AttrShapeList", a=[[3, 2], [6, 5, 4]], name="sl") self.assertProtoEquals(""" name: 'sl' op: 'AttrShapeList' attr { key: 'a' value { list { shape { dim { size: 3 } dim { size: 2 } } shape { dim { size: 6 } dim { size: 5 } dim { size: 4 } } } } } """, op.node_def) op = self._lib.apply_op("AttrShapeList", a=[], name="esl") self.assertProtoEquals(""" name: 'esl' op: 'AttrShapeList' attr { key: 'a' value { list { } } } """, op.node_def) def testAttrPartialShape(self): self._add_op( "name: 'AttrPartialShape' attr { name: 'a' type: 'shape' }") op = self._lib.apply_op("AttrPartialShape", a=[5], name="s1") self.assertProtoEquals(""" name: 's1' op: 'AttrPartialShape' attr { key: 'a' value { shape { dim { size: 5 } } } } """, op.node_def) op = self._lib.apply_op("AttrPartialShape", a=(4, None, 2), name="s2") self.assertProtoEquals(""" name: 's2' op: 'AttrPartialShape' attr { key: 'a' value { shape { dim { size: 4 } dim { size: -1 } dim { size: 2 } } } } """, op.node_def) op = self._lib.apply_op( "AttrPartialShape", a=tensor_shape.TensorShape([3, None]), name="s3") self.assertProtoEquals(""" name: 's3' op: 'AttrPartialShape' attr { key: 'a' value { shape { dim { size: 3 } dim { size: -1 } } } } """, op.node_def) op = self._lib.apply_op("AttrPartialShape", a=[], name="s4") self.assertProtoEquals(""" name: 's4' op: 'AttrPartialShape' attr { key: 'a' value { shape { } } } """, op.node_def) shape = tensor_shape_pb2.TensorShapeProto() shape.dim.add().size = -1 shape.dim.add().size = 3 op = self._lib.apply_op("AttrPartialShape", a=shape, name="s5") self.assertProtoEquals(""" name: 's5' op: 'AttrPartialShape' attr { key: 'a' value { shape { dim { size: -1 } dim { size: 3 } } } } """, op.node_def) # TODO(ebrevdo): Re-enable once we stop promoting scalars to shapes. # with self.assertRaises(TypeError) as cm: # self._lib.apply_op("AttrPartialShape", a=5) # self.assertEqual(str(cm.exception), # "Don't know how to convert 5 to a TensorShapeProto for " # "argument 'a'") with self.assertRaises(TypeError): self._lib.apply_op("AttrPartialShape", a="ABC") def testAttrPartialShapeList(self): self._add_op(""" name: 'AttrPartialShapeList' attr { name: 'a' type: 'list(shape)' } """) op = self._lib.apply_op( "AttrPartialShapeList", a=[[3, 2], [6, None, 4]], name="sl") self.assertProtoEquals(""" name: 'sl' op: 'AttrPartialShapeList' attr { key: 'a' value { list { shape { dim { size: 3 } dim { size: 2 } } shape { dim { size: 6 } dim { size: -1 } dim { size: 4 } } } } } """, op.node_def) op = self._lib.apply_op("AttrPartialShapeList", a=[], name="esl") self.assertProtoEquals(""" name: 'esl' op: 'AttrPartialShapeList' attr { key: 'a' value { list { } } } """, op.node_def) def testAttrDefault(self): self._add_op("name: 'AttrDefault' " "attr { name: 'a' type: 'string' " " default_value { s: 'banana' } }") op = self._lib.apply_op("AttrDefault", a=None, name="d") self.assertProtoEquals(""" name: 'd' op: 'AttrDefault' attr { key: 'a' value { s: 'banana' } } """, op.node_def) op = self._lib.apply_op("AttrDefault", a="kiwi", name="c") self.assertProtoEquals(""" name: 'c' op: 'AttrDefault' attr { key: 'a' value { s: 'kiwi' } } """, op.node_def) def testAttrListDefault(self): self._add_op("name: 'AttrListDefault' " "attr { name: 'a' type: 'list(int)' " " default_value { list { i: 5 i: 15 } } }") op = self._lib.apply_op("AttrListDefault", a=None, name="b") self.assertProtoEquals(""" name: 'b' op: 'AttrListDefault' attr { key: 'a' value { list { i: 5 i: 15 } } } """, op.node_def) op = self._lib.apply_op("AttrListDefault", a=[3], name="a") self.assertProtoEquals(""" name: 'a' op: 'AttrListDefault' attr { key: 'a' value { list { i: 3 } } } """, op.node_def) op = self._lib.apply_op("AttrListDefault", a=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'AttrListDefault' attr { key: 'a' value { list { } } } """, op.node_def) def testAttrEmptyListDefault(self): self._add_op("name: 'AttrEmptyListDefault' " "attr { name: 'a' type: 'list(float)' " " default_value { list { } } }") op = self._lib.apply_op("AttrEmptyListDefault", a=None, name="b") self.assertProtoEquals(""" name: 'b' op: 'AttrEmptyListDefault' attr { key: 'a' value { list { } } } """, op.node_def) op = self._lib.apply_op("AttrEmptyListDefault", a=[3], name="a") self.assertProtoEquals(""" name: 'a' op: 'AttrEmptyListDefault' attr { key: 'a' value { list { f: 3 } } } """, op.node_def) op = self._lib.apply_op("AttrEmptyListDefault", a=[], name="empty") self.assertProtoEquals(""" name: 'empty' op: 'AttrEmptyListDefault' attr { key: 'a' value { list { } } } """, op.node_def) def testReservedAttr(self): self._add_op("name: 'ReservedAttr' " "attr { name: 'range' type: 'int' } ") op = self._lib.apply_op("ReservedAttr", range_=7, name="x") self.assertProtoEquals(""" name: 'x' op: 'ReservedAttr' attr { key: 'range' value { i: 7 } } """, op.node_def) def testDefaultAttrType(self): self._add_op("name: 'AttrTypeDefault' " "input_arg { name: 'a' type_attr: 'T' } " "attr { name: 'T' type: 'type' " " default_value { type: DT_INT32 } }") # Give an input whose type has no obvious output type. op = self._lib.apply_op("AttrTypeDefault", a=[], name="n") self.assertProtoEquals(""" name: 'n' op: 'AttrTypeDefault' input: 'n/a' attr { key: 'T' value { type: DT_INT32 } } """, op.node_def) # Give an input whose type can be inferred as different # than the default. op = self._lib.apply_op("AttrTypeDefault", a=[1.0], name="f") self.assertProtoEquals(""" name: 'f' op: 'AttrTypeDefault' input: 'f/a' attr { key: 'T' value { type: DT_FLOAT } } """, op.node_def) def testDefaultListAttrType(self): self._add_op("name: 'AttrListTypeDefault' " "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type' " " default_value { type: DT_INT32 } }" "attr { name: 'N' type: 'int' }") # Give an input whose type can be inferred as different # than the default. op = self._lib.apply_op("AttrListTypeDefault", a=[1.0], b=[2.0], name="n") self.assertProtoEquals(""" name: 'n' op: 'AttrListTypeDefault' input: 'n/a_0' input: 'n/b_0' attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'N' value { i: 1 } } """, op.node_def) def testNIntsIn(self): self._add_op("name: 'NIntsIn' " "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") op = self._lib.apply_op("NIntsIn", a=[1, 2], name="n") self.assertProtoEquals(""" name: 'n' op: 'NIntsIn' input: 'n/a_0' input: 'n/a_1' attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NIntsIn", a=[5, 4, 3, 2, 1], name="o") self.assertProtoEquals(""" name: 'o' op: 'NIntsIn' input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4' attr { key: 'N' value { i: 5 } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=["foo", "bar"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[string, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.string), self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have " "types [string, string] that do not match expected type " "int32.") with self.assertRaises(ValueError) as cm: self._lib.apply_op("NIntsIn", a=[99]) self.assertEqual(str(cm.exception), "List argument 'a' to 'NIntsIn' Op " "with length 1 shorter than " "minimum length 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=[38, "bar"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op have types " "[int32, string] that do not match expected type int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=[self.Tensor(dtypes.int32), self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NIntsIn' Op " "have types [int32, string] that do not match expected " "type int32.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsIn", a=17) self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NIntsIn' Op, not ") def testNPolymorphicIn(self): self._add_op("name: 'NPolymorphicIn' " "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") op = self._lib.apply_op("NPolymorphicIn", a=[1, 2], name="n") self.assertProtoEquals(""" name: 'n' op: 'NPolymorphicIn' input: 'n/a_0' input: 'n/a_1' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NPolymorphicIn", a=[5, 4, 3, 2, 1], name="o") self.assertProtoEquals(""" name: 'o' op: 'NPolymorphicIn' input: 'o/a_0' input: 'o/a_1' input: 'o/a_2' input: 'o/a_3' input: 'o/a_4' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 5 } } """, op.node_def) op = self._lib.apply_op("NPolymorphicIn", a=["foo", "bar"], name="p") self.assertProtoEquals(""" name: 'p' op: 'NPolymorphicIn' input: 'p/a_0' input: 'p/a_1' attr { key: 'T' value { type: DT_STRING } } attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NPolymorphicIn", a=[1, self.Tensor(dtypes.float32, name="x")], name="q") self.assertProtoEquals(""" name: 'q' op: 'NPolymorphicIn' input: 'q/a_0' input: 'x' attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NPolymorphicIn", a=[self.Tensor(dtypes.float32, name="y"), self.Tensor(dtypes.float32_ref, name="z")], name="r") self.assertProtoEquals(""" name: 'r' op: 'NPolymorphicIn' input: 'y' input: 'z' attr { key: 'T' value { type: DT_FLOAT } } attr { key: 'N' value { i: 2 } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NPolymorphicIn", a=[99]) self.assertEqual(str(cm.exception), "List argument 'a' to 'NPolymorphicIn' Op with length 1 " "shorter than minimum length 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=[38, "bar"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, string] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=[38, self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, string] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=[38, None]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [int32, ] that " "don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=["abcd", self.Tensor(dtypes.int32)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'a' of 'NPolymorphicIn' Op " "have types [string, int32] that don't all match.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicIn", a=17) self.assertStartsWith(str(cm.exception), "Expected list for 'a' argument " "to 'NPolymorphicIn' Op, not ") def testNPolymorphicRestrictIn(self): self._add_op("name: 'NPolymorphicRestrictIn' " "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type' allowed_values { " " list { type: DT_STRING type: DT_BOOL } } } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") op = self._lib.apply_op("NPolymorphicRestrictIn", a=["foo", "bar"], name="p") self.assertProtoEquals(""" name: 'p' op: 'NPolymorphicRestrictIn' input: 'p/a_0' input: 'p/a_1' attr { key: 'T' value { type: DT_STRING } } attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NPolymorphicRestrictIn", a=[False, True, False], name="b") self.assertProtoEquals(""" name: 'b' op: 'NPolymorphicRestrictIn' input: 'b/a_0' input: 'b/a_1' input: 'b/a_2' attr { key: 'T' value { type: DT_BOOL } } attr { key: 'N' value { i: 3 } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicRestrictIn", a=[1, 2]) self.assertEqual(str(cm.exception), "Value passed to parameter 'a' has DataType int32 not in " "list of allowed values: string, bool") def testNInTwice(self): self._add_op("name: 'NInTwice' " "input_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "input_arg { name: 'b' type: DT_STRING number_attr: 'N' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }") op = self._lib.apply_op("NInTwice", a=[1, 2], b=["one", "two"], name="n") self.assertProtoEquals(""" name: 'n' op: 'NInTwice' input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1' attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NInTwice", a=[], b=[], name="o") self.assertProtoEquals(""" name: 'o' op: 'NInTwice' attr { key: 'N' value { i: 0 } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInTwice", a=[1, 2, 3], b=["too short"]) self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwice' Op " "with length 1 must match " "length 3 of argument 'a'.") def testNInPolymorphicTwice(self): self._add_op("name: 'NInPolymorphicTwice' " "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }") op = self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=[3, 4], name="n") self.assertProtoEquals(""" name: 'n' op: 'NInPolymorphicTwice' input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 2 } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[1, 2, 3], b=[5]) self.assertEqual(str(cm.exception), "List argument 'b' to 'NInPolymorphicTwice' Op " "with length 1 " "must match length 3 of argument 'a'.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[1, 2], b=["one", "two"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'NInPolymorphicTwice' " "Op have types [string, string] that do not match type " "int32 inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NInPolymorphicTwice", a=[self.Tensor(dtypes.int32)], b=[self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of " "'NInPolymorphicTwice' Op have types [string] that do not " "match type int32 inferred from earlier arguments.") def testNInTwoTypeVariables(self): self._add_op("name: 'NInTwoTypeVariables' " "input_arg { name: 'a' type_attr: 'S' number_attr: 'N' } " "input_arg { name: 'b' type_attr: 'T' number_attr: 'N' } " "attr { name: 'S' type: 'type' } " "attr { name: 'T' type: 'type' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 }") op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[True, False], name="n") self.assertProtoEquals(""" name: 'n' op: 'NInTwoTypeVariables' input: 'n/a_0' input: 'n/a_1' input: 'n/b_0' input: 'n/b_1' attr { key: 'S' value { type: DT_INT32 } } attr { key: 'T' value { type: DT_BOOL } } attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NInTwoTypeVariables", a=[1, 2], b=[3, 4], name="o") self.assertProtoEquals(""" name: 'o' op: 'NInTwoTypeVariables' input: 'o/a_0' input: 'o/a_1' input: 'o/b_0' input: 'o/b_1' attr { key: 'S' value { type: DT_INT32 } } attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 2 } } """, op.node_def) op = self._lib.apply_op("NInTwoTypeVariables", a=[self.Tensor(dtypes.int32, name="q")], b=[self.Tensor(dtypes.string, name="r")], name="p") self.assertProtoEquals(""" name: 'p' op: 'NInTwoTypeVariables' input: 'q' input: 'r' attr { key: 'S' value { type: DT_INT32 } } attr { key: 'T' value { type: DT_STRING } } attr { key: 'N' value { i: 1 } } """, op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NInTwoTypeVariables", a=[1, 2, 3], b=["5"]) self.assertEqual(str(cm.exception), "List argument 'b' to 'NInTwoTypeVariables' Op " "with length 1 " "must match length 3 of argument 'a'.") def testInPolymorphicTwice(self): self._add_op("name: 'InPolymorphicTwice' " "input_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "input_arg { name: 'b' type_attr: 'T' number_attr: 'M' } " "attr { name: 'T' type: 'type' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 0 } " "attr { name: 'M' type: 'int' has_minimum: true minimum: 0 } ") op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[3, 4, 5], name="n") self.assertProtoEquals(""" name: 'n' op: 'InPolymorphicTwice' input: 'n/a_0' input: 'n/b_0' input: 'n/b_1' input: 'n/b_2' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 1 } } attr { key: 'M' value { i: 3 } } """, op.node_def) op = self._lib.apply_op("InPolymorphicTwice", a=[8], b=[], name="o") self.assertProtoEquals(""" name: 'o' op: 'InPolymorphicTwice' input: 'o/a_0' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 1 } } attr { key: 'M' value { i: 0 } } """, op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[], b=[3, 4, 5]) self.assertEqual(str(cm.exception), "Don't know how to infer type variable from empty input " "list passed to input 'a' of 'InPolymorphicTwice' Op.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[1, 2], b=["one", "two"]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' Op " "have types [string, string] that do not match type int32 " "inferred from earlier arguments.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("InPolymorphicTwice", a=[self.Tensor(dtypes.int32)], b=[self.Tensor(dtypes.string)]) self.assertEqual(str(cm.exception), "Tensors in list passed to 'b' of 'InPolymorphicTwice' " "Op have types [string] that do not match type int32 " "inferred from earlier arguments.") def testNIntsOut(self): self._add_op("name: 'NIntsOut' " "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2 = self._lib.apply_op("NIntsOut", N=2, name="n") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NIntsOut' attr { key: 'N' value { i: 2 } } """, out1.op.node_def) out1, out2, out3, out4, out5 = self._lib.apply_op( "NIntsOut", N=5, name="o") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertEqual(dtypes.int32, out3.dtype) self.assertEqual(dtypes.int32, out4.dtype) self.assertEqual(dtypes.int32, out5.dtype) self.assertProtoEquals(""" name: 'o' op: 'NIntsOut' attr { key: 'N' value { i: 5 } } """, out5.op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NIntsOut", N=1) self.assertEqual(str(cm.exception), "Attr 'N' of 'NIntsOut' Op passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NIntsOut", N=[3]) self.assertEqual(str(cm.exception), "Expected int for argument 'N' not [3].") def testNIntsOutDefault(self): self._add_op("name: 'NIntsOutDefault' " "output_arg { name: 'a' type: DT_INT32 number_attr: 'N' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2" " default_value { i:3 } }") out1, out2, out3 = self._lib.apply_op( "NIntsOutDefault", N=None, name="z") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertEqual(dtypes.int32, out3.dtype) self.assertProtoEquals(""" name: 'z' op: 'NIntsOutDefault' attr { key: 'N' value { i: 3 } } """, out1.op.node_def) out1, out2 = self._lib.apply_op("NIntsOutDefault", N=2, name="y") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'y' op: 'NIntsOutDefault' attr { key: 'N' value { i: 2 } } """, out2.op.node_def) def testNPolymorphicOut(self): self._add_op("name: 'NPolymorphicOut' " "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type' } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2 = self._lib.apply_op("NPolymorphicOut", N=2, T=dtypes.int32, name="n") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 'n' op: 'NPolymorphicOut' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 2 } } """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( "NPolymorphicOut", T=dtypes.string, N=3, name="o") self.assertEqual(dtypes.string, out1.dtype) self.assertEqual(dtypes.string, out2.dtype) self.assertEqual(dtypes.string, out3.dtype) self.assertProtoEquals(""" name: 'o' op: 'NPolymorphicOut' attr { key: 'T' value { type: DT_STRING } } attr { key: 'N' value { i: 3 } } """, out3.op.node_def) with self.assertRaises(ValueError) as cm: self._lib.apply_op("NPolymorphicOut", N=1, T=dtypes.string) self.assertEqual(str(cm.exception), "Attr 'N' of 'NPolymorphicOut' Op " "passed 1 less than minimum 2.") with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicOut", N=3, T=[dtypes.string]) self.assertEqual( str(cm.exception), "Expected DataType for argument 'T' not [tf.string].") def testNPolymorphicOutDefault(self): self._add_op("name: 'NPolymorphicOutDefault' " "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type'" " default_value { type: DT_BOOL } } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 " " default_value { i: 2 } }") out1, out2 = self._lib.apply_op( "NPolymorphicOutDefault", N=None, T=None, name="r") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) self.assertProtoEquals(""" name: 'r' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_BOOL } } attr { key: 'N' value { i: 2 } } """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( "NPolymorphicOutDefault", N=3, T=None, name="s") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) self.assertEqual(dtypes.bool, out3.dtype) self.assertProtoEquals(""" name: 's' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_BOOL } } attr { key: 'N' value { i: 3 } } """, out1.op.node_def) out1, out2 = self._lib.apply_op( "NPolymorphicOutDefault", N=None, T=dtypes.int32, name="t") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertProtoEquals(""" name: 't' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 2 } } """, out1.op.node_def) out1, out2, out3 = self._lib.apply_op( "NPolymorphicOutDefault", N=3, T=dtypes.int32, name="u") self.assertEqual(dtypes.int32, out1.dtype) self.assertEqual(dtypes.int32, out2.dtype) self.assertEqual(dtypes.int32, out3.dtype) self.assertProtoEquals(""" name: 'u' op: 'NPolymorphicOutDefault' attr { key: 'T' value { type: DT_INT32 } } attr { key: 'N' value { i: 3 } } """, out1.op.node_def) def testNPolymorphicRestrictOut(self): self._add_op("name: 'NPolymorphicRestrictOut' " "output_arg { name: 'a' type_attr: 'T' number_attr: 'N' } " "attr { name: 'T' type: 'type' allowed_values { " " list { type: DT_STRING type: DT_BOOL } } } " "attr { name: 'N' type: 'int' has_minimum: true minimum: 2 }") out1, out2, out3 = self._lib.apply_op( "NPolymorphicRestrictOut", N=3, T=dtypes.bool, name="u") self.assertEqual(dtypes.bool, out1.dtype) self.assertEqual(dtypes.bool, out2.dtype) self.assertEqual(dtypes.bool, out3.dtype) self.assertProtoEquals(""" name: 'u' op: 'NPolymorphicRestrictOut' attr { key: 'T' value { type: DT_BOOL } } attr { key: 'N' value { i: 3 } } """, out1.op.node_def) with self.assertRaises(TypeError) as cm: self._lib.apply_op("NPolymorphicRestrictOut", N=2, T=dtypes.int32) self.assertEqual(str(cm.exception), "Value passed to parameter 'T' has DataType int32 " "not in list of allowed values: string, bool") def testRef(self): self._add_op("name: 'RefIn' " "input_arg { name: 'a' type_attr: 'T' is_ref: true } " "attr { name: 'T' type: 'type' } ") self._add_op("name: 'TwoRefsIn' " "input_arg { name: 'a' type_attr: 'T' is_ref: true } " "input_arg { name: 'b' type_attr: 'T' is_ref: true } " "attr { name: 'T' type: 'type' } ") self._add_op("name: 'RefOut' " "output_arg { name: 'a' type_attr: 'T' is_ref: true } " "attr { name: 'T' type: 'type' } ") out = self._lib.apply_op("RefOut", T=dtypes.bool, name="o") self.assertEqual(dtypes.bool_ref, out.dtype) self.assertProtoEquals(""" name: 'o' op: 'RefOut' attr { key: 'T' value { type: DT_BOOL } } """, out.op.node_def) op = self._lib.apply_op("RefIn", a=out, name="i") self.assertProtoEquals(""" name: 'i' op: 'RefIn' input: 'o' attr { key: 'T' value { type: DT_BOOL } } attr { key: "_class" value { list { s: "loc:@o" } } } """, op.node_def) # Can pass ref to non-ref input. out = self._lib.apply_op("RefOut", T=dtypes.int32, name="r") out = self._lib.apply_op("Simple", a=out, name="s") self.assertProtoEquals(""" name: 's' op: 'Simple' input: 'r' """, out.op.node_def) # Can't pass non-ref to ref input. with self.assertRaises(TypeError) as cm: self._lib.apply_op("RefIn", a=2) self.assertEqual(str(cm.exception), "'RefIn' Op requires that input 'a' be a mutable tensor " + "(e.g.: a tf.Variable)") input_a = self._lib.apply_op("RefOut", T=dtypes.int32, name="t") input_b = self._lib.apply_op("RefOut", T=dtypes.int32, name="u") op = self._lib.apply_op("TwoRefsIn", a=input_a, b=input_b, name="v") # NOTE(mrry): The order of colocation constraints is an implementation # detail. self.assertProtoEquals(""" name: 'v' op: 'TwoRefsIn' input: 't' input: 'u' attr { key: 'T' value { type: DT_INT32 } } attr { key: "_class" value { list { s: "loc:@t" s: "loc:@u" } } } """, op.node_def) def testSpecifyDevice(self): with self._g.device("/job:ADevice"): self._lib.apply_op("Simple", a=3) # We look at the whole graph here to make sure the Const op is also given # the specified device. graph_def = self._g.as_graph_def() self.assertEqual(len(graph_def.node), 2) for node in graph_def.node: self.assertDeviceEqual(node.device, "/job:ADevice") def testStructuredOutputSingleList(self): self._add_op("name: 'SimpleStruct' " "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } " "attr { name: 'n_a' type: 'int' }") for n_a in [0, 1, 3]: a = self._lib.apply_op("SimpleStruct", n_a=n_a) self.assertTrue(isinstance(a, list)) self.assertEqual(n_a, len(a)) def testStructuredOutputListAndSingle(self): self._add_op("name: 'MixedStruct' " "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } " "output_arg { name: 'b' type: DT_FLOAT } " "attr { name: 'n_a' type: 'int' }") for n_a in [0, 1, 3]: a, b = self._lib.apply_op("MixedStruct", n_a=n_a) self.assertTrue(isinstance(a, list)) self.assertEqual(n_a, len(a)) self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) self.assertTrue(isinstance(b, ops.Tensor)) self.assertEqual(dtypes.float32, b.dtype) def testStructuredOutputMultipleLists(self): self._add_op("name: 'ComplexStruct' " "output_arg { name: 'a' type: DT_INT32 number_attr: 'n_a' } " "output_arg { name: 'b' type: DT_INT64 number_attr: 'n_b' } " "output_arg { name: 'c' type_list_attr: 't_c' } " "attr { name: 'n_a' type: 'int' } " "attr { name: 'n_b' type: 'int' } " "attr { name: 't_c' type: 'list(type)' }") for n_a in [0, 1, 3]: for n_b in [0, 1, 3]: for t_c in [[], [dtypes.int32], [dtypes.int32, dtypes.float32]]: a, b, c = self._lib.apply_op("ComplexStruct", n_a=n_a, n_b=n_b, t_c=t_c) self.assertEqual(n_a, len(a)) self.assertTrue(all(x.dtype == dtypes.int32 for x in a)) self.assertEqual(n_b, len(b)) self.assertTrue(all(x.dtype == dtypes.int64 for x in b)) self.assertEqual(t_c, [x.dtype for x in c]) class OpDefLibraryGraphTest(test_util.TensorFlowTestCase): def setUp(self): self._lib = OpDefLibrary() self._g = ops.Graph() self._add_op("name: 'Simple' input_arg { name: 'a' type: DT_INT32 } " "output_arg { name: 'out' type: DT_FLOAT }") self._add_op("name: 'Binary' " "input_arg { name: 'a' type_attr: 'T' } " "input_arg { name: 'b' type_attr: 'T' } " "output_arg { name: 'out' type_attr: 'T' } " "attr { name: 'T' type: 'type' }") def _add_op(self, ascii): op_def = op_def_pb2.OpDef() text_format.Merge(ascii, op_def) self._lib.add_op(op_def) def testNoGraph(self): out = self._lib.apply_op("Simple", a=3) self.assertEqual(out.graph, ops.get_default_graph()) def testDefaultGraph(self): with self._g.as_default(): out = self._lib.apply_op("Simple", a=3) self.assertEqual(out.graph, self._g) def testDifferentGraphFails(self): with self._g.as_default(): a = self._lib.apply_op("Simple", a=3) other_g = ops.Graph() with other_g.as_default(): b = self._lib.apply_op("Simple", a=4) with self.assertRaises(ValueError) as cm: self._lib.apply_op("Binary", a=a, b=b) self.assertTrue("must be from the same graph" in str(cm.exception)) if __name__ == "__main__": googletest.main()