aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-02-01 14:11:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 17:32:59 -0800
commitf1f1d6d482e332f11452d9103a29149e2adc7125 (patch)
treeae747c9485cfb79b3641b0239652e920fb249e0b
parenta964248ae9aaee99165594e80427152576e803fe (diff)
Throw an exception when the user's batch size isn't divisible by GPUs.
The alternative to this is to have an adaptive approach that would unevenly split input into per-tower batches. The concern with that was that all towers will be as slow as the one with more input reducing the performance. Batch size seems to be commonly tailored to the available hardware. PiperOrigin-RevId: 184192793
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py10
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py79
2 files changed, 85 insertions, 4 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index caa9dd8323..c9153c9352 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -457,6 +457,13 @@ def _get_local_devices(device_type):
def _split_batch(features, labels, number_of_shards, device):
"""Split input features and labes into batches."""
+ def ensure_divisible_by_shards(sequence):
+ batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0]
+ if batch_size % number_of_shards != 0:
+ raise ValueError(
+ 'Batch size {} needs to be divisible by the number of GPUs, which '
+ 'is {}.'.format(batch_size, number_of_shards))
+
def split_dictionary(dictionary):
"""Split a dictionary into shards."""
shards = [{} for _ in range(number_of_shards)]
@@ -467,6 +474,7 @@ def _split_batch(features, labels, number_of_shards, device):
sp_input=tensor, num_split=number_of_shards, axis=0)):
shards[i][name] = shard
else:
+ ensure_divisible_by_shards(tensor)
for i, shard in enumerate(array_ops.split(tensor, number_of_shards)):
shards[i][name] = shard
return shards
@@ -476,6 +484,7 @@ def _split_batch(features, labels, number_of_shards, device):
if isinstance(features, dict):
feature_shards = split_dictionary(features)
else:
+ ensure_divisible_by_shards(features)
feature_shards = array_ops.split(features, number_of_shards)
if labels is None:
@@ -483,6 +492,7 @@ def _split_batch(features, labels, number_of_shards, device):
elif isinstance(labels, dict):
label_shards = split_dictionary(labels)
else:
+ ensure_divisible_by_shards(labels)
label_shards = array_ops.split(labels, number_of_shards)
return feature_shards, label_shards
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 03d31226af..6936f8a131 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -433,6 +434,17 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
'probabilities': np.array([[0.1], [0.02]])
}, session.run(estimator_spec.predictions))
+ def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self):
+ features = np.array([[1.0], [2.0], [3.0]])
+ labels = np.array([[1.0], [2.0], [3.0]])
+
+ with self.assertRaisesRegexp(
+ ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'):
+ replicated_model_fn = replicate_model_fn.replicate_model_fn(
+ self.model_fn, devices=['/gpu:0', '/gpu:1'])
+ _ = replicated_model_fn(
+ features, labels, model_fn_lib.ModeKeys.TRAIN, self.params)
+
def test_unsupported_loss_reduction(self):
with self.assertRaisesRegexp(ValueError,
'.+none.+reduction.+is.+specified.+'):
@@ -981,8 +993,13 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
return list(map(evaluate_items, first_list)), list(
map(evaluate_items, second_list))
+ def assertSparseValuesEqual(self, a, b):
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
def test_simple_half_split(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -995,7 +1012,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards)
def test_to_each_their_own(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1008,7 +1025,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards)
def test_one_batch(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = [0.0, 1.0, 2.0, 3.0]
labels = [10.0, 11.0, 12.0, 13.0]
feature_shards, label_shards = replicate_model_fn._split_batch(
@@ -1021,7 +1038,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards)
def test_half_split_in_dictionary(self):
- with self.test_session() as session: # pylint: disable=unused-variable
+ with self.test_session():
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}
labels = [10.0, 11.0, 12.0, 13.0]
@@ -1035,6 +1052,60 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
self.assertAllEqual([10.0, 11.0], label_shards[0].eval())
self.assertAllEqual([12.0, 13.0], label_shards[1].eval())
+ def test_sparse_tensor_can_be_split_unevenly(self):
+ with self.test_session():
+ features = {
+ 'x':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 2], [2, 2]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=[3, 4])
+ }
+ labels = np.array([[1.0], [2.0]])
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]),
+ feature_shards[0]['x'].eval())
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 2]], values=[3.], dense_shape=[1, 4]),
+ feature_shards[1]['x'].eval())
+ self.assertAllEqual([[1.0]], label_shards[0].eval())
+ self.assertAllEqual([[2.0]], label_shards[1].eval())
+
+ def test_sparse_tensor_can_be_split_unevenly_repeated_row(self):
+ with self.test_session():
+ features = {
+ 'x':
+ sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[1.0, 2.0, 3.0],
+ dense_shape=[3, 4])
+ }
+ labels = np.array([[1.0], [2.0]])
+
+ feature_shards, label_shards = replicate_model_fn._split_batch(
+ features, labels, 2, device='/gpu:0')
+
+ print(feature_shards[0]['x'].eval())
+ print(feature_shards[1]['x'].eval())
+ self.assertSparseValuesEqual(
+ sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [1, 1]],
+ values=[1., 2., 3.],
+ dense_shape=[2, 4]), feature_shards[0]['x'].eval())
+
+ second_batch = feature_shards[1]['x'].eval()
+ self.assertFalse(len(second_batch.indices))
+ self.assertFalse(len(second_batch.values))
+ self.assertAllEqual([1, 4], second_batch.dense_shape)
+ self.assertAllEqual([[1.0]], label_shards[0].eval())
+ self.assertAllEqual([[2.0]], label_shards[1].eval())
+
def test_one_batch_in_dictionary(self):
with self.test_session() as session: # pylint: disable=unused-variable
features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]}