diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-01 12:00:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-01 12:04:27 -0700 |
commit | ec618a458f88a01cb5c5c179da1cb5b610c8cbd4 (patch) | |
tree | 1f9204d87c16c8d8efc647d0ddd33e74a4e276d1 /tensorflow/contrib/autograph | |
parent | 09c8964acdeeb11634c43bd5ac0c68d7588f2c01 (diff) |
Better error message when @autograph.convert(recursive=True) fails
PiperOrigin-RevId: 206967298
Diffstat (limited to 'tensorflow/contrib/autograph')
3 files changed, 54 insertions, 2 deletions
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py index 73125eb452..7e7ef5a3e2 100644 --- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py +++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py @@ -44,6 +44,33 @@ class ModelWithStaticConditional(object): return x +class BasicBlock(tf.keras.Model): + + def __init__(self): + super(BasicBlock, self).__init__() + self.conv1 = tf.keras.layers.Conv2D(8, 3) + self.pool = tf.keras.layers.GlobalAveragePooling2D() + self.dense = tf.keras.layers.Dense(3) + + def call(self, x): + x = self.conv1(x) + x = self.pool(x) + x = self.dense(x) + return x + + +class CompoundModel(tf.keras.Model): + + def __init__(self): + super(CompoundModel, self).__init__() + self.block = BasicBlock() + + @autograph.convert(recursive=True) + def call(self, x): + x = self.block(x) # pylint: disable=not-callable + return x + + class KerasTest(tf.test.TestCase): def test_basic(self): @@ -57,6 +84,20 @@ class KerasTest(tf.test.TestCase): model = ModelWithStaticConditional(True) self.assertEqual(model.call(), 25) + def test_recursive_true(self): + with self.assertRaisesRegexp(NotImplementedError, + 'Object conversion is not yet supported.'): + with tf.Graph().as_default(): + model = CompoundModel() + model.build(tf.TensorShape((None, 10, 10, 1))) + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + sample_input = tf.random_uniform((1, 10, 10, 1)) + output = model(sample_input) # pylint: disable=not-callable + self.assertEqual(sess.run(output).shape, (1, 3)) + if __name__ == '__main__': tf.test.main() diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py index afb10d4d8b..fc8a976d3f 100644 --- a/tensorflow/contrib/autograph/impl/conversion.py +++ b/tensorflow/contrib/autograph/impl/conversion.py @@ -118,6 +118,17 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) elif tf_inspect.ismethod(o): node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) + # TODO(mdan,yashkatariya): Remove when object conversion is implemented. + elif hasattr(o, '__class__'): + raise NotImplementedError( + 'Object conversion is not yet supported. If you are ' + 'trying to convert code that uses an existing object, ' + 'try including the creation of that object in the ' + 'conversion. For example, instead of converting the method ' + 'of a class, try converting the entire class instead. ' + 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' + 'contrib/autograph/README.md#using-the-functional-api ' + 'for more information.') else: raise ValueError( 'Entity "%s" has unsupported type "%s". Only functions and classes are ' @@ -181,7 +192,7 @@ def class_to_graph(c, program_ctx): class_name = namer.compiled_class_name(c.__name__, c) # TODO(mdan): This needs to be explained more thoroughly. - # Process any base classes: if the sueprclass if of a whitelisted type, an + # Process any base classes: if the superclass if of a whitelisted type, an # absolute import line is generated. Otherwise, it is marked for conversion # (as a side effect of the call to namer.compiled_class_name() followed by # program_ctx.update_name_map(namer)). diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py index 1c5d4d09c4..86432573a7 100644 --- a/tensorflow/contrib/autograph/impl/conversion_test.py +++ b/tensorflow/contrib/autograph/impl/conversion_test.py @@ -50,7 +50,7 @@ class ConversionTest(test.TestCase): self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) def test_entity_to_graph_unsupported_types(self): - with self.assertRaises(ValueError): + with self.assertRaises(NotImplementedError): program_ctx = self._simple_program_ctx() conversion.entity_to_graph('dummy', program_ctx, None, None) |