aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2018-10-04 13:14:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 13:18:26 -0700
commitb949f9ee60522ca43f7f8a89b15ea6eeed2ac570 (patch)
treecb7bc95db0e38c0c919503b01ea0a5b4a0ee11f5 /tensorflow/python
parent7fcb05ff475a0c6c1076eacf9d11e17323d98bc2 (diff)
Enable masking through a Sequential model.
PiperOrigin-RevId: 215790636
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/keras/engine/input_layer.py1
-rw-r--r--tensorflow/python/keras/engine/topology_test.py31
2 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 8a4018a0df..6a69d0ed90 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -82,6 +82,7 @@ class InputLayer(base_layer.Layer):
self.built = True
self.sparse = sparse
self.batch_size = batch_size
+ self.supports_masking = True
if isinstance(input_shape, tensor_shape.TensorShape):
input_shape = tuple(input_shape.as_list())
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index a0da96334b..b4488033cd 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
try:
import yaml # pylint:disable=g-import-not-at-top
@@ -1182,6 +1183,36 @@ class DefaultShapeInferenceBehaviorTest(test.TestCase):
output = model(sample_input)
self.assertEqual(output.shape, (1, 3))
+ @test_util.run_in_graph_and_eager_modes()
+ def test_sequential_as_downstream_of_masking_layer(self):
+ inputs = keras.layers.Input(shape=(3, 4))
+ x = keras.layers.Masking(mask_value=0., input_shape=(3, 4))(inputs)
+
+ s = keras.Sequential()
+ s.add(keras.layers.Dense(5, input_shape=(4,)))
+
+ x = keras.layers.wrappers.TimeDistributed(s)(x)
+ model = keras.Model(inputs=inputs, outputs=x)
+ model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss='mse')
+
+ model_input = np.random.randint(
+ low=1, high=5, size=(10, 3, 4)).astype('float32')
+ for i in range(4):
+ model_input[i, i:, :] = 0.
+ model.fit(model_input,
+ np.random.random((10, 3, 5)), epochs=1, batch_size=6)
+
+ if not context.executing_eagerly():
+ # Note: this doesn't work in eager due to DeferredTensor/ops compatibility
+ # issue.
+ mask_outputs = [model.layers[1].compute_mask(model.layers[1].input)]
+ mask_outputs += [model.layers[2].compute_mask(
+ model.layers[2].input, mask_outputs[-1])]
+ func = keras.backend.function([model.input], mask_outputs)
+ mask_outputs_val = func([model_input])
+ self.assertAllClose(mask_outputs_val[0], np.any(model_input, axis=-1))
+ self.assertAllClose(mask_outputs_val[1], np.any(model_input, axis=-1))
+
class GraphUtilsTest(test.TestCase):