diff options
author | Derek Murray <mrry@google.com> | 2018-08-29 18:24:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 18:29:25 -0700 |
commit | bea2cb338c268eb25e32d929772dfe6c99f713bd (patch) | |
tree | eab7f0b9f598b731252242ac3301eae5e092e048 /tensorflow/contrib/data | |
parent | 98dd0cd1539c8831ff2527895dd3025c7f12b187 (diff) |
[tf.data] Add special case for single dataset in `tf.contrib.data.sample_from_datasets()`.
PiperOrigin-RevId: 210830214
Diffstat (limited to 'tensorflow/contrib/data')
-rw-r--r-- | tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/data/python/ops/interleave_ops.py | 6 |
2 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index 9b1857de1a..9020a499c4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -84,7 +84,7 @@ class DirectedInterleaveDatasetTest(test.TestCase): # Use chi-squared test to assert that the observed distribution matches the # expected distribution. Based on the implementation in # "tensorflow/python/kernel_tests/multinomial_op_test.py". - for probs in [[.85, .05, .1], rand_probs]: + for probs in [[.85, .05, .1], rand_probs, [1.]]: probs = np.asarray(probs) classes = len(probs) freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples) diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 54a92ab185..38c0a09c33 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -235,6 +235,12 @@ def sample_from_datasets(datasets, weights=None, seed=None): # to weights. logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0) + # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it + # is a `Dataset`, it is possible that evaluating it has a side effect the + # user depends on. + if len(datasets) == 1: + return datasets[0] + def select_dataset_constant_logits(seed): return array_ops.squeeze( stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1]) |