aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/model_subclassing_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/model_subclassing_test.py')
-rw-r--r--tensorflow/python/keras/model_subclassing_test.py276
1 files changed, 266 insertions, 10 deletions
diff --git a/tensorflow/python/keras/model_subclassing_test.py b/tensorflow/python/keras/model_subclassing_test.py
index b7e16a41dd..5fbc191e78 100644
--- a/tensorflow/python/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/model_subclassing_test.py
@@ -29,9 +29,11 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.training.checkpointable import data_structures
from tensorflow.python.training.rmsprop import RMSPropOptimizer
try:
@@ -65,6 +67,22 @@ class SimpleTestModel(keras.Model):
return self.dense2(x)
+class SimpleConvTestModel(keras.Model):
+
+ def __init__(self, num_classes=10):
+ super(SimpleConvTestModel, self).__init__(name='test_model')
+ self.num_classes = num_classes
+
+ self.conv1 = keras.layers.Conv2D(32, (3, 3), activation='relu')
+ self.flatten = keras.layers.Flatten()
+ self.dense1 = keras.layers.Dense(num_classes, activation='softmax')
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.flatten(x)
+ return self.dense1(x)
+
+
class MultiIOTestModel(keras.Model):
def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)):
@@ -174,6 +192,213 @@ def get_nested_model_3(input_dim, num_classes):
class ModelSubclassingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
+ def test_invalid_input_shape_build(self):
+ num_classes = 2
+ input_dim = 50
+
+ model = SimpleTestModel(num_classes=num_classes,
+ use_dp=True,
+ use_bn=True)
+
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ with self.assertRaisesRegexp(
+ ValueError, 'input shape is not one of the valid types'):
+ model.build(input_shape=tensor_shape.Dimension(input_dim))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_embed_dtype_with_subclass_build(self):
+ class Embedding(keras.layers.Layer):
+ """An Embedding layer."""
+
+ def __init__(self, vocab_size, embedding_dim, **kwargs):
+ super(Embedding, self).__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.embedding_dim = embedding_dim
+
+ def build(self, _):
+ self.embedding = self.add_variable(
+ 'embedding_kernel',
+ shape=[self.vocab_size, self.embedding_dim],
+ dtype=np.float32,
+ initializer=init_ops.random_uniform_initializer(-0.1, 0.1),
+ trainable=True)
+
+ def call(self, x):
+ return embedding_ops.embedding_lookup(self.embedding, x)
+
+ class EmbedModel(keras.Model):
+
+ def __init__(self, vocab_size, embed_size):
+ super(EmbedModel, self).__init__()
+ self.embed1 = Embedding(vocab_size, embed_size)
+
+ def call(self, inputs):
+ return self.embed1(inputs)
+
+ model = EmbedModel(100, 20)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ with self.assertRaisesRegexp(
+ ValueError, 'if your layers do not support float type inputs'):
+ model.build(input_shape=(35, 20))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_single_time_step_rnn_build(self):
+ dim = 4
+ timesteps = 1
+ batch_input_shape = (None, timesteps, dim)
+ units = 3
+
+ class SimpleRNNModel(keras.Model):
+
+ def __init__(self):
+ super(SimpleRNNModel, self).__init__()
+ self.lstm = keras.layers.LSTM(units)
+
+ def call(self, inputs):
+ return self.lstm(inputs)
+
+ model = SimpleRNNModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(batch_input_shape)
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ model(array_ops.ones((32, timesteps, dim)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_single_io_subclass_build(self):
+ num_classes = 2
+ input_dim = 50
+ batch_size = None
+
+ model = SimpleTestModel(num_classes=num_classes,
+ use_dp=True,
+ use_bn=True)
+
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(input_shape=(batch_size, input_dim))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ model(array_ops.ones((32, input_dim)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_single_io_dimension_subclass_build(self):
+ num_classes = 2
+ input_dim = tensor_shape.Dimension(50)
+ batch_size = tensor_shape.Dimension(None)
+
+ model = SimpleTestModel(num_classes=num_classes,
+ use_dp=True,
+ use_bn=True)
+
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(input_shape=(batch_size, input_dim))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ model(array_ops.ones((32, input_dim)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_multidim_io_subclass_build(self):
+ num_classes = 10
+ # Input size, e.g. image
+ batch_size = 32
+ input_shape = (32, 32, 3)
+
+ model = SimpleConvTestModel(num_classes)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ batch_input_shape = (batch_size,) + input_shape
+ model.build(input_shape=batch_input_shape)
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+
+ model(array_ops.ones(batch_input_shape))
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_tensorshape_io_subclass_build(self):
+ num_classes = 10
+ # Input size, e.g. image
+ batch_size = None
+ input_shape = (32, 32, 3)
+
+ model = SimpleConvTestModel(num_classes)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(
+ input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+
+ model(array_ops.ones((32,) + input_shape))
+
+ def test_subclass_save_model(self):
+ num_classes = 10
+ # Input size, e.g. image
+ batch_size = None
+ input_shape = (32, 32, 3)
+
+ model = SimpleConvTestModel(num_classes)
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build(
+ input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ weights = model.get_weights()
+
+ tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt')
+ model.save_weights(tf_format_name)
+ if h5py is not None:
+ hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
+ model.save_weights(hdf5_format_name)
+
+ model = SimpleConvTestModel(num_classes)
+ model.build(
+ input_shape=tensor_shape.TensorShape((batch_size,) + input_shape))
+ if h5py is not None:
+ model.load_weights(hdf5_format_name)
+ self.assertAllClose(weights, model.get_weights())
+ model.load_weights(tf_format_name)
+ self.assertAllClose(weights, model.get_weights())
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_multi_io_subclass_build(self):
+ batch_size = None
+ num_samples = 1000
+ input_dim = 50
+ model = MultiIOTestModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ batch_input_shape = tensor_shape.TensorShape((batch_size, input_dim))
+ model.build(
+ input_shape=[batch_input_shape, batch_input_shape])
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+ x1 = array_ops.ones((num_samples, input_dim))
+ x2 = array_ops.ones((num_samples, input_dim))
+ model([x1, x2])
+
+ @test_util.run_in_graph_and_eager_modes
def test_single_io_workflow_with_np_arrays(self):
num_classes = 2
num_samples = 100
@@ -679,8 +904,8 @@ class ModelSubclassingTest(test.TestCase):
def __init__(self):
super(Foo, self).__init__()
self.isdep = keras.layers.Dense(1)
- self.notdep = checkpointable.NoDependency(keras.layers.Dense(2))
- self.notdep_var = checkpointable.NoDependency(
+ self.notdep = data_structures.NoDependency(keras.layers.Dense(2))
+ self.notdep_var = data_structures.NoDependency(
resource_variable_ops.ResourceVariable(1., name='notdep_var'))
m = Foo()
@@ -750,6 +975,16 @@ class CustomCallModel(keras.Model):
return combined
+class TrainingNoDefaultModel(keras.Model):
+
+ def __init__(self):
+ super(TrainingNoDefaultModel, self).__init__()
+ self.dense1 = keras.layers.Dense(1)
+
+ def call(self, x, training):
+ return self.dense1(x)
+
+
class CustomCallSignatureTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -767,6 +1002,32 @@ class CustomCallSignatureTests(test.TestCase):
self.assertAllClose(expected_output, self.evaluate(output))
@test_util.run_in_graph_and_eager_modes
+ def test_training_args_call_build(self):
+ input_dim = 2
+
+ model = TrainingNoDefaultModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ model.build((None, input_dim))
+ self.assertTrue(model.weights, ('Model should have weights now that it '
+ 'has been properly built.'))
+ self.assertTrue(model.built, 'Model should be built after calling `build`.')
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_custom_call_kwargs_and_build(self):
+ first_input_shape = (2, 3)
+ second_input_shape = (2, 5)
+
+ model = CustomCallModel()
+ self.assertFalse(model.built, 'Model should not have been built')
+ self.assertFalse(model.weights, ('Model should have no weights since it '
+ 'has not been built.'))
+ with self.assertRaisesRegexp(
+ ValueError, 'cannot build your model if it has positional'):
+ model.build(input_shape=[first_input_shape, second_input_shape])
+
+ @test_util.run_in_graph_and_eager_modes
def test_inputs_in_signature(self):
class HasInputsAndOtherPositional(keras.Model):
@@ -829,14 +1090,9 @@ class CustomCallSignatureTests(test.TestCase):
def test_training_no_default(self):
- class TrainingNoDefault(keras.Model):
-
- def call(self, x, training):
- return x
-
with context.graph_mode():
- model = TrainingNoDefault()
- arg = array_ops.ones([])
+ model = TrainingNoDefaultModel()
+ arg = array_ops.ones([1, 1])
model(arg, True)
six.assertCountEqual(self, [arg], model.inputs)