aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-02-08 16:05:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 16:09:41 -0800
commitf87caf7cd62be25a9c7e390b55b4933fcdcc784c (patch)
tree2ff94486c422b91c9964d41b3161292e3cc4d61d /tensorflow/python/saved_model
parent5a2c810376ad4ea4d55f8eaa696250011c1f5005 (diff)
Make SavedModelTest.testStripDefaultAttrsInconsistentConsumerDefaults work with C API.
The test originally altered the Python version of the op registry, which is not reflected in the C API. This changes the test to alter the serialized node def instead of the op def, and renames the test to testInconsistentConsumerDefaultAttrs. PiperOrigin-RevId: 185067838
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/BUILD1
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py91
2 files changed, 53 insertions, 39 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index e34aa7cc2c..30e0a099d8 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -148,6 +148,7 @@ py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:saver_test_utils",
"//tensorflow/python:state_ops",
+ "//tensorflow/python:test_ops",
"//tensorflow/python:util",
"//tensorflow/python:variables",
],
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index f92247d52e..d1f6bc27ef 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import os
-from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
@@ -28,8 +27,8 @@ from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import control_flow_ops
@@ -945,61 +944,75 @@ class SavedModelTest(test.TestCase):
self.assertIn("T", node_def.attr)
self.assertIn("Tout", node_def.attr)
- def testStripDefaultAttrsInconsistentConsumerDefaults(self):
- if ops._USE_C_API: return # TODO(skyewm): get this working
-
+ # Tests the behavior of loading SavedModels that having missing attrs or attrs
+ # with incorrect types.
+ def testInconsistentConsumerDefaultAttrs(self):
export_dir = self._get_export_dir(
"test_strip_default_attrs_no_consumer_defaults")
builder = saved_model_builder.SavedModelBuilder(export_dir)
- # Add a graph with two float32 variables and a Complex Op composing them
- # with strip_default_attrs enabled. This must remove the following
- # defaults for the "Complex" Op:
- # o "T" : float32. (input type)
- # o "Tout" : complex64. (output type)
+ # Add a graph with a single variable and a test op with a defaultless
+ # float32 attr, "test_attr".
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
- math_ops.complex(real_num, imag_num, name="complex")
+ variables.Variable(1.0, dtype=dtypes.float64, name="var")
+ test_ops.test_attr(T=dtypes.float32, name="test_attr")
sess.run(variables.global_variables_initializer())
- builder.add_meta_graph_and_variables(
- sess, ["foo"], strip_default_attrs=True)
+ builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
- # Update the Op registry to remove defaults for all attrs("T", "Tout") from
- # the "Complex" OpDef.
- complex_op_def = op_def_registry.get_registered_ops()["Complex"]
- original_complex_op_def = op_def_pb2.OpDef()
- original_complex_op_def.CopyFrom(complex_op_def)
- for attr_def in complex_op_def.attr:
- attr_def.ClearField("default_value")
+ # Rewrite the SavedModel to remove the T attr from "test_attr".
+ saved_model_file = os.path.join(
+ export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
+ with open(saved_model_file) as f:
+ original_saved_model = f.read()
+
+ no_attr_saved_model = original_saved_model.replace("""
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }""", "")
+ with open(saved_model_file, "w") as f:
+ f.write(no_attr_saved_model)
# Loading the SavedModel via the loader must fail because the SavedModel
- # does not have any attr values for the "Complex" node and the current
- # op registry does not have have any default values for the "Complex" op.
+ # does not have any attr values for the "TestAttr" node, and there is no
+ # default specified in the TestAttr OpDef.
sess = session.Session(graph=ops.Graph())
- with self.assertRaisesRegexp(
- ValueError,
- "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
+ if ops._USE_C_API:
+ error_message = "NodeDef missing attr 'T' from Op<name=TestAttr"
+ else:
+ error_message = ("Expected one attr with name .*T(out)?.* in name: "
+ "\"test_attr\".*")
+ with self.assertRaisesRegexp(ValueError, error_message):
loader.load(sess, ["foo"], export_dir)
- # Update the Op registry to change the defaults for attr "Tout"
- # (complex64 -> complex128).
- complex_op_def.CopyFrom(original_complex_op_def)
- for attr_def in complex_op_def.attr:
- if attr_def.name == "Tout":
- attr_def.default_value.type = types_pb2.DT_COMPLEX128
-
- # Loading the SavedModel via the loader must set "Tout" attr_value for the
- # "Complex" node according to the latest defaults (complex128). This is
- # expected to fail the model import as there is no OpKernel registered to
- # handle attrs "T" (float32) and "Tout" (complex128).
+ # Rewrite the SavedModel to change the type of the T attr in "test_attr"
+ bad_type_saved_model = original_saved_model.replace("""
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }""", """
+ attr {
+ key: "T"
+ value {
+ type: DT_DOUBLE
+ }
+ }""")
+ with open(saved_model_file, "w") as f:
+ f.write(bad_type_saved_model)
+
+ # Loading the SavedModel via the loader must fail because there is no
+ # OpKernel registered to handle T = double.
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- ".*No OpKernel was registered to support Op \'Complex\' with these "
+ ".*No OpKernel was registered to support Op \'TestAttr\' with these "
"attrs..*"):
loader.load(sess, ["foo"], export_dir)