diff options
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 87 |
1 files changed, 42 insertions, 45 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 4e931e00c5..3087d6060b 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -31,11 +31,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.framework import versions @@ -359,55 +357,54 @@ class OperationTest(test_util.TensorFlowTestCase): self.assertEqual("<tf.Operation 'op1' type=None>", repr(op)) def testGetAttr(self): - op = test_ops.default_attrs() - self.assertEqual(op.get_attr("string_val"), b"abc") - self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) - self.assertEqual(op.get_attr("int_val"), 123) - self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) - self.assertEqual(op.get_attr("float_val"), 10.0) - self.assertEqual(op.get_attr("float_list_val"), [10.0]) - self.assertEqual(op.get_attr("bool_val"), True) - self.assertEqual(op.get_attr("bool_list_val"), [True, False]) - self.assertEqual(op.get_attr("shape_val"), - tensor_shape.as_shape([2, 1]).as_proto()) - self.assertEqual(op.get_attr("shape_list_val"), - [tensor_shape.as_shape([]).as_proto(), - tensor_shape.as_shape([1]).as_proto()]) - self.assertEqual(op.get_attr("tensor_val"), - tensor_util.make_tensor_proto(1, dtypes.int32)) - self.assertEqual(op.get_attr("tensor_list_val"), - [tensor_util.make_tensor_proto(1, dtypes.int32)]) - - type_val = op.get_attr("type_val") - # First check that type_val is a DType, because the assertEquals will work - # no matter what since DType overrides __eq__ - self.assertIsInstance(type_val, dtypes.DType) - self.assertEqual(type_val, dtypes.int32) - - type_list_val = op.get_attr("type_list_val") - self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) - self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) - - @function.Defun(dtypes.float32, func_name="MyFunc") - def func(x): - return x - - op = test_ops.func_attr(func) - self.assertEqual(op.get_attr("f"), - attr_value_pb2.NameAttrList(name="MyFunc")) - - # Try fetching missing attr + # TODO(b/65162920): implement all tests for get_attr with C API if ops._USE_C_API: - error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'." - else: - error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\"" + op = test_ops.int_attr().op + self.assertEqual(op.get_attr("foo"), 1) + + op_str = test_ops.string_list_attr(a=["z"], b="y") + self.assertEqual(op_str.get_attr("a"), [b"z"]) + self.assertEqual(op_str.get_attr("b"), b"y") - with self.assertRaisesRegexp(ValueError, error_msg): - op.get_attr("FakeAttr") + else: + list_value = attr_value_pb2.AttrValue.ListValue() + + list_value.type.append(types_pb2.DT_STRING) + list_value.type.append(types_pb2.DT_DOUBLE) + op = ops.Operation( + ops._NodeDef( + "None", + "op1", + attrs={ + "value": + attr_value_pb2.AttrValue(i=32), + "dtype": + attr_value_pb2.AttrValue(type=types_pb2.DT_INT32), + "list": + attr_value_pb2.AttrValue(list=list_value), + "func": + attr_value_pb2.AttrValue( + func=attr_value_pb2.NameAttrList()) + }), ops.Graph(), [], [dtypes.int32]) + self.assertEqual(32, op.get_attr("value")) + self.assertEqual("", op.get_attr("func").name) + + d = op.get_attr("dtype") + # First check that d is a DType, because the assertEquals will + # work no matter what since DType overrides __eq__ + self.assertIsInstance(d, dtypes.DType) + self.assertEqual(dtypes.int32, d) + + l = op.get_attr("list") + for x in l: + self.assertIsInstance(x, dtypes.DType) + self.assertEqual([dtypes.string, dtypes.double], l) # TODO(b/65162920): remove this test when users who are directly mutating the # node_def have been updated to proper usage. def testSetAttr(self): + if not ops._USE_C_API: + return op = test_ops.int_attr().op op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) # TODO(skyewm): add node_def check |