aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/proto
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-07-18 21:33:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-18 21:36:51 -0700
commit874de86fe803823589d6b1c1e2dbe4adc5d3408c (patch)
treee611dbd0b45c9a56e770e2a7b9145aebf2b05439 /tensorflow/contrib/proto
parent2422a250654757480e1c3e301a2a4d3564e9ff25 (diff)
Changing the mapping between proto and TF types.
PiperOrigin-RevId: 205185039
Diffstat (limited to 'tensorflow/contrib/proto')
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py31
-rw-r--r--tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py72
2 files changed, 54 insertions, 49 deletions
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
index 5f7f510352..e3570e38a3 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
@@ -106,34 +106,27 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase):
self.assertEqual(v, ev)
continue
- # This can be a little confusing. For testing we are using TestValue in
- # two ways: it's the proto that we decode for testing, and it's used in
- # the expected value as a union type.
- #
- # The two cases are slightly different: this is the second case. We may be
- # fetching the uint64_value from the test proto, but in the expected proto
- # we store it in the int64_value field because TensorFlow doesn't support
- # unsigned int64.
tf_type_to_primitive_value_field = {
+ dtypes.bool:
+ 'bool_value',
dtypes.float32:
'float_value',
dtypes.float64:
'double_value',
- dtypes.int32:
- 'int32_value',
- dtypes.uint8:
- 'uint8_value',
dtypes.int8:
'int8_value',
- dtypes.string:
- 'string_value',
+ dtypes.int32:
+ 'int32_value',
dtypes.int64:
'int64_value',
- dtypes.bool:
- 'bool_value',
- # Unhandled TensorFlow types:
- # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
- # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ dtypes.string:
+ 'string_value',
+ dtypes.uint8:
+ 'uint8_value',
+ dtypes.uint32:
+ 'uint32_value',
+ dtypes.uint64:
+ 'uint64_value',
}
tf_field_name = tf_type_to_primitive_value_field.get(field.dtype)
if tf_field_name is None:
diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
index cbc7b3d3f8..2950c7dfdc 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
@@ -44,7 +44,7 @@ class ProtoOpTestBase(test.TestCase):
("minmax", ProtoOpTestBase.minmax_test_case()),
("nested", ProtoOpTestBase.nested_test_case()),
("optional", ProtoOpTestBase.optional_test_case()),
- ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()),
+ ("promote", ProtoOpTestBase.promote_test_case()),
("ragged", ProtoOpTestBase.ragged_test_case()),
("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
("simple", ProtoOpTestBase.simple_test_case()),
@@ -83,13 +83,13 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "uint64_value_with_default"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(4)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "fixed64_value_with_default"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(6)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(6)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "int32_value_with_default"
@@ -108,13 +108,13 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "uint32_value_with_default"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(9)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(9)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "fixed32_value_with_default"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(7)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(7)
test_case.sizes.append(0)
field = test_case.fields.add()
field.name = "bool_value_with_default"
@@ -202,15 +202,15 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "uint64_value"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(0)
- field.value.int64_value.append(-1)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(0)
+ field.value.uint64_value.append(18446744073709551615)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "fixed64_value"
- field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(0)
- field.value.int64_value.append(-1)
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(0)
+ field.value.uint64_value.append(18446744073709551615)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "int32_value"
@@ -232,15 +232,15 @@ class ProtoOpTestBase(test.TestCase):
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "uint32_value"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(0)
- field.value.int32_value.append(-1)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(0)
+ field.value.uint32_value.append(4294967295)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "fixed32_value"
- field.dtype = types_pb2.DT_INT32
- field.value.int32_value.append(0)
- field.value.int32_value.append(-1)
+ field.dtype = types_pb2.DT_UINT32
+ field.value.uint32_value.append(0)
+ field.value.uint32_value.append(4294967295)
test_case.sizes.append(2)
field = test_case.fields.add()
field.name = "bool_value"
@@ -289,28 +289,40 @@ class ProtoOpTestBase(test.TestCase):
return test_case
@staticmethod
- def promote_unsigned_test_case():
+ def promote_test_case():
test_case = test_example_pb2.TestCase()
value = test_case.values.add()
+ value.sint32_value.append(2147483647)
+ value.sfixed32_value.append(2147483647)
+ value.int32_value.append(2147483647)
value.fixed32_value.append(4294967295)
value.uint32_value.append(4294967295)
test_case.shapes.append(1)
test_case.sizes.append(1)
field = test_case.fields.add()
- field.name = "fixed32_value"
+ field.name = "sint32_value"
field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(4294967295)
+ field.value.int64_value.append(2147483647)
test_case.sizes.append(1)
field = test_case.fields.add()
- field.name = "uint32_value"
+ field.name = "sfixed32_value"
field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(4294967295)
- # Comes from an explicitly-specified default
- test_case.sizes.append(0)
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
field = test_case.fields.add()
- field.name = "uint32_value_with_default"
+ field.name = "int32_value"
field.dtype = types_pb2.DT_INT64
- field.value.int64_value.append(9)
+ field.value.int64_value.append(2147483647)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "fixed32_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4294967295)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = "uint32_value"
+ field.dtype = types_pb2.DT_UINT64
+ field.value.uint64_value.append(4294967295)
return test_case
@staticmethod