From f87caf7cd62be25a9c7e390b55b4933fcdcc784c Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 8 Feb 2018 16:05:31 -0800 Subject: Make SavedModelTest.testStripDefaultAttrsInconsistentConsumerDefaults work with C API. The test originally altered the Python version of the op registry, which is not reflected in the C API. This changes the test to alter the serialized node def instead of the op def, and renames the test to testInconsistentConsumerDefaultAttrs. PiperOrigin-RevId: 185067838 --- tensorflow/python/framework/test_ops.cc | 19 +++++ tensorflow/python/saved_model/BUILD | 1 + tensorflow/python/saved_model/saved_model_test.py | 91 +++++++++++++---------- 3 files changed, 72 insertions(+), 39 deletions(-) diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc index c6c6c2233c..070b5ac11f 100644 --- a/tensorflow/python/framework/test_ops.cc +++ b/tensorflow/python/framework/test_ops.cc @@ -76,6 +76,11 @@ REGISTER_OP("TestStringOutput") .Output("output2: string") .SetShapeFn(shape_inference::UnknownShape); +REGISTER_OP("TestAttr") + .Output("out: T") + .Attr("T: {float, double}") + .SetShapeFn(shape_inference::UnknownShape); + namespace { enum KernelLabel { DEFAULT_LABEL, OVERLOAD_1_LABEL, OVERLOAD_2_LABEL }; } // namespace @@ -188,6 +193,20 @@ class ResourceUsingOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("ResourceUsingOp").Device(DEVICE_CPU), ResourceUsingOp); +class TestAttrOp : public OpKernel { + public: + explicit TestAttrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + Tensor* output; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + output->scalar()() = 1.0; + } +}; + +REGISTER_KERNEL_BUILDER( + Name("TestAttr").Device(DEVICE_CPU).TypeConstraint("T"), TestAttrOp); + // Various test ops without kernels. These are used to test graph construction. REGISTER_OP("A") diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index e34aa7cc2c..30e0a099d8 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -148,6 +148,7 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:saver_test_utils", "//tensorflow/python:state_ops", + "//tensorflow/python:test_ops", "//tensorflow/python:util", "//tensorflow/python:variables", ], diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index f92247d52e..d1f6bc27ef 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import os -from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 @@ -28,8 +27,8 @@ from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import test_ops from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import control_flow_ops @@ -945,61 +944,75 @@ class SavedModelTest(test.TestCase): self.assertIn("T", node_def.attr) self.assertIn("Tout", node_def.attr) - def testStripDefaultAttrsInconsistentConsumerDefaults(self): - if ops._USE_C_API: return # TODO(skyewm): get this working - + # Tests the behavior of loading SavedModels that having missing attrs or attrs + # with incorrect types. + def testInconsistentConsumerDefaultAttrs(self): export_dir = self._get_export_dir( "test_strip_default_attrs_no_consumer_defaults") builder = saved_model_builder.SavedModelBuilder(export_dir) - # Add a graph with two float32 variables and a Complex Op composing them - # with strip_default_attrs enabled. This must remove the following - # defaults for the "Complex" Op: - # o "T" : float32. (input type) - # o "Tout" : complex64. (output type) + # Add a graph with a single variable and a test op with a defaultless + # float32 attr, "test_attr". with session.Session(graph=ops.Graph()) as sess: - real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") - imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") - math_ops.complex(real_num, imag_num, name="complex") + variables.Variable(1.0, dtype=dtypes.float64, name="var") + test_ops.test_attr(T=dtypes.float32, name="test_attr") sess.run(variables.global_variables_initializer()) - builder.add_meta_graph_and_variables( - sess, ["foo"], strip_default_attrs=True) + builder.add_meta_graph_and_variables(sess, ["foo"]) # Save the SavedModel to disk in text format. builder.save(as_text=True) - # Update the Op registry to remove defaults for all attrs("T", "Tout") from - # the "Complex" OpDef. - complex_op_def = op_def_registry.get_registered_ops()["Complex"] - original_complex_op_def = op_def_pb2.OpDef() - original_complex_op_def.CopyFrom(complex_op_def) - for attr_def in complex_op_def.attr: - attr_def.ClearField("default_value") + # Rewrite the SavedModel to remove the T attr from "test_attr". + saved_model_file = os.path.join( + export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) + with open(saved_model_file) as f: + original_saved_model = f.read() + + no_attr_saved_model = original_saved_model.replace(""" + attr { + key: "T" + value { + type: DT_FLOAT + } + }""", "") + with open(saved_model_file, "w") as f: + f.write(no_attr_saved_model) # Loading the SavedModel via the loader must fail because the SavedModel - # does not have any attr values for the "Complex" node and the current - # op registry does not have have any default values for the "Complex" op. + # does not have any attr values for the "TestAttr" node, and there is no + # default specified in the TestAttr OpDef. sess = session.Session(graph=ops.Graph()) - with self.assertRaisesRegexp( - ValueError, - "Expected one attr with name .*T(out)?.* in name: \"complex\".*"): + if ops._USE_C_API: + error_message = "NodeDef missing attr 'T' from Op complex128). - complex_op_def.CopyFrom(original_complex_op_def) - for attr_def in complex_op_def.attr: - if attr_def.name == "Tout": - attr_def.default_value.type = types_pb2.DT_COMPLEX128 - - # Loading the SavedModel via the loader must set "Tout" attr_value for the - # "Complex" node according to the latest defaults (complex128). This is - # expected to fail the model import as there is no OpKernel registered to - # handle attrs "T" (float32) and "Tout" (complex128). + # Rewrite the SavedModel to change the type of the T attr in "test_attr" + bad_type_saved_model = original_saved_model.replace(""" + attr { + key: "T" + value { + type: DT_FLOAT + } + }""", """ + attr { + key: "T" + value { + type: DT_DOUBLE + } + }""") + with open(saved_model_file, "w") as f: + f.write(bad_type_saved_model) + + # Loading the SavedModel via the loader must fail because there is no + # OpKernel registered to handle T = double. sess = session.Session(graph=ops.Graph()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - ".*No OpKernel was registered to support Op \'Complex\' with these " + ".*No OpKernel was registered to support Op \'TestAttr\' with these " "attrs..*"): loader.load(sess, ["foo"], export_dir) -- cgit v1.2.3