diff options
Diffstat (limited to 'tensorflow/python/keras/model_subclassing_test.py')
-rw-r--r-- | tensorflow/python/keras/model_subclassing_test.py | 276 |
1 files changed, 266 insertions, 10 deletions
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py index b7e16a41dd..5fbc191e78 100644 --- a/tensorflow/python/keras/model_subclassing_test.py +++ b/tensorflow/python/keras/model_subclassing_test.py @@ -29,9 +29,11 @@ from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -from tensorflow.python.training.checkpointable import base as checkpointable +from tensorflow.python.training.checkpointable import data_structures from tensorflow.python.training.rmsprop import RMSPropOptimizer try: @@ -65,6 +67,22 @@ class SimpleTestModel(keras.Model): return self.dense2(x) +class SimpleConvTestModel(keras.Model): + + def __init__(self, num_classes=10): + super(SimpleConvTestModel, self).__init__(name='test_model') + self.num_classes = num_classes + + self.conv1 = keras.layers.Conv2D(32, (3, 3), activation='relu') + self.flatten = keras.layers.Flatten() + self.dense1 = keras.layers.Dense(num_classes, activation='softmax') + + def call(self, x): + x = self.conv1(x) + x = self.flatten(x) + return self.dense1(x) + + class MultiIOTestModel(keras.Model): def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)): @@ -174,6 +192,213 @@ def get_nested_model_3(input_dim, num_classes): class ModelSubclassingTest(test.TestCase): @test_util.run_in_graph_and_eager_modes + def test_invalid_input_shape_build(self): + num_classes = 2 + input_dim = 50 + + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + with self.assertRaisesRegexp( + ValueError, 'input shape is not one of the valid types'): + model.build(input_shape=tensor_shape.Dimension(input_dim)) + + @test_util.run_in_graph_and_eager_modes + def test_embed_dtype_with_subclass_build(self): + class Embedding(keras.layers.Layer): + """An Embedding layer.""" + + def __init__(self, vocab_size, embedding_dim, **kwargs): + super(Embedding, self).__init__(**kwargs) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + def build(self, _): + self.embedding = self.add_variable( + 'embedding_kernel', + shape=[self.vocab_size, self.embedding_dim], + dtype=np.float32, + initializer=init_ops.random_uniform_initializer(-0.1, 0.1), + trainable=True) + + def call(self, x): + return embedding_ops.embedding_lookup(self.embedding, x) + + class EmbedModel(keras.Model): + + def __init__(self, vocab_size, embed_size): + super(EmbedModel, self).__init__() + self.embed1 = Embedding(vocab_size, embed_size) + + def call(self, inputs): + return self.embed1(inputs) + + model = EmbedModel(100, 20) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + with self.assertRaisesRegexp( + ValueError, 'if your layers do not support float type inputs'): + model.build(input_shape=(35, 20)) + + @test_util.run_in_graph_and_eager_modes + def test_single_time_step_rnn_build(self): + dim = 4 + timesteps = 1 + batch_input_shape = (None, timesteps, dim) + units = 3 + + class SimpleRNNModel(keras.Model): + + def __init__(self): + super(SimpleRNNModel, self).__init__() + self.lstm = keras.layers.LSTM(units) + + def call(self, inputs): + return self.lstm(inputs) + + model = SimpleRNNModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build(batch_input_shape) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + model(array_ops.ones((32, timesteps, dim))) + + @test_util.run_in_graph_and_eager_modes + def test_single_io_subclass_build(self): + num_classes = 2 + input_dim = 50 + batch_size = None + + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build(input_shape=(batch_size, input_dim)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + model(array_ops.ones((32, input_dim))) + + @test_util.run_in_graph_and_eager_modes + def test_single_io_dimension_subclass_build(self): + num_classes = 2 + input_dim = tensor_shape.Dimension(50) + batch_size = tensor_shape.Dimension(None) + + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build(input_shape=(batch_size, input_dim)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + model(array_ops.ones((32, input_dim))) + + @test_util.run_in_graph_and_eager_modes + def test_multidim_io_subclass_build(self): + num_classes = 10 + # Input size, e.g. image + batch_size = 32 + input_shape = (32, 32, 3) + + model = SimpleConvTestModel(num_classes) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + batch_input_shape = (batch_size,) + input_shape + model.build(input_shape=batch_input_shape) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + + model(array_ops.ones(batch_input_shape)) + + @test_util.run_in_graph_and_eager_modes + def test_tensorshape_io_subclass_build(self): + num_classes = 10 + # Input size, e.g. image + batch_size = None + input_shape = (32, 32, 3) + + model = SimpleConvTestModel(num_classes) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build( + input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + + model(array_ops.ones((32,) + input_shape)) + + def test_subclass_save_model(self): + num_classes = 10 + # Input size, e.g. image + batch_size = None + input_shape = (32, 32, 3) + + model = SimpleConvTestModel(num_classes) + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build( + input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + weights = model.get_weights() + + tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt') + model.save_weights(tf_format_name) + if h5py is not None: + hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5') + model.save_weights(hdf5_format_name) + + model = SimpleConvTestModel(num_classes) + model.build( + input_shape=tensor_shape.TensorShape((batch_size,) + input_shape)) + if h5py is not None: + model.load_weights(hdf5_format_name) + self.assertAllClose(weights, model.get_weights()) + model.load_weights(tf_format_name) + self.assertAllClose(weights, model.get_weights()) + + @test_util.run_in_graph_and_eager_modes + def test_multi_io_subclass_build(self): + batch_size = None + num_samples = 1000 + input_dim = 50 + model = MultiIOTestModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + batch_input_shape = tensor_shape.TensorShape((batch_size, input_dim)) + model.build( + input_shape=[batch_input_shape, batch_input_shape]) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + x1 = array_ops.ones((num_samples, input_dim)) + x2 = array_ops.ones((num_samples, input_dim)) + model([x1, x2]) + + @test_util.run_in_graph_and_eager_modes def test_single_io_workflow_with_np_arrays(self): num_classes = 2 num_samples = 100 @@ -679,8 +904,8 @@ class ModelSubclassingTest(test.TestCase): def __init__(self): super(Foo, self).__init__() self.isdep = keras.layers.Dense(1) - self.notdep = checkpointable.NoDependency(keras.layers.Dense(2)) - self.notdep_var = checkpointable.NoDependency( + self.notdep = data_structures.NoDependency(keras.layers.Dense(2)) + self.notdep_var = data_structures.NoDependency( resource_variable_ops.ResourceVariable(1., name='notdep_var')) m = Foo() @@ -750,6 +975,16 @@ class CustomCallModel(keras.Model): return combined +class TrainingNoDefaultModel(keras.Model): + + def __init__(self): + super(TrainingNoDefaultModel, self).__init__() + self.dense1 = keras.layers.Dense(1) + + def call(self, x, training): + return self.dense1(x) + + class CustomCallSignatureTests(test.TestCase): @test_util.run_in_graph_and_eager_modes @@ -767,6 +1002,32 @@ class CustomCallSignatureTests(test.TestCase): self.assertAllClose(expected_output, self.evaluate(output)) @test_util.run_in_graph_and_eager_modes + def test_training_args_call_build(self): + input_dim = 2 + + model = TrainingNoDefaultModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + model.build((None, input_dim)) + self.assertTrue(model.weights, ('Model should have weights now that it ' + 'has been properly built.')) + self.assertTrue(model.built, 'Model should be built after calling `build`.') + + @test_util.run_in_graph_and_eager_modes + def test_custom_call_kwargs_and_build(self): + first_input_shape = (2, 3) + second_input_shape = (2, 5) + + model = CustomCallModel() + self.assertFalse(model.built, 'Model should not have been built') + self.assertFalse(model.weights, ('Model should have no weights since it ' + 'has not been built.')) + with self.assertRaisesRegexp( + ValueError, 'cannot build your model if it has positional'): + model.build(input_shape=[first_input_shape, second_input_shape]) + + @test_util.run_in_graph_and_eager_modes def test_inputs_in_signature(self): class HasInputsAndOtherPositional(keras.Model): @@ -829,14 +1090,9 @@ class CustomCallSignatureTests(test.TestCase): def test_training_no_default(self): - class TrainingNoDefault(keras.Model): - - def call(self, x, training): - return x - with context.graph_mode(): - model = TrainingNoDefault() - arg = array_ops.ones([]) + model = TrainingNoDefaultModel() + arg = array_ops.ones([1, 1]) model(arg, True) six.assertCountEqual(self, [arg], model.inputs) |