diff options
author | 2017-03-15 22:08:46 -0800 | |
---|---|---|
committer | 2017-03-15 23:31:14 -0700 | |
commit | 44550613ba6846d9fa701847f98bf776a4d008a7 (patch) | |
tree | 43d0b361bc74505b5a67f27628a6db9c7db1b300 | |
parent | dedd9ae9b026553d4e3c88ecee43dd73671a76c4 (diff) |
Use default descriptor pool in assertProtoEquals
This makes it possible to use this function for protos with Any fields.
Change: 150286292
-rw-r--r-- | tensorflow/python/framework/test_util.py | 4 | ||||
-rw-r--r-- | tensorflow/python/framework/test_util_test.py | 78 | ||||
-rw-r--r-- | tensorflow/python/util/protobuf/compare.py | 11 |
3 files changed, 63 insertions, 30 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4b49cafa51..b8195c14fa 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -29,6 +29,7 @@ import threading import numpy as np import six +from google.protobuf import descriptor_pool from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 @@ -218,7 +219,8 @@ class TensorFlowTestCase(googletest.TestCase): self._AssertProtoEquals(expected_message, message) elif isinstance(expected_message_maybe_ascii, str): expected_message = type(message)() - text_format.Merge(expected_message_maybe_ascii, expected_message) + text_format.Merge(expected_message_maybe_ascii, expected_message, + descriptor_pool=descriptor_pool.Default()) self._AssertProtoEquals(expected_message, message) else: assert False, ("Can't compare protos of type %s and %s" % diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index cb021c1170..e457b35f00 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -22,11 +22,11 @@ import random import threading import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from google.protobuf import text_format from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -88,6 +88,34 @@ class TestUtilTest(test_util.TensorFlowTestCase): # test original comparison self.assertProtoEquals(graph_def, graph_def) + def testAssertProtoEqualsAny(self): + # Test assertProtoEquals with a protobuf.Any field. + meta_graph_def_str = """ + meta_info_def { + meta_graph_version: "outer" + any_info { + [type.googleapis.com/tensorflow.MetaGraphDef] { + meta_info_def { + meta_graph_version: "inner" + } + } + } + } + """ + meta_graph_def_outer = meta_graph_pb2.MetaGraphDef() + meta_graph_def_outer.meta_info_def.meta_graph_version = "outer" + meta_graph_def_inner = meta_graph_pb2.MetaGraphDef() + meta_graph_def_inner.meta_info_def.meta_graph_version = "inner" + meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner) + self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer) + self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer) + + # Check if the assertion failure message contains the content of + # the inner proto. + with self.assertRaisesRegexp(AssertionError, + r'meta_graph_version: "inner"'): + self.assertProtoEquals("", meta_graph_def_outer) + def testNDArrayNear(self): a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) @@ -196,50 +224,50 @@ class TestUtilTest(test_util.TensorFlowTestCase): def testAssertAllCloseAccordingToType(self): # test float64 self.assertAllCloseAccordingToType( - np.asarray([1e-8], dtype=np.float64), - np.asarray([2e-8], dtype=np.float64), - rtol=1e-8, atol=1e-8 + np.asarray([1e-8], dtype=np.float64), + np.asarray([2e-8], dtype=np.float64), + rtol=1e-8, atol=1e-8 ) with (self.assertRaises(AssertionError)): self.assertAllCloseAccordingToType( - np.asarray([1e-7], dtype=np.float64), - np.asarray([2e-7], dtype=np.float64), - rtol=1e-8, atol=1e-8 + np.asarray([1e-7], dtype=np.float64), + np.asarray([2e-7], dtype=np.float64), + rtol=1e-8, atol=1e-8 ) # test float32 self.assertAllCloseAccordingToType( - np.asarray([1e-7], dtype=np.float32), - np.asarray([2e-7], dtype=np.float32), - rtol=1e-8, atol=1e-8, - float_rtol=1e-7, float_atol=1e-7 + np.asarray([1e-7], dtype=np.float32), + np.asarray([2e-7], dtype=np.float32), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7 ) with (self.assertRaises(AssertionError)): self.assertAllCloseAccordingToType( - np.asarray([1e-6], dtype=np.float32), - np.asarray([2e-6], dtype=np.float32), - rtol=1e-8, atol=1e-8, - float_rtol=1e-7, float_atol=1e-7 + np.asarray([1e-6], dtype=np.float32), + np.asarray([2e-6], dtype=np.float32), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7 ) # test float16 self.assertAllCloseAccordingToType( - np.asarray([1e-4], dtype=np.float16), - np.asarray([2e-4], dtype=np.float16), - rtol=1e-8, atol=1e-8, - float_rtol=1e-7, float_atol=1e-7, - half_rtol=1e-4, half_atol=1e-4 + np.asarray([1e-4], dtype=np.float16), + np.asarray([2e-4], dtype=np.float16), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7, + half_rtol=1e-4, half_atol=1e-4 ) with (self.assertRaises(AssertionError)): self.assertAllCloseAccordingToType( - np.asarray([1e-3], dtype=np.float16), - np.asarray([2e-3], dtype=np.float16), - rtol=1e-8, atol=1e-8, - float_rtol=1e-7, float_atol=1e-7, - half_rtol=1e-4, half_atol=1e-4 + np.asarray([1e-3], dtype=np.float16), + np.asarray([2e-3], dtype=np.float16), + rtol=1e-8, atol=1e-8, + float_rtol=1e-7, float_atol=1e-7, + half_rtol=1e-4, half_atol=1e-4 ) def testRandomSeed(self): diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py index c07ded6746..a0e6bf65cf 100644 --- a/tensorflow/python/util/protobuf/compare.py +++ b/tensorflow/python/util/protobuf/compare.py @@ -67,6 +67,7 @@ import collections import six from google.protobuf import descriptor +from google.protobuf import descriptor_pool from google.protobuf import message from google.protobuf import text_format @@ -88,8 +89,9 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva numbers before comparison. msg: if specified, is used as the error message on failure. """ + pool = descriptor_pool.Default() if isinstance(a, six.string_types): - a = text_format.Merge(a, b.__class__()) + a = text_format.Merge(a, b.__class__(), descriptor_pool=pool) for pb in a, b: if check_initialized: @@ -99,9 +101,10 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva if normalize_numbers: NormalizeNumberFields(pb) - self.assertMultiLineEqual(text_format.MessageToString(a), - text_format.MessageToString(b), - msg=msg) + self.assertMultiLineEqual( + text_format.MessageToString(a, descriptor_pool=pool), + text_format.MessageToString(b, descriptor_pool=pool), + msg=msg) def NormalizeNumberFields(pb): |