diff options
author | Francois Chollet <fchollet@google.com> | 2018-10-04 13:14:07 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 13:18:26 -0700 |
commit | b949f9ee60522ca43f7f8a89b15ea6eeed2ac570 (patch) | |
tree | cb7bc95db0e38c0c919503b01ea0a5b4a0ee11f5 /tensorflow/python | |
parent | 7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (diff) |
Enable masking through a Sequential model.
PiperOrigin-RevId: 215790636
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/keras/engine/input_layer.py | 1 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/topology_test.py | 31 |
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py index 8a4018a0df..6a69d0ed90 100644 --- a/tensorflow/python/keras/engine/input_layer.py +++ b/tensorflow/python/keras/engine/input_layer.py @@ -82,6 +82,7 @@ class InputLayer(base_layer.Layer): self.built = True self.sparse = sparse self.batch_size = batch_size + self.supports_masking = True if isinstance(input_shape, tensor_shape.TensorShape): input_shape = tuple(input_shape.as_list()) diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py index a0da96334b..b4488033cd 100644 --- a/tensorflow/python/keras/engine/topology_test.py +++ b/tensorflow/python/keras/engine/topology_test.py @@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import test +from tensorflow.python.training import rmsprop try: import yaml # pylint:disable=g-import-not-at-top @@ -1182,6 +1183,36 @@ class DefaultShapeInferenceBehaviorTest(test.TestCase): output = model(sample_input) self.assertEqual(output.shape, (1, 3)) + @test_util.run_in_graph_and_eager_modes() + def test_sequential_as_downstream_of_masking_layer(self): + inputs = keras.layers.Input(shape=(3, 4)) + x = keras.layers.Masking(mask_value=0., input_shape=(3, 4))(inputs) + + s = keras.Sequential() + s.add(keras.layers.Dense(5, input_shape=(4,))) + + x = keras.layers.wrappers.TimeDistributed(s)(x) + model = keras.Model(inputs=inputs, outputs=x) + model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss='mse') + + model_input = np.random.randint( + low=1, high=5, size=(10, 3, 4)).astype('float32') + for i in range(4): + model_input[i, i:, :] = 0. + model.fit(model_input, + np.random.random((10, 3, 5)), epochs=1, batch_size=6) + + if not context.executing_eagerly(): + # Note: this doesn't work in eager due to DeferredTensor/ops compatibility + # issue. + mask_outputs = [model.layers[1].compute_mask(model.layers[1].input)] + mask_outputs += [model.layers[2].compute_mask( + model.layers[2].input, mask_outputs[-1])] + func = keras.backend.function([model.input], mask_outputs) + mask_outputs_val = func([model_input]) + self.assertAllClose(mask_outputs_val[0], np.any(model_input, axis=-1)) + self.assertAllClose(mask_outputs_val[1], np.any(model_input, axis=-1)) + class GraphUtilsTest(test.TestCase): |