aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-07-26 16:38:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 16:42:22 -0700
commite336ee65a5c887e9a2f0b4c82c333bca405707a5 (patch)
tree0fd79cdca3585c51b7e86b924a5bd4a0f1ded6de
parent403845d3e26291d6013c623b9130f4404c969ca6 (diff)
Fix: When sample_weight_mode is a list/dict set default sample_weight values so that we do not require sample_weight to be set during training/eval
PiperOrigin-RevId: 206242625
-rw-r--r--tensorflow/python/keras/engine/training.py57
-rw-r--r--tensorflow/python/keras/engine/training_test.py48
-rw-r--r--tensorflow/python/keras/engine/training_utils.py24
3 files changed, 87 insertions, 42 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 4df739254b..39d207cc6b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -24,7 +24,6 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -374,21 +373,14 @@ class Model(Network):
'sample_weight_mode dictionary: "' + name + '". '
'Only expected the following keys: ' + str(self.output_names))
for i, name in enumerate(self.output_names):
- if i in skip_target_weighing_indices:
- weight = None
- sample_weight_modes.append(None)
- else:
- if name not in sample_weight_mode:
- raise ValueError(
- 'Output "' + name + '" missing from sample_weight_modes '
- 'dictionary')
- if sample_weight_mode.get(name) == 'temporal':
- weight = K.placeholder(ndim=2, name=name + '_sample_weights')
- sample_weight_modes.append('temporal')
- else:
- weight = K.placeholder(ndim=1, name=name + 'sample_weights')
- sample_weight_modes.append(None)
+ if (i not in skip_target_weighing_indices and
+ name not in sample_weight_mode):
+ raise ValueError('Output "' + name +
+ '" missing from sample_weight_modes dictionary')
+ weight, mode = training_utils.get_output_sample_weight_and_mode(
+ skip_target_weighing_indices, sample_weight_mode.get(name), name, i)
sample_weights.append(weight)
+ sample_weight_modes.append(mode)
elif isinstance(sample_weight_mode, list):
if len(sample_weight_mode) != len(self.outputs):
raise ValueError('When passing a list as sample_weight_mode, '
@@ -396,36 +388,17 @@ class Model(Network):
'The model has ' + str(len(self.outputs)) +
' outputs, but you passed '
'sample_weight_mode=' + str(sample_weight_mode))
- for i in range(len(self.output_names)):
- if i in skip_target_weighing_indices:
- weight = None
- sample_weight_modes.append(None)
- else:
- mode = sample_weight_mode[i]
- name = self.output_names[i]
- if mode == 'temporal':
- weight = K.placeholder(ndim=2, name=name + '_sample_weights')
- sample_weight_modes.append('temporal')
- else:
- weight = K.placeholder(ndim=1, name=name + '_sample_weights')
- sample_weight_modes.append(None)
+ for i, name in enumerate(self.output_names):
+ weight, mode = training_utils.get_output_sample_weight_and_mode(
+ skip_target_weighing_indices, sample_weight_mode[i], name, i)
sample_weights.append(weight)
+ sample_weight_modes.append(mode)
else:
for i, name in enumerate(self.output_names):
- if i in skip_target_weighing_indices:
- sample_weight_modes.append(None)
- sample_weights.append(None)
- else:
- if sample_weight_mode == 'temporal':
- sample_weights.append(array_ops.placeholder_with_default(
- constant_op.constant([[1.]], dtype=K.floatx()),
- shape=[None, None], name=name + '_sample_weights'))
- sample_weight_modes.append('temporal')
- else:
- sample_weights.append(array_ops.placeholder_with_default(
- constant_op.constant([1.], dtype=K.floatx()),
- shape=[None], name=name + '_sample_weights'))
- sample_weight_modes.append(None)
+ weight, mode = training_utils.get_output_sample_weight_and_mode(
+ skip_target_weighing_indices, sample_weight_mode, name, i)
+ sample_weights.append(weight)
+ sample_weight_modes.append(mode)
self.sample_weight_modes = sample_weight_modes
self._feed_sample_weight_modes = []
for i in range(len(self.outputs)):
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 301a6ca866..129441d159 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -731,6 +731,54 @@ class LossWeightingTest(test.TestCase):
model.fit(x_np, [y_np, y_np], epochs=1,
sample_weight={'1': bad_w_np})
+ def test_default_sample_weight(self):
+ """Verifies that fit works without having to set sample_weight."""
+
+ num_classes = 5
+ input_dim = 5
+ timesteps = 3
+ with self.test_session():
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.TimeDistributed(
+ keras.layers.Dense(num_classes),
+ input_shape=(timesteps, input_dim)))
+
+ x = np.random.random((10, timesteps, input_dim))
+ y = np.random.random((10, timesteps, num_classes))
+
+ # sample_weight_mode is a list and mode value is None
+ model.compile(loss='mse', optimizer='rmsprop', sample_weight_mode=[None])
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a list and mode value is `temporal`
+ model.compile(
+ loss='mse', optimizer='rmsprop', sample_weight_mode=['temporal'])
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a dict and mode value is None
+ model.compile(
+ loss='mse',
+ optimizer='rmsprop',
+ sample_weight_mode={'time_distributed': None})
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a dict and mode value is `temporal`
+ model.compile(
+ loss='mse',
+ optimizer='rmsprop',
+ sample_weight_mode={'time_distributed': 'temporal'})
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a not a list/dict and mode value is None
+ model.compile(loss='mse', optimizer='rmsprop', sample_weight_mode=None)
+ model.fit(x, y, epochs=1, batch_size=10)
+
+ # sample_weight_mode is a not a list/dict and mode value is `temporal`
+ model.compile(
+ loss='mse', optimizer='rmsprop', sample_weight_mode='temporal')
+ model.fit(x, y, epochs=1, batch_size=10)
+
class LossMaskingTest(test.TestCase):
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index dbbc87daf9..21495fd0bd 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -26,10 +26,12 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import metrics as metrics_module
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -856,3 +858,25 @@ def cast_if_floating_dtype(x):
for val in x
]
return math_ops.cast(x, dtype=K.floatx()) if x.dtype.is_floating else x
+
+
+def get_output_sample_weight_and_mode(skip_target_weighing_indices,
+ sample_weight_mode, output_name,
+ output_index):
+ """Returns the sample weight and weight mode for a single output."""
+ if output_index in skip_target_weighing_indices:
+ return None, None
+
+ if sample_weight_mode == 'temporal':
+ default_value = [[1.]]
+ shape = [None, None]
+ mode = 'temporal'
+ else:
+ default_value = [1.]
+ shape = [None]
+ mode = None
+ weight = array_ops.placeholder_with_default(
+ constant_op.constant(default_value, dtype=K.floatx()),
+ shape=shape,
+ name=output_name + '_sample_weights')
+ return weight, mode