diff options
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/framework/dtypes.py | 3 | ||||
-rw-r--r-- | tensorflow/python/framework/dtypes_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 10 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 27 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/py_func_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/util/example_parser_configuration.py | 4 |
6 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index d373bac47a..3e6c04982b 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -270,6 +270,9 @@ class DType(object): """Returns the string name for this `DType`.""" return _TYPE_TO_STRING[self._type_enum] + def __int__(self): + return self._type_enum + def __str__(self): return "<dtype: %r>" % self.name diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py index f04f67ffed..5bb60763b6 100644 --- a/tensorflow/python/framework/dtypes_test.py +++ b/tensorflow/python/framework/dtypes_test.py @@ -45,8 +45,8 @@ class TypesTest(test_util.TensorFlowTestCase): for datatype_enum in types_pb2.DataType.values(): if datatype_enum == types_pb2.DT_INVALID: continue - self.assertEqual(datatype_enum, - dtypes.as_dtype(datatype_enum).as_datatype_enum) + dt = dtypes.as_dtype(datatype_enum) + self.assertEqual(datatype_enum, dt.as_datatype_enum) def testAllTypesConvertibleToNumpyDtype(self): for datatype_enum in types_pb2.DataType.values(): diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 2581ac6a3c..7f3a7eb876 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1666,12 +1666,18 @@ class Operation(object): if x.HasField("list"): for f in fields: if getattr(x.list, f): - return list(getattr(x.list, f)) + if f == "type": + return [dtypes.as_dtype(x) for x in list(getattr(x.list, f))] + else: + return list(getattr(x.list, f)) return [] else: for f in fields: if x.HasField(f): - return getattr(x, f) + if f == "type": + return dtypes.as_dtype(getattr(x, f)) + else: + return getattr(x, f) assert False, "Unsupported field type in " + str(x) def run(self, feed_dict=None, session=None): diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index 3e9f047a7d..eae36a8613 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -22,6 +22,7 @@ import gc import weakref from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import common_shapes @@ -356,6 +357,32 @@ class OperationTest(test_util.TensorFlowTestCase): ops._NodeDef("noop", "op1"), ops.Graph(), [], [dtypes.float32]) self.assertEqual("<tf.Operation 'op1' type=noop>", repr(op)) + def testGetAttr(self): + list_value = attr_value_pb2.AttrValue.ListValue() + list_value.type.append(types_pb2.DT_STRING) + list_value.type.append(types_pb2.DT_DOUBLE) + op = ops.Operation( + ops._NodeDef( + "noop", + "op1", + attrs={ + "value": attr_value_pb2.AttrValue(i=32), + "dtype": attr_value_pb2.AttrValue(type=types_pb2.DT_INT32), + "list": attr_value_pb2.AttrValue(list=list_value) + }), ops.Graph(), [], [dtypes.int32]) + self.assertEqual(32, op.get_attr("value")) + + d = op.get_attr("dtype") + # First check that d is a DType, because the assertEquals will + # work no matter what since DType overrides __eq__ + self.assertIsInstance(d, dtypes.DType) + self.assertEqual(dtypes.int32, d) + + l = op.get_attr("list") + for x in l: + self.assertIsInstance(x, dtypes.DType) + self.assertEqual([dtypes.string, dtypes.double], l) + class CreateOpTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py index c7fc7dd582..e098cf3ff9 100644 --- a/tensorflow/python/kernel_tests/py_func_test.py +++ b/tensorflow/python/kernel_tests/py_func_test.py @@ -186,9 +186,9 @@ class PyOpTest(test.TestCase): def bad(): # Non-string python objects aren't supported. - return dtypes.float32 + return {"foo": dtypes.float32} - z, = script_ops.py_func(bad, [], [dtypes.float64]) + z, = script_ops.py_func(bad, [], [dtypes.int64]) with self.assertRaisesRegexp(errors.UnimplementedError, "Unsupported object type"): diff --git a/tensorflow/python/util/example_parser_configuration.py b/tensorflow/python/util/example_parser_configuration.py index 8843016a97..a375085176 100644 --- a/tensorflow/python/util/example_parser_configuration.py +++ b/tensorflow/python/util/example_parser_configuration.py @@ -101,7 +101,7 @@ def extract_example_parser_configuration(parse_example_op, sess): fixed_config.shape.CopyFrom( tensor_shape.TensorShape(dense_shapes[i]).as_proto()) - fixed_config.dtype = dense_types[i] + fixed_config.dtype = int(dense_types[i]) # Get the output tensor name. fixed_config.values_output_tensor_name = parse_example_op.outputs[ dense_values_start + i].name @@ -111,7 +111,7 @@ def extract_example_parser_configuration(parse_example_op, sess): key = fetched[sparse_keys_start + i] feature_config = config.feature_map[key] var_len_feature = feature_config.var_len_feature - var_len_feature.dtype = sparse_types[i] + var_len_feature.dtype = int(sparse_types[i]) var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ sparse_indices_start + i].name var_len_feature.values_output_tensor_name = parse_example_op.outputs[ |