aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-01 12:00:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 12:04:27 -0700
commitec618a458f88a01cb5c5c179da1cb5b610c8cbd4 (patch)
tree1f9204d87c16c8d8efc647d0ddd33e74a4e276d1 /tensorflow/contrib/autograph
parent09c8964acdeeb11634c43bd5ac0c68d7588f2c01 (diff)
Better error message when @autograph.convert(recursive=True) fails
PiperOrigin-RevId: 206967298
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/keras_test.py41
-rw-r--r--tensorflow/contrib/autograph/impl/conversion.py13
-rw-r--r--tensorflow/contrib/autograph/impl/conversion_test.py2
3 files changed, 54 insertions, 2 deletions
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
index 73125eb452..7e7ef5a3e2 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
+++ b/tensorflow/contrib/autograph/examples/integration_tests/keras_test.py
@@ -44,6 +44,33 @@ class ModelWithStaticConditional(object):
return x
+class BasicBlock(tf.keras.Model):
+
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+ self.conv1 = tf.keras.layers.Conv2D(8, 3)
+ self.pool = tf.keras.layers.GlobalAveragePooling2D()
+ self.dense = tf.keras.layers.Dense(3)
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = self.dense(x)
+ return x
+
+
+class CompoundModel(tf.keras.Model):
+
+ def __init__(self):
+ super(CompoundModel, self).__init__()
+ self.block = BasicBlock()
+
+ @autograph.convert(recursive=True)
+ def call(self, x):
+ x = self.block(x) # pylint: disable=not-callable
+ return x
+
+
class KerasTest(tf.test.TestCase):
def test_basic(self):
@@ -57,6 +84,20 @@ class KerasTest(tf.test.TestCase):
model = ModelWithStaticConditional(True)
self.assertEqual(model.call(), 25)
+ def test_recursive_true(self):
+ with self.assertRaisesRegexp(NotImplementedError,
+ 'Object conversion is not yet supported.'):
+ with tf.Graph().as_default():
+ model = CompoundModel()
+ model.build(tf.TensorShape((None, 10, 10, 1)))
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ sample_input = tf.random_uniform((1, 10, 10, 1))
+ output = model(sample_input) # pylint: disable=not-callable
+ self.assertEqual(sess.run(output).shape, (1, 3))
+
if __name__ == '__main__':
tf.test.main()
diff --git a/tensorflow/contrib/autograph/impl/conversion.py b/tensorflow/contrib/autograph/impl/conversion.py
index afb10d4d8b..fc8a976d3f 100644
--- a/tensorflow/contrib/autograph/impl/conversion.py
+++ b/tensorflow/contrib/autograph/impl/conversion.py
@@ -118,6 +118,17 @@ def entity_to_graph(o, program_ctx, arg_values, arg_types):
node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
elif tf_inspect.ismethod(o):
node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
+ # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
+ elif hasattr(o, '__class__'):
+ raise NotImplementedError(
+ 'Object conversion is not yet supported. If you are '
+ 'trying to convert code that uses an existing object, '
+ 'try including the creation of that object in the '
+ 'conversion. For example, instead of converting the method '
+ 'of a class, try converting the entire class instead. '
+ 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
+ 'contrib/autograph/README.md#using-the-functional-api '
+ 'for more information.')
else:
raise ValueError(
'Entity "%s" has unsupported type "%s". Only functions and classes are '
@@ -181,7 +192,7 @@ def class_to_graph(c, program_ctx):
class_name = namer.compiled_class_name(c.__name__, c)
# TODO(mdan): This needs to be explained more thoroughly.
- # Process any base classes: if the sueprclass if of a whitelisted type, an
+ # Process any base classes: if the superclass if of a whitelisted type, an
# absolute import line is generated. Otherwise, it is marked for conversion
# (as a side effect of the call to namer.compiled_class_name() followed by
# program_ctx.update_name_map(namer)).
diff --git a/tensorflow/contrib/autograph/impl/conversion_test.py b/tensorflow/contrib/autograph/impl/conversion_test.py
index 1c5d4d09c4..86432573a7 100644
--- a/tensorflow/contrib/autograph/impl/conversion_test.py
+++ b/tensorflow/contrib/autograph/impl/conversion_test.py
@@ -50,7 +50,7 @@ class ConversionTest(test.TestCase):
self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant))
def test_entity_to_graph_unsupported_types(self):
- with self.assertRaises(ValueError):
+ with self.assertRaises(NotImplementedError):
program_ctx = self._simple_program_ctx()
conversion.entity_to_graph('dummy', program_ctx, None, None)