aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-29 18:24:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 18:29:25 -0700
commitbea2cb338c268eb25e32d929772dfe6c99f713bd (patch)
treeeab7f0b9f598b731252242ac3301eae5e092e048 /tensorflow/contrib/data
parent98dd0cd1539c8831ff2527895dd3025c7f12b187 (diff)
[tf.data] Add special case for single dataset in `tf.contrib.data.sample_from_datasets()`.
PiperOrigin-RevId: 210830214
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py2
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py6
2 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9b1857de1a..9020a499c4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -84,7 +84,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
# Use chi-squared test to assert that the observed distribution matches the
# expected distribution. Based on the implementation in
# "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs]:
+ for probs in [[.85, .05, .1], rand_probs, [1.]]:
probs = np.asarray(probs)
classes = len(probs)
freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 54a92ab185..38c0a09c33 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -235,6 +235,12 @@ def sample_from_datasets(datasets, weights=None, seed=None):
# to weights.
logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
def select_dataset_constant_logits(seed):
return array_ops.squeeze(
stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])