aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar joel-shor <joelshor@google.com>2018-04-29 01:16:16 +0300
committerGravatar joel-shor <joelshor@google.com>2018-04-29 01:21:28 +0300
commitb384c339ee7d8440b6d4e39c09202c19f900aebe (patch)
tree99347f27049dbb11d96d59e12f7755f381f8f538
parentc45b05197623b375a056dd9577a778c5d5cc7d03 (diff)
[tf.data] Possible bug fix to fix Winsows build.
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py4
-rw-r--r--tensorflow/contrib/data/python/ops/resampling.py1
2 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index c08283a041..bbb8ca22f6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -60,9 +60,9 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
- ("InitialDistributionUnknown", True)) # THIS IS TO TEST THE WINDOWS BUILD DONT SUBMIT
+ ("InitialDistributionUnknown", False))
def testDistribution(self, initial_known):
- classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
+ classes = np.random.randint(5, size=(20000,), dtype=np.int64)
target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
initial_dist = [0.2] * 5 if initial_known else None
dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py
index 1194b8447a..bad6edd514 100644
--- a/tensorflow/contrib/data/python/ops/resampling.py
+++ b/tensorflow/contrib/data/python/ops/resampling.py
@@ -79,7 +79,6 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
lambda accept_prob, _: accept_prob)
prob_of_original_ds = acceptance_and_original_prob_ds.map(
lambda _, prob_original: prob_original)
- prob_of_original = None
filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
class_values_ds, seed)
# Prefetch filtered dataset for speed.