diff options
author | 2018-07-18 21:33:42 -0700 | |
---|---|---|
committer | 2018-07-18 21:36:51 -0700 | |
commit | 874de86fe803823589d6b1c1e2dbe4adc5d3408c (patch) | |
tree | e611dbd0b45c9a56e770e2a7b9145aebf2b05439 /tensorflow/contrib/proto | |
parent | 2422a250654757480e1c3e301a2a4d3564e9ff25 (diff) |
Changing the mapping between proto and TF types.
PiperOrigin-RevId: 205185039
Diffstat (limited to 'tensorflow/contrib/proto')
-rw-r--r-- | tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py | 31 | ||||
-rw-r--r-- | tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py | 72 |
2 files changed, 54 insertions, 49 deletions
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py index 5f7f510352..e3570e38a3 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py @@ -106,34 +106,27 @@ class DecodeProtoOpTestBase(test_base.ProtoOpTestBase, parameterized.TestCase): self.assertEqual(v, ev) continue - # This can be a little confusing. For testing we are using TestValue in - # two ways: it's the proto that we decode for testing, and it's used in - # the expected value as a union type. - # - # The two cases are slightly different: this is the second case. We may be - # fetching the uint64_value from the test proto, but in the expected proto - # we store it in the int64_value field because TensorFlow doesn't support - # unsigned int64. tf_type_to_primitive_value_field = { + dtypes.bool: + 'bool_value', dtypes.float32: 'float_value', dtypes.float64: 'double_value', - dtypes.int32: - 'int32_value', - dtypes.uint8: - 'uint8_value', dtypes.int8: 'int8_value', - dtypes.string: - 'string_value', + dtypes.int32: + 'int32_value', dtypes.int64: 'int64_value', - dtypes.bool: - 'bool_value', - # Unhandled TensorFlow types: - # DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 - # DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 + dtypes.string: + 'string_value', + dtypes.uint8: + 'uint8_value', + dtypes.uint32: + 'uint32_value', + dtypes.uint64: + 'uint64_value', } tf_field_name = tf_type_to_primitive_value_field.get(field.dtype) if tf_field_name is None: diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py index cbc7b3d3f8..2950c7dfdc 100644 --- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py +++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py @@ -44,7 +44,7 @@ class ProtoOpTestBase(test.TestCase): ("minmax", ProtoOpTestBase.minmax_test_case()), ("nested", ProtoOpTestBase.nested_test_case()), ("optional", ProtoOpTestBase.optional_test_case()), - ("promote_unsigned", ProtoOpTestBase.promote_unsigned_test_case()), + ("promote", ProtoOpTestBase.promote_test_case()), ("ragged", ProtoOpTestBase.ragged_test_case()), ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()), ("simple", ProtoOpTestBase.simple_test_case()), @@ -83,13 +83,13 @@ class ProtoOpTestBase(test.TestCase): test_case.sizes.append(0) field = test_case.fields.add() field.name = "uint64_value_with_default" - field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(4) + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4) test_case.sizes.append(0) field = test_case.fields.add() field.name = "fixed64_value_with_default" - field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(6) + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(6) test_case.sizes.append(0) field = test_case.fields.add() field.name = "int32_value_with_default" @@ -108,13 +108,13 @@ class ProtoOpTestBase(test.TestCase): test_case.sizes.append(0) field = test_case.fields.add() field.name = "uint32_value_with_default" - field.dtype = types_pb2.DT_INT32 - field.value.int32_value.append(9) + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(9) test_case.sizes.append(0) field = test_case.fields.add() field.name = "fixed32_value_with_default" - field.dtype = types_pb2.DT_INT32 - field.value.int32_value.append(7) + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(7) test_case.sizes.append(0) field = test_case.fields.add() field.name = "bool_value_with_default" @@ -202,15 +202,15 @@ class ProtoOpTestBase(test.TestCase): test_case.sizes.append(2) field = test_case.fields.add() field.name = "uint64_value" - field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(0) - field.value.int64_value.append(-1) + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) test_case.sizes.append(2) field = test_case.fields.add() field.name = "fixed64_value" - field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(0) - field.value.int64_value.append(-1) + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(0) + field.value.uint64_value.append(18446744073709551615) test_case.sizes.append(2) field = test_case.fields.add() field.name = "int32_value" @@ -232,15 +232,15 @@ class ProtoOpTestBase(test.TestCase): test_case.sizes.append(2) field = test_case.fields.add() field.name = "uint32_value" - field.dtype = types_pb2.DT_INT32 - field.value.int32_value.append(0) - field.value.int32_value.append(-1) + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) test_case.sizes.append(2) field = test_case.fields.add() field.name = "fixed32_value" - field.dtype = types_pb2.DT_INT32 - field.value.int32_value.append(0) - field.value.int32_value.append(-1) + field.dtype = types_pb2.DT_UINT32 + field.value.uint32_value.append(0) + field.value.uint32_value.append(4294967295) test_case.sizes.append(2) field = test_case.fields.add() field.name = "bool_value" @@ -289,28 +289,40 @@ class ProtoOpTestBase(test.TestCase): return test_case @staticmethod - def promote_unsigned_test_case(): + def promote_test_case(): test_case = test_example_pb2.TestCase() value = test_case.values.add() + value.sint32_value.append(2147483647) + value.sfixed32_value.append(2147483647) + value.int32_value.append(2147483647) value.fixed32_value.append(4294967295) value.uint32_value.append(4294967295) test_case.shapes.append(1) test_case.sizes.append(1) field = test_case.fields.add() - field.name = "fixed32_value" + field.name = "sint32_value" field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(4294967295) + field.value.int64_value.append(2147483647) test_case.sizes.append(1) field = test_case.fields.add() - field.name = "uint32_value" + field.name = "sfixed32_value" field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(4294967295) - # Comes from an explicitly-specified default - test_case.sizes.append(0) + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) field = test_case.fields.add() - field.name = "uint32_value_with_default" + field.name = "int32_value" field.dtype = types_pb2.DT_INT64 - field.value.int64_value.append(9) + field.value.int64_value.append(2147483647) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "fixed32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) + test_case.sizes.append(1) + field = test_case.fields.add() + field.name = "uint32_value" + field.dtype = types_pb2.DT_UINT64 + field.value.uint64_value.append(4294967295) return test_case @staticmethod |