aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-15 22:08:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-15 23:31:14 -0700
commit44550613ba6846d9fa701847f98bf776a4d008a7 (patch)
tree43d0b361bc74505b5a67f27628a6db9c7db1b300
parentdedd9ae9b026553d4e3c88ecee43dd73671a76c4 (diff)
Use default descriptor pool in assertProtoEquals
This makes it possible to use this function for protos with Any fields. Change: 150286292
-rw-r--r--tensorflow/python/framework/test_util.py4
-rw-r--r--tensorflow/python/framework/test_util_test.py78
-rw-r--r--tensorflow/python/util/protobuf/compare.py11
3 files changed, 63 insertions, 30 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4b49cafa51..b8195c14fa 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -29,6 +29,7 @@ import threading
import numpy as np
import six
+from google.protobuf import descriptor_pool
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
@@ -218,7 +219,8 @@ class TensorFlowTestCase(googletest.TestCase):
self._AssertProtoEquals(expected_message, message)
elif isinstance(expected_message_maybe_ascii, str):
expected_message = type(message)()
- text_format.Merge(expected_message_maybe_ascii, expected_message)
+ text_format.Merge(expected_message_maybe_ascii, expected_message,
+ descriptor_pool=descriptor_pool.Default())
self._AssertProtoEquals(expected_message, message)
else:
assert False, ("Can't compare protos of type %s and %s" %
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index cb021c1170..e457b35f00 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -22,11 +22,11 @@ import random
import threading
import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -88,6 +88,34 @@ class TestUtilTest(test_util.TensorFlowTestCase):
# test original comparison
self.assertProtoEquals(graph_def, graph_def)
+ def testAssertProtoEqualsAny(self):
+ # Test assertProtoEquals with a protobuf.Any field.
+ meta_graph_def_str = """
+ meta_info_def {
+ meta_graph_version: "outer"
+ any_info {
+ [type.googleapis.com/tensorflow.MetaGraphDef] {
+ meta_info_def {
+ meta_graph_version: "inner"
+ }
+ }
+ }
+ }
+ """
+ meta_graph_def_outer = meta_graph_pb2.MetaGraphDef()
+ meta_graph_def_outer.meta_info_def.meta_graph_version = "outer"
+ meta_graph_def_inner = meta_graph_pb2.MetaGraphDef()
+ meta_graph_def_inner.meta_info_def.meta_graph_version = "inner"
+ meta_graph_def_outer.meta_info_def.any_info.Pack(meta_graph_def_inner)
+ self.assertProtoEquals(meta_graph_def_str, meta_graph_def_outer)
+ self.assertProtoEquals(meta_graph_def_outer, meta_graph_def_outer)
+
+ # Check if the assertion failure message contains the content of
+ # the inner proto.
+ with self.assertRaisesRegexp(AssertionError,
+ r'meta_graph_version: "inner"'):
+ self.assertProtoEquals("", meta_graph_def_outer)
+
def testNDArrayNear(self):
a1 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
a2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -196,50 +224,50 @@ class TestUtilTest(test_util.TensorFlowTestCase):
def testAssertAllCloseAccordingToType(self):
# test float64
self.assertAllCloseAccordingToType(
- np.asarray([1e-8], dtype=np.float64),
- np.asarray([2e-8], dtype=np.float64),
- rtol=1e-8, atol=1e-8
+ np.asarray([1e-8], dtype=np.float64),
+ np.asarray([2e-8], dtype=np.float64),
+ rtol=1e-8, atol=1e-8
)
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
- np.asarray([1e-7], dtype=np.float64),
- np.asarray([2e-7], dtype=np.float64),
- rtol=1e-8, atol=1e-8
+ np.asarray([1e-7], dtype=np.float64),
+ np.asarray([2e-7], dtype=np.float64),
+ rtol=1e-8, atol=1e-8
)
# test float32
self.assertAllCloseAccordingToType(
- np.asarray([1e-7], dtype=np.float32),
- np.asarray([2e-7], dtype=np.float32),
- rtol=1e-8, atol=1e-8,
- float_rtol=1e-7, float_atol=1e-7
+ np.asarray([1e-7], dtype=np.float32),
+ np.asarray([2e-7], dtype=np.float32),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7
)
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
- np.asarray([1e-6], dtype=np.float32),
- np.asarray([2e-6], dtype=np.float32),
- rtol=1e-8, atol=1e-8,
- float_rtol=1e-7, float_atol=1e-7
+ np.asarray([1e-6], dtype=np.float32),
+ np.asarray([2e-6], dtype=np.float32),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7
)
# test float16
self.assertAllCloseAccordingToType(
- np.asarray([1e-4], dtype=np.float16),
- np.asarray([2e-4], dtype=np.float16),
- rtol=1e-8, atol=1e-8,
- float_rtol=1e-7, float_atol=1e-7,
- half_rtol=1e-4, half_atol=1e-4
+ np.asarray([1e-4], dtype=np.float16),
+ np.asarray([2e-4], dtype=np.float16),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7,
+ half_rtol=1e-4, half_atol=1e-4
)
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
- np.asarray([1e-3], dtype=np.float16),
- np.asarray([2e-3], dtype=np.float16),
- rtol=1e-8, atol=1e-8,
- float_rtol=1e-7, float_atol=1e-7,
- half_rtol=1e-4, half_atol=1e-4
+ np.asarray([1e-3], dtype=np.float16),
+ np.asarray([2e-3], dtype=np.float16),
+ rtol=1e-8, atol=1e-8,
+ float_rtol=1e-7, float_atol=1e-7,
+ half_rtol=1e-4, half_atol=1e-4
)
def testRandomSeed(self):
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index c07ded6746..a0e6bf65cf 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -67,6 +67,7 @@ import collections
import six
from google.protobuf import descriptor
+from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import text_format
@@ -88,8 +89,9 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva
numbers before comparison.
msg: if specified, is used as the error message on failure.
"""
+ pool = descriptor_pool.Default()
if isinstance(a, six.string_types):
- a = text_format.Merge(a, b.__class__())
+ a = text_format.Merge(a, b.__class__(), descriptor_pool=pool)
for pb in a, b:
if check_initialized:
@@ -99,9 +101,10 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva
if normalize_numbers:
NormalizeNumberFields(pb)
- self.assertMultiLineEqual(text_format.MessageToString(a),
- text_format.MessageToString(b),
- msg=msg)
+ self.assertMultiLineEqual(
+ text_format.MessageToString(a, descriptor_pool=pool),
+ text_format.MessageToString(b, descriptor_pool=pool),
+ msg=msg)
def NormalizeNumberFields(pb):