aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/interleave_ops.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-04-23 21:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 21:21:38 -0700
commit22f3a97b8b089202f60bb0c7697feb0c8e0713cc (patch)
treed16f95826e4be15bbb3b0f22bed0ca25d3eb5897 /tensorflow/contrib/data/python/ops/interleave_ops.py
parent24b7c9a800ab5086d45a7d83ebcd6218424dc9e3 (diff)
Merge changes from github.
PiperOrigin-RevId: 194031845
Diffstat (limited to 'tensorflow/contrib/data/python/ops/interleave_ops.py')
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py26
1 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 106a1ef388..812a50ecbf 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -200,10 +200,11 @@ def sample_from_datasets(datasets, weights=None, seed=None):
Args:
datasets: A list of @{tf.data.Dataset} objects with compatible structure.
- weights: (Optional.) A list of `len(datasets)` floating-point values,
- where `weights[i]` represents the probability with which an element
- should be sampled from `datasets[i]`. Defaults to a uniform distribution
- across `datasets`.
+ weights: (Optional.) A list of `len(datasets)` floating-point values where
+ `weights[i]` represents the probability with which an element should be
+ sampled from `datasets[i]`, or a @{tf.data.Dataset} object where each
+ element is such a list. Defaults to a uniform distribution across
+ `datasets`.
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
random seed that will be used to create the distribution. See
@{tf.set_random_seed} for behavior.
@@ -219,24 +220,23 @@ def sample_from_datasets(datasets, weights=None, seed=None):
"""
num_datasets = len(datasets)
if weights is None:
- weights = array_ops.ones(
- [num_datasets], dtype=dtypes.float32, name="weights")
- else:
+ weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat()
+ elif not isinstance(weights, dataset_ops.Dataset):
weights = ops.convert_to_tensor(weights, name="weights")
if weights.dtype not in (dtypes.float32, dtypes.float64):
raise TypeError("`weights` must be convertible to a tensor of "
"`tf.float32` or `tf.float64` elements.")
if not weights.shape.is_compatible_with([num_datasets]):
raise ValueError("`weights` must be a vector of length `len(datasets)`.")
+ weights = dataset_ops.Dataset.from_tensors(weights).repeat()
# The `stateless_multinomial()` op expects log-probabilities, as opposed to
# weights.
- logits = math_ops.log(weights, name="logits")
-
- def select_dataset(seed):
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+ def select_dataset(logits, seed):
return array_ops.squeeze(
- stateless.stateless_multinomial([logits], 1, seed=seed), axis=[0, 1])
-
- selector_input = random_ops.RandomDataset(seed).batch(2).map(select_dataset)
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+ selector_input = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
return DirectedInterleaveDataset(selector_input, datasets)