aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-05-10 12:31:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-10 17:45:01 -0700
commit15f32d93a48fa78f7f2f17c0a46f5625637c82aa (patch)
treea9a63150e48109cbd7296e7434d41ac649f46df3 /tensorflow
parent32482dea01e2fc672eee9df5305bccc0b432eb7b (diff)
get_attr returns dtype objects instead of raw ints
Originally by @Mycosynth, but edited to make it work using integer conversion. Fixes #447. If this breaks you, do an explicit cast of the return of get_attr to int, if you are parsing a dtype object. PiperOrigin-RevId: 155661630
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/framework/dtypes.py3
-rw-r--r--tensorflow/python/framework/dtypes_test.py4
-rw-r--r--tensorflow/python/framework/ops.py10
-rw-r--r--tensorflow/python/framework/ops_test.py27
-rw-r--r--tensorflow/python/kernel_tests/py_func_test.py4
-rw-r--r--tensorflow/python/util/example_parser_configuration.py4
6 files changed, 44 insertions, 8 deletions
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index d373bac47a..3e6c04982b 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -270,6 +270,9 @@ class DType(object):
"""Returns the string name for this `DType`."""
return _TYPE_TO_STRING[self._type_enum]
+ def __int__(self):
+ return self._type_enum
+
def __str__(self):
return "<dtype: %r>" % self.name
diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py
index f04f67ffed..5bb60763b6 100644
--- a/tensorflow/python/framework/dtypes_test.py
+++ b/tensorflow/python/framework/dtypes_test.py
@@ -45,8 +45,8 @@ class TypesTest(test_util.TensorFlowTestCase):
for datatype_enum in types_pb2.DataType.values():
if datatype_enum == types_pb2.DT_INVALID:
continue
- self.assertEqual(datatype_enum,
- dtypes.as_dtype(datatype_enum).as_datatype_enum)
+ dt = dtypes.as_dtype(datatype_enum)
+ self.assertEqual(datatype_enum, dt.as_datatype_enum)
def testAllTypesConvertibleToNumpyDtype(self):
for datatype_enum in types_pb2.DataType.values():
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 2581ac6a3c..7f3a7eb876 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1666,12 +1666,18 @@ class Operation(object):
if x.HasField("list"):
for f in fields:
if getattr(x.list, f):
- return list(getattr(x.list, f))
+ if f == "type":
+ return [dtypes.as_dtype(x) for x in list(getattr(x.list, f))]
+ else:
+ return list(getattr(x.list, f))
return []
else:
for f in fields:
if x.HasField(f):
- return getattr(x, f)
+ if f == "type":
+ return dtypes.as_dtype(getattr(x, f))
+ else:
+ return getattr(x, f)
assert False, "Unsupported field type in " + str(x)
def run(self, feed_dict=None, session=None):
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 3e9f047a7d..eae36a8613 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -22,6 +22,7 @@ import gc
import weakref
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import common_shapes
@@ -356,6 +357,32 @@ class OperationTest(test_util.TensorFlowTestCase):
ops._NodeDef("noop", "op1"), ops.Graph(), [], [dtypes.float32])
self.assertEqual("<tf.Operation 'op1' type=noop>", repr(op))
+ def testGetAttr(self):
+ list_value = attr_value_pb2.AttrValue.ListValue()
+ list_value.type.append(types_pb2.DT_STRING)
+ list_value.type.append(types_pb2.DT_DOUBLE)
+ op = ops.Operation(
+ ops._NodeDef(
+ "noop",
+ "op1",
+ attrs={
+ "value": attr_value_pb2.AttrValue(i=32),
+ "dtype": attr_value_pb2.AttrValue(type=types_pb2.DT_INT32),
+ "list": attr_value_pb2.AttrValue(list=list_value)
+ }), ops.Graph(), [], [dtypes.int32])
+ self.assertEqual(32, op.get_attr("value"))
+
+ d = op.get_attr("dtype")
+ # First check that d is a DType, because the assertEquals will
+ # work no matter what since DType overrides __eq__
+ self.assertIsInstance(d, dtypes.DType)
+ self.assertEqual(dtypes.int32, d)
+
+ l = op.get_attr("list")
+ for x in l:
+ self.assertIsInstance(x, dtypes.DType)
+ self.assertEqual([dtypes.string, dtypes.double], l)
+
class CreateOpTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/py_func_test.py b/tensorflow/python/kernel_tests/py_func_test.py
index c7fc7dd582..e098cf3ff9 100644
--- a/tensorflow/python/kernel_tests/py_func_test.py
+++ b/tensorflow/python/kernel_tests/py_func_test.py
@@ -186,9 +186,9 @@ class PyOpTest(test.TestCase):
def bad():
# Non-string python objects aren't supported.
- return dtypes.float32
+ return {"foo": dtypes.float32}
- z, = script_ops.py_func(bad, [], [dtypes.float64])
+ z, = script_ops.py_func(bad, [], [dtypes.int64])
with self.assertRaisesRegexp(errors.UnimplementedError,
"Unsupported object type"):
diff --git a/tensorflow/python/util/example_parser_configuration.py b/tensorflow/python/util/example_parser_configuration.py
index 8843016a97..a375085176 100644
--- a/tensorflow/python/util/example_parser_configuration.py
+++ b/tensorflow/python/util/example_parser_configuration.py
@@ -101,7 +101,7 @@ def extract_example_parser_configuration(parse_example_op, sess):
fixed_config.shape.CopyFrom(
tensor_shape.TensorShape(dense_shapes[i]).as_proto())
- fixed_config.dtype = dense_types[i]
+ fixed_config.dtype = int(dense_types[i])
# Get the output tensor name.
fixed_config.values_output_tensor_name = parse_example_op.outputs[
dense_values_start + i].name
@@ -111,7 +111,7 @@ def extract_example_parser_configuration(parse_example_op, sess):
key = fetched[sparse_keys_start + i]
feature_config = config.feature_map[key]
var_len_feature = feature_config.var_len_feature
- var_len_feature.dtype = sparse_types[i]
+ var_len_feature.dtype = int(sparse_types[i])
var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
sparse_indices_start + i].name
var_len_feature.values_output_tensor_name = parse_example_op.outputs[