aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r--tensorflow/python/framework/ops_test.py87
1 files changed, 42 insertions, 45 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 4e931e00c5..3087d6060b 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -31,11 +31,9 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
@@ -359,55 +357,54 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
def testGetAttr(self):
- op = test_ops.default_attrs()
- self.assertEqual(op.get_attr("string_val"), b"abc")
- self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
- self.assertEqual(op.get_attr("int_val"), 123)
- self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
- self.assertEqual(op.get_attr("float_val"), 10.0)
- self.assertEqual(op.get_attr("float_list_val"), [10.0])
- self.assertEqual(op.get_attr("bool_val"), True)
- self.assertEqual(op.get_attr("bool_list_val"), [True, False])
- self.assertEqual(op.get_attr("shape_val"),
- tensor_shape.as_shape([2, 1]).as_proto())
- self.assertEqual(op.get_attr("shape_list_val"),
- [tensor_shape.as_shape([]).as_proto(),
- tensor_shape.as_shape([1]).as_proto()])
- self.assertEqual(op.get_attr("tensor_val"),
- tensor_util.make_tensor_proto(1, dtypes.int32))
- self.assertEqual(op.get_attr("tensor_list_val"),
- [tensor_util.make_tensor_proto(1, dtypes.int32)])
-
- type_val = op.get_attr("type_val")
- # First check that type_val is a DType, because the assertEquals will work
- # no matter what since DType overrides __eq__
- self.assertIsInstance(type_val, dtypes.DType)
- self.assertEqual(type_val, dtypes.int32)
-
- type_list_val = op.get_attr("type_list_val")
- self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
- self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
-
- @function.Defun(dtypes.float32, func_name="MyFunc")
- def func(x):
- return x
-
- op = test_ops.func_attr(func)
- self.assertEqual(op.get_attr("f"),
- attr_value_pb2.NameAttrList(name="MyFunc"))
-
- # Try fetching missing attr
+ # TODO(b/65162920): implement all tests for get_attr with C API
if ops._USE_C_API:
- error_msg = "Operation 'FuncAttr' has no attr named 'FakeAttr'."
- else:
- error_msg = "No attr named 'FakeAttr' in name: \"FuncAttr\""
+ op = test_ops.int_attr().op
+ self.assertEqual(op.get_attr("foo"), 1)
+
+ op_str = test_ops.string_list_attr(a=["z"], b="y")
+ self.assertEqual(op_str.get_attr("a"), [b"z"])
+ self.assertEqual(op_str.get_attr("b"), b"y")
- with self.assertRaisesRegexp(ValueError, error_msg):
- op.get_attr("FakeAttr")
+ else:
+ list_value = attr_value_pb2.AttrValue.ListValue()
+
+ list_value.type.append(types_pb2.DT_STRING)
+ list_value.type.append(types_pb2.DT_DOUBLE)
+ op = ops.Operation(
+ ops._NodeDef(
+ "None",
+ "op1",
+ attrs={
+ "value":
+ attr_value_pb2.AttrValue(i=32),
+ "dtype":
+ attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
+ "list":
+ attr_value_pb2.AttrValue(list=list_value),
+ "func":
+ attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList())
+ }), ops.Graph(), [], [dtypes.int32])
+ self.assertEqual(32, op.get_attr("value"))
+ self.assertEqual("", op.get_attr("func").name)
+
+ d = op.get_attr("dtype")
+ # First check that d is a DType, because the assertEquals will
+ # work no matter what since DType overrides __eq__
+ self.assertIsInstance(d, dtypes.DType)
+ self.assertEqual(dtypes.int32, d)
+
+ l = op.get_attr("list")
+ for x in l:
+ self.assertIsInstance(x, dtypes.DType)
+ self.assertEqual([dtypes.string, dtypes.double], l)
# TODO(b/65162920): remove this test when users who are directly mutating the
# node_def have been updated to proper usage.
def testSetAttr(self):
+ if not ops._USE_C_API:
+ return
op = test_ops.int_attr().op
op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
# TODO(skyewm): add node_def check