aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/BUILD53
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py5
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py4
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py24
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt1
-rw-r--r--tensorflow/contrib/data/BUILD38
-rw-r--r--tensorflow/contrib/data/kernels/BUILD139
-rw-r--r--tensorflow/contrib/data/kernels/assert_next_dataset_op.cc155
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc859
-rw-r--r--tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc280
-rw-r--r--tensorflow/contrib/data/kernels/identity_indexed_dataset.cc155
-rw-r--r--tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc141
-rw-r--r--tensorflow/contrib/data/kernels/indexed_dataset.cc373
-rw-r--r--tensorflow/contrib/data/kernels/indexed_dataset.h119
-rw-r--r--tensorflow/contrib/data/kernels/lmdb_dataset_op.cc217
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc481
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc219
-rw-r--r--tensorflow/contrib/data/kernels/unique_dataset_op.cc223
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc208
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py43
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD9
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py7
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py5
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py3
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD57
-rw-r--r--tensorflow/contrib/data/python/ops/contrib_op_loader.py24
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py5
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py25
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py13
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py7
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py37
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py6
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py9
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py5
-rw-r--r--tensorflow/contrib/decision_trees/proto/BUILD1
-rw-r--r--tensorflow/contrib/distribute/README.md3
-rw-r--r--tensorflow/contrib/distribute/python/BUILD28
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py3
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py121
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py14
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py232
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/values.py63
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py23
-rw-r--r--tensorflow/contrib/factorization/BUILD9
-rw-r--r--tensorflow/contrib/lite/BUILD19
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl53
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD1
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h29
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_test.cc36
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/makefile/Makefile3
-rw-r--r--tensorflow/contrib/opt/BUILD5
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py40
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD7
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/tpu/__init__.py3
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc42
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto4
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto33
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py12
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py53
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py13
-rw-r--r--tensorflow/contrib/training/BUILD1
113 files changed, 958 insertions, 4380 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index ae5ca32bcf..98dff965a9 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -112,26 +112,14 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python/estimator:estimator_py",
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
- "//tensorflow/contrib/kafka",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
- "//tensorflow/contrib/kinesis",
- ],
- "//conditions:default": [],
- }) + if_not_windows_cuda([
- "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
- ]) + if_not_windows([
- ]) + select({
"//tensorflow:linux_s390x": [],
"//tensorflow:windows": [],
"//conditions:default": [
"//tensorflow/contrib/bigtable",
"//tensorflow/contrib/cloud:cloud_py",
+ "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols
+ "//tensorflow/contrib/kafka",
+ "//tensorflow/contrib/kinesis",
"//tensorflow/contrib/tensorrt:init_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
@@ -144,7 +132,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
- "//tensorflow/contrib/data/kernels:dataset_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/hadoop:dataset_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
@@ -159,20 +146,14 @@ cc_library(
] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([
"//tensorflow/contrib/nccl:nccl_kernels",
]) + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_kernels",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_kernels",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_kernel",
- ]),
+ }),
)
cc_library(
@@ -181,8 +162,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
@@ -198,18 +177,12 @@ cc_library(
"//tensorflow/contrib/text:all_ops",
"//tensorflow/contrib/tpu:all_ops",
] + select({
- "//tensorflow:with_kafka_support_windows_override": [],
- "//tensorflow:with_kafka_support": [
+ "//tensorflow:linux_s390x": [],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
"//tensorflow/contrib/kafka:dataset_ops_op_lib",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
+ "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
],
- "//conditions:default": [],
- }) + if_not_windows([
- "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
- ]),
+ }),
)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 6b6fe9663a..839eedd3a8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -188,9 +188,8 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
- # 10 steps for dnn + 3 for 1 tree of depth 3 + 1 after the tree finished
- # + 1 for resource variables.
- self._assert_checkpoint(est.model_dir, global_step=15)
+ # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ self._assert_checkpoint(est.model_dir, global_step=14)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
est.predict(input_fn=_eval_input_fn)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index d7b14e00ba..c155128c0e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -238,8 +238,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
output_leaf_index=False)
classifier.fit(input_fn=_train_input_fn, steps=15)
- # When no override of global steps, 6 steps were used.
- self._assert_checkpoint(classifier.model_dir, global_step=6)
+ # When no override of global steps, 5 steps were used.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
def testOverridesGlobalSteps(self):
learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index c7eb2493a8..8531e97f90 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object):
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._num_quantiles = num_quantiles
- self._max_tree_depth = variables.Variable(
+ self._max_tree_depth = variables.VariableV1(
initial_value=self._learner_config.constraints.max_tree_depth)
- self._attempted_trees = variables.Variable(
+ self._attempted_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="attempted_trees")
- self._finalized_trees = variables.Variable(
+ self._finalized_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="finalized_trees")
@@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object):
fc_name_idx += 1
# Create ensemble stats variables.
- num_layer_examples = variables.Variable(
+ num_layer_examples = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_examples",
trainable=False)
- num_layer_steps = variables.Variable(
+ num_layer_steps = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_steps",
trainable=False)
- num_layers = variables.Variable(
+ num_layers = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layers",
trainable=False)
- active_tree = variables.Variable(
+ active_tree = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_tree",
trainable=False)
- active_layer = variables.Variable(
+ active_layer = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_layer",
trainable=False)
# Variable that becomes false once bias centering is done.
- continue_centering = variables.Variable(
+ continue_centering = variables.VariableV1(
initial_value=self._center_bias,
name="continue_centering",
trainable=False)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 9d9941f696..6d20a2e7f4 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index c0763f4c0e..2975b167ec 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -132,7 +132,6 @@ tensorflow/contrib/cudnn_rnn/python
tensorflow/contrib/cudnn_rnn/python/layers
tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
-tensorflow/contrib/data/kernels
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
tensorflow/contrib/data/python/kernel_tests/serialization
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 9f710613dd..38f1c65a4d 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load(
- "//tensorflow:tensorflow.bzl",
- "tf_custom_op_library",
- "tf_gen_op_libs",
- "if_not_windows",
-)
-load(
- "//tensorflow/core:platform/default/build_config_root.bzl",
- "if_static",
-)
-
py_library(
name = "data",
srcs = ["__init__.py"],
@@ -25,30 +14,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-cc_library(
- name = "lib_proto_parsing_for_dataset_ops",
- deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]),
-)
-
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- "ops/indexed_dataset_ops.cc",
- ],
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/contrib/data/kernels:indexed_dataset",
- ] + if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
-)
-
-tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "indexed_dataset_ops",
- ],
-)
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
deleted file mode 100644
index ec6cb37193..0000000000
--- a/tensorflow/contrib/data/kernels/BUILD
+++ /dev/null
@@ -1,139 +0,0 @@
-# Description:
-# Contains kernels for datasets and iterators.
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-cc_library(
- name = "indexed_dataset_headers",
- hdrs = ["indexed_dataset.h"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
-)
-
-cc_library(
- name = "indexed_dataset",
- srcs = [
- "identity_indexed_dataset.cc",
- "indexed_dataset.cc",
- ],
- deps = [
- ":indexed_dataset_headers",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "prefetching_kernels",
- srcs = ["prefetching_kernels.cc"],
- deps = [
- "//tensorflow/core:core_cpu_headers_lib",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "directed_interleave_dataset_op",
- srcs = ["directed_interleave_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "csv_dataset_op",
- srcs = ["csv_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "ignore_errors_dataset_op",
- srcs = ["ignore_errors_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "lmdb_dataset_op",
- srcs = ["lmdb_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@lmdb",
- "@protobuf_archive//:protobuf_headers",
- ],
-)
-
-cc_library(
- name = "threadpool_dataset_op",
- srcs = ["threadpool_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "unique_dataset_op",
- srcs = ["unique_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "assert_next_dataset_op",
- srcs = ["assert_next_dataset_op.cc"],
- deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "dataset_kernels",
- deps = [
- ":assert_next_dataset_op",
- ":csv_dataset_op",
- ":directed_interleave_dataset_op",
- ":ignore_errors_dataset_op",
- ":indexed_dataset",
- ":lmdb_dataset_op",
- ":prefetching_kernels",
- ":threadpool_dataset_op",
- ":unique_dataset_op",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
- ],
-)
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
deleted file mode 100644
index c19a609780..0000000000
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ /dev/null
@@ -1,155 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <map>
-
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/partial_tensor_shape.h"
-#include "tensorflow/core/framework/tensor.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-class AssertNextDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- }
-
- protected:
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- std::vector<string> transformations;
- OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
- &transformations));
- *output =
- new Dataset(ctx, input, transformations, output_types_, output_shapes_);
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input,
- const std::vector<string>& transformations,
- const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
- : DatasetBase(DatasetContext(ctx)),
- input_(input),
- transformations_(transformations),
- output_types_(output_types),
- output_shapes_(output_shapes) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Assert")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return output_types_;
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
- }
-
- string DebugString() const override {
- return "AssertNextDatasetOp::Dataset";
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* transformations_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
- TF_RETURN_IF_ERROR(b->AddDataset(
- this, {input_graph_node, transformations_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- std::vector<string> tokens =
- str_util::Split(prefix(), ':', str_util::SkipEmpty());
- if (dataset()->transformations_.size() > tokens.size() - 2) {
- return errors::InvalidArgument(
- "Asserted next ", dataset()->transformations_.size(),
- " transformations but encountered only ", tokens.size() - 2, ".");
- }
- int n = tokens.size();
- for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
- if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
- return errors::InvalidArgument(
- "Asserted ", dataset()->transformations_[i],
- " transformation at offset ", i, " but encountered ",
- tokens[n - 2 - i], " transformation instead.");
- }
- }
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
- }
-
- private:
- std::unique_ptr<IteratorBase> input_impl_;
- };
-
- const DatasetBase* input_;
- const std::vector<string> transformations_;
- const DataTypeVector output_types_;
- const std::vector<PartialTensorShape> output_shapes_;
- };
-
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
- AssertNextDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
deleted file mode 100644
index 21ec50fb6b..0000000000
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ /dev/null
@@ -1,859 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// See docs in ../ops/parsing_ops.cc.
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-#include "tensorflow/core/lib/io/inputstream_interface.h"
-#include "tensorflow/core/lib/io/random_inputstream.h"
-#include "tensorflow/core/lib/io/zlib_compression_options.h"
-#include "tensorflow/core/lib/io/zlib_inputstream.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-class CSVDatasetOp : public DatasetOpKernel {
- public:
- explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- }
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- const Tensor* filenames_tensor;
- OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
- OP_REQUIRES(
- ctx, filenames_tensor->dims() <= 1,
- errors::InvalidArgument("`filenames` must be a scalar or a vector."));
-
- string compression_type;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
- &compression_type));
-
- OpInputList record_defaults_list;
- OP_REQUIRES_OK(ctx,
- ctx->input_list("record_defaults", &record_defaults_list));
- for (int i = 0; i < record_defaults_list.size(); ++i) {
- OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
- errors::InvalidArgument(
- "Each record default should be at most rank 1"));
- OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
- errors::InvalidArgument(
- "There should only be 1 default per field but field ", i,
- " has ", record_defaults_list[i].NumElements()));
- }
-
- const Tensor* select_cols_tensor;
- OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor));
- OP_REQUIRES(ctx, select_cols_tensor->dims() == 1,
- errors::InvalidArgument("`select_cols` must be a vector."));
-
- int64 buffer_size;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
- OP_REQUIRES(ctx, buffer_size > 0,
- errors::InvalidArgument("buffer_size should be positive"));
-
- string delim;
- OP_REQUIRES_OK(ctx,
- ParseScalarArgument<string>(ctx, "field_delim", &delim));
- OP_REQUIRES(ctx, delim.size() == 1,
- errors::InvalidArgument("field_delim should be only 1 char"));
-
- bool header;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header));
-
- bool use_quote_delim;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim",
- &use_quote_delim));
- string na_value;
- OP_REQUIRES_OK(ctx,
- ParseScalarArgument<string>(ctx, "na_value", &na_value));
-
- std::vector<Tensor> record_defaults;
- record_defaults.reserve(record_defaults_list.size());
- for (const Tensor& t : record_defaults_list) {
- record_defaults.push_back(t);
- }
-
- std::vector<string> filenames;
- filenames.reserve(filenames_tensor->NumElements());
- for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
- filenames.push_back(filenames_tensor->flat<string>()(i));
- }
-
- io::ZlibCompressionOptions zlib_compression_options =
- io::ZlibCompressionOptions::DEFAULT();
- if (compression_type == "ZLIB") {
- zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
- } else if (compression_type == "GZIP") {
- zlib_compression_options = io::ZlibCompressionOptions::GZIP();
- } else {
- OP_REQUIRES(ctx, compression_type.empty(),
- errors::InvalidArgument(
- "Unsupported compression_type: ", compression_type, "."));
- }
- zlib_compression_options.input_buffer_size = buffer_size;
-
- std::vector<int64> select_cols;
- select_cols.reserve(select_cols_tensor->NumElements());
- for (int i = 0; i < select_cols_tensor->NumElements(); ++i) {
- select_cols.push_back(select_cols_tensor->flat<int64>()(i));
- }
- OP_REQUIRES(
- ctx, output_types_.size() == select_cols.size() || select_cols.empty(),
- errors::InvalidArgument("select_cols should match output size"));
- for (int i = 1; i < select_cols.size(); i++) {
- OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i],
- errors::InvalidArgument(
- "select_cols should be strictly increasing indices"));
- }
- OP_REQUIRES(
- ctx, select_cols.empty() || select_cols.front() >= 0,
- errors::InvalidArgument("select_cols should be non-negative indices"));
-
- *output = new Dataset(ctx, std::move(filenames), header,
- std::move(compression_type), zlib_compression_options,
- output_types_, output_shapes_,
- std::move(record_defaults), std::move(select_cols),
- use_quote_delim, delim[0], std::move(na_value));
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
- string compression_type, io::ZlibCompressionOptions options,
- const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes,
- std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
- bool use_quote_delim, char delim, string na_value)
- : DatasetBase(DatasetContext(ctx)),
- filenames_(std::move(filenames)),
- header_(header),
- out_type_(output_types),
- output_shapes_(output_shapes),
- record_defaults_(std::move(record_defaults)),
- select_cols_(std::move(select_cols)),
- use_quote_delim_(use_quote_delim),
- delim_(delim),
- na_value_(std::move(na_value)),
- use_compression_(!compression_type.empty()),
- compression_type_(std::move(compression_type)),
- options_(options) {}
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::CSV")}));
- }
-
- const DataTypeVector& output_dtypes() const override { return out_type_; }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
- }
-
- string DebugString() const override { return "CSVDatasetOp::Dataset"; }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* filenames = nullptr;
- Node* compression_type = nullptr;
- Node* buffer_size = nullptr;
- Node* header = nullptr;
- Node* delim = nullptr;
- Node* use_quote_delim = nullptr;
- Node* na_value = nullptr;
- Node* select_cols = nullptr;
-
- std::vector<Node*> record_defaults;
- record_defaults.reserve(record_defaults_.size());
- for (const Tensor& t : record_defaults_) {
- Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
- record_defaults.emplace_back(node);
- }
-
- TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
- TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
- TF_RETURN_IF_ERROR(
- b->AddScalar(options_.input_buffer_size, &buffer_size));
- TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
-
- string delim_string(1, delim_);
- TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
- TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
- TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
- TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
-
- TF_RETURN_IF_ERROR(b->AddDataset(
- this,
- {std::make_pair(0, filenames), std::make_pair(1, compression_type),
- std::make_pair(2, buffer_size), std::make_pair(3, header),
- std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
- std::make_pair(6, na_value),
- std::make_pair(7, select_cols)}, // Single tensor inputs
- {std::make_pair(8, record_defaults)}, // Tensor list inputs
- {}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
- bool select_all = dataset()->select_cols_.empty();
- do {
- // We are currently processing a file, so try to read the next record
- if (input_stream_) {
- Status s = ReadRecord(ctx, out_tensors, select_all,
- dataset()->select_cols_);
- if (s.ok()) {
- // Validate output
- if (out_tensors->size() != dataset()->out_type_.size()) {
- return errors::InvalidArgument(
- "Expect ", dataset()->out_type_.size(), " fields but have ",
- out_tensors->size(), " in record");
- }
-
- *end_of_sequence = false;
- return s;
- }
- if (!errors::IsOutOfRange(s)) {
- // Not at the end of file, return OK or non-EOF errors to caller.
- *end_of_sequence = false;
- return s;
- }
- // We have reached the end of the current file, so maybe
- // move on to next file.
- ResetStreamsLocked();
- ++current_file_index_;
- }
- // Iteration ends when there are no more files to process.
- if (current_file_index_ == dataset()->filenames_.size()) {
- *end_of_sequence = true;
- return Status::OK();
- }
- TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
- } while (true);
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
- current_file_index_));
- // `input_stream_` is empty if
- // 1. GetNext has not been called even once.
- // 2. All files have been read and the iterator has been exhausted.
- if (input_stream_ && num_buffer_reads_ > 0) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
- // If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
- TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
- num_buffer_reads_));
- }
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- ResetStreamsLocked();
- int64 current_file_index;
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
- &current_file_index));
- current_file_index_ = size_t(current_file_index);
- // The keys "pos" and "num_buffer_reads" are written only if
- // the iterator was saved with an open, partially read file.
- if (reader->Contains(full_name("pos"))) {
- int64 pos, num_buffer_reads;
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
- &num_buffer_reads));
-
- TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
-
- num_buffer_reads_ = size_t(num_buffer_reads - 1);
-
- // Restores the most recently held buffer
- Status s = input_stream_->SkipNBytes(
- num_buffer_reads_ * dataset()->options_.input_buffer_size);
- if (!s.ok() && !errors::IsOutOfRange(s)) {
- // We might get out of range error here if the size of the file
- // is not an exact multiple of the buffer size, and the last buffer
- // read is < buffer_size. This is valid and we do not surface the
- // error.
- return s;
- }
-
- Status s2 = FillBuffer(&buffer_);
- if (!s2.ok() && !errors::IsOutOfRange(s2)) {
- return s2;
- }
- pos_ = size_t(pos);
- }
- return Status::OK();
- }
-
- private:
- // Reads an entire CSV row from the input stream, either from the
- // existing buffer or by filling the buffer as needed. Converts extracted
- // fields to output tensors as we go.
- //
- // When this function is called, pos_ should be the index of the first
- // character of the record in buffer_, or past the end of the buffer.
- // Note: ctx and out_tensors are only used in this function
- // when fields are included in the record.
- Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
- bool select_all, const std::vector<int64>& selected)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (pos_ >= buffer_.size()) {
- // At the end of the file, this will return errors::OutOfRange
- TF_RETURN_IF_ERROR(FillBuffer(&buffer_));
- pos_ = 0;
- }
-
- // The first character may be \n if this is the continuation of a
- // \r\n linebreak between this and the previous record. If so, skip it.
-
- bool end_of_record = false; // Keep track of when we find \n, \r or EOF
- size_t num_parsed = 0;
- size_t num_selected_parsed = 0;
-
- Status result;
-
- while (!end_of_record) { // Read till we reach \n, \r or EOF
- bool include =
- select_all || (num_selected_parsed < selected.size() &&
- selected[num_selected_parsed] == num_parsed);
-
- // Don't fail fast, so that the next call to GetNext may still return
- // a valid record
- result.Update(
- ParseOneField(ctx, out_tensors, &end_of_record, include));
-
- num_parsed++;
- if (include) num_selected_parsed++;
- }
-
- return result;
- }
-
- // Parses one field from position pos_ in the buffer. Fields are
- // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of
- // the next field.
- Status ParseOneField(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_record, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (pos_ >= buffer_.size()) {
- // If we get here, this means the previous field's end coincided
- // with the end of the buffer. We can fill the buffer without abandon.
- Status s = FillBuffer(&buffer_);
-
- if (errors::IsOutOfRange(s)) {
- // Reached EOF, and last field is empty
- *end_of_record = true;
- if (include) {
- return FieldToOutput(ctx, StringPiece(), out_tensors);
- } else {
- return Status::OK();
- }
- } else if (!s.ok()) {
- return s; // Surface other errors back to caller
- }
-
- pos_ = 0;
- }
-
- if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') {
- return ParseQuotedField(ctx, out_tensors, end_of_record, include);
- }
-
- return ParseUnquotedField(ctx, out_tensors, end_of_record, include);
- }
-
- // For keeping track of relevant parts of a field from a previous buffer
- struct Piece {
- size_t start;
- size_t len;
- string buffer;
-
- Piece(string buffer, size_t start, size_t len)
- : start(start), len(len), buffer(std::move(buffer)) {}
- };
-
- // Given that pos_ exceeds the buffer, saves the relevant part of the
- // current buffer (if necessary), fills the buffer, and resets indices to
- // 0.
- Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
- size_t* start, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- string temp_buffer;
-
- buffer_.swap(temp_buffer);
- if (include && pos_ > *start) {
- earlier_pieces->push_back(
- Piece(std::move(temp_buffer), *start, pos_ - *start));
- }
- pos_ = 0;
- *start = 0;
- return FillBuffer(&buffer_);
- }
-
- // Parses unquoted field from position pos_ in the buffer. Continually
- // reads from buffer until end of field is reached (delim, CRLF, or EOF).
- // Advances pos_ to keep track of our position in the buffer as we go,
- // stopping at the first character of the next field.
- Status ParseQuotedField(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_record, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- std::vector<Piece> earlier_pieces;
- size_t start = pos_;
- pos_++; // Starting quotation mark
-
- Status parse_result;
- while (true) { // Each iter reads 1 char, filling buffer if necessary
- if (pos_ >= buffer_.size()) {
- Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
- if (errors::IsOutOfRange(s)) {
- return errors::InvalidArgument(
- "Reached end of file without closing quoted field in "
- "record");
- } else if (!s.ok()) {
- return s; // Surface all other errors to caller
- }
- }
-
- char ch = buffer_[pos_];
- if (ch == '"') {
- // When we encounter a quote, we look ahead to the next character to
- // decide what to do
- pos_++;
- if (pos_ >= buffer_.size()) {
- Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
- if (errors::IsOutOfRange(s)) {
- // This was the last field. We are done
- *end_of_record = true;
- parse_result.Update(QuotedFieldToOutput(
- ctx, StringPiece(), out_tensors, earlier_pieces, include));
- return parse_result;
- } else if (!s.ok()) {
- return s;
- }
- }
-
- char next = buffer_[pos_];
- pos_++;
- if (next == dataset()->delim_) {
- parse_result.Update(QuotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
- out_tensors, earlier_pieces, include));
- return parse_result;
-
- } else if (next == '\n' || next == '\r') {
- *end_of_record = true;
- parse_result.Update(QuotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
- out_tensors, earlier_pieces, include));
- if (next == '\r') SkipNewLineIfNecessary();
- return parse_result;
- } else if (next != '"') {
- // Take note of the error, but keep going to end of field.
- include = false; // So we don't get funky errors when trying to
- // unescape the quotes.
- parse_result.Update(errors::InvalidArgument(
- "Quote inside a string has to be escaped by another quote"));
- }
-
- } else {
- pos_++;
- }
- }
- }
-
- // Converts quoted field to an output tensor, removing the starting
- // and ending quotes from it and unescaping double quotations if
- // necessary.
- Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field,
- std::vector<Tensor>* out_tensors,
- const std::vector<Piece>& earlier_pieces,
- bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!include) return Status::OK();
-
- if (earlier_pieces.empty()) {
- if (field.find('\"', 1) == field.size() - 1) {
- // `field` contains no escaped quotation marks.
- // Exclude framing quotation marks
- field.remove_prefix(1);
- field.remove_suffix(1);
- return FieldToOutput(ctx, field, out_tensors);
- }
- }
- string field_complete;
- size_t str_len = field.size();
- for (const Piece& p : earlier_pieces) {
- str_len += p.len;
- }
- field_complete.reserve(str_len);
-
- // This bool flips every time we see a quote, so that we skip the second
- // quote of every pair of adjacent quotes in the field. We need to track
- // this across iterations of the for loop because adjacent double quotes
- // may be in different buffers. Initialize to true because we also skip
- // the opening quotation mark of the quoted field.
- bool skip_next_quote = true;
- for (const Piece& p : earlier_pieces) {
- AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len),
- &field_complete, &skip_next_quote);
- }
- AppendUnescapedPiece(field, &field_complete, &skip_next_quote);
- StringPiece result = StringPiece(field_complete);
- result.remove_suffix(1); // Skip final quote
-
- return FieldToOutput(ctx, result, out_tensors);
- }
-
- void AppendUnescapedPiece(StringPiece piece, string* field_complete,
- bool* skip_next_quote) {
- size_t from = 0;
- size_t found = piece.find('\"', from);
- while (found != string::npos) {
- if (!*skip_next_quote) {
- // This is the first quote in a pair of adjacent double quotes
- field_complete->append(piece.data() + from, found + 1 - from);
- }
- *skip_next_quote = !*skip_next_quote;
- from = found + 1;
- found = piece.find('\"', from);
- }
- // Include the chunk after the last quotation mark in the string
- if (from < piece.size()) {
- field_complete->append(piece.data() + from, piece.size() - from);
- }
- }
-
- // Parses unquoted field from position pos_ in the buffer. Continually
- // reads from buffer until end of field is reached (delim, CRLF, or EOF).
- // Advances pos_ to keep track of our position in the buffer as we go,
- // stopping at the first character of the next field.
- Status ParseUnquotedField(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_record, bool include)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- std::vector<Piece> earlier_pieces;
- size_t start = pos_;
- Status parse_result;
-
- while (true) { // Each iter reads 1 char, filling buffer if necessary
- if (pos_ >= buffer_.size()) {
- Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
- // Handle errors
- if (errors::IsOutOfRange(s)) {
- // Whatever we have is the last field of the last record
- *end_of_record = true;
- parse_result.Update(UnquotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include));
- return parse_result;
- } else if (!s.ok()) {
- return s; // Surface all other errors to caller
- }
- }
-
- char ch = buffer_[pos_];
-
- if (ch == dataset()->delim_) {
- parse_result.Update(UnquotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include));
- pos_++;
- return parse_result;
- }
- if (ch == '\n' || ch == '\r') {
- // need special case to skip over first \n of record if the line
- // breaks are \r\n
- parse_result.Update(UnquotedFieldToOutput(
- ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
- earlier_pieces, include));
- *end_of_record = true;
- pos_++;
- if (ch == '\r') SkipNewLineIfNecessary();
- return parse_result;
- }
- if (dataset()->use_quote_delim_ && ch == '"') {
- // Take note of the error, but keep going to end of field.
- parse_result.Update(errors::InvalidArgument(
- "Unquoted fields cannot have quotes inside"));
- }
- // Otherwise, go to next character
- pos_++;
- }
- }
-
- Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- result->clear();
- ++num_buffer_reads_;
- Status s = input_stream_->ReadNBytes(
- dataset()->options_.input_buffer_size, result);
-
- if (errors::IsOutOfRange(s) && !result->empty()) {
- // Ignore OutOfRange error when ReadNBytes read < N bytes.
- return Status::OK();
- }
- return s;
- }
-
- // Given a field, converts it to the right output tensor type
- Status FieldToOutput(IteratorContext* ctx, StringPiece field,
- std::vector<Tensor>* out_tensors) {
- size_t output_idx = out_tensors->size();
- if (output_idx >= dataset()->out_type_.size()) {
- // We can get here if we're selecting all columns, but the number of
- // fields exceeds the number of defaults provided
- return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
- " fields but have more in record");
- }
- const DataType& dtype = dataset()->out_type_[output_idx];
- Tensor component(ctx->allocator({}), dtype, {});
- if ((field.empty() || field == dataset()->na_value_) &&
- dataset()->record_defaults_[output_idx].NumElements() != 1) {
- // If the field is empty or NA value, and default is not given,
- // report error.
- return errors::InvalidArgument("Field ", output_idx,
- " is required but missing in record!");
- }
-
- switch (dtype) {
- // For each case, if the field is empty, we use the default.
- // Otherwise, we convert it to the right type.
- case DT_INT32: {
- if (field.empty() || field == dataset()->na_value_) {
- component.scalar<int32>()() =
- dataset()->record_defaults_[output_idx].flat<int32>()(0);
- } else {
- int32 value;
- if (!strings::safe_strto32(field, &value)) {
- return errors::InvalidArgument(
- "Field ", output_idx,
- " in record is not a valid int32: ", field);
- }
- component.scalar<int32>()() = value;
- }
- break;
- }
- case DT_INT64: {
- if (field.empty() || field == dataset()->na_value_) {
- component.scalar<int64>()() =
- dataset()->record_defaults_[output_idx].flat<int64>()(0);
- } else {
- int64 value;
- if (!strings::safe_strto64(field, &value)) {
- return errors::InvalidArgument(
- "Field ", output_idx,
- " in record is not a valid int64: ", field);
- }
- component.scalar<int64>()() = value;
- }
- break;
- }
- case DT_FLOAT: {
- if (field.empty() || field == dataset()->na_value_) {
- component.scalar<float>()() =
- dataset()->record_defaults_[output_idx].flat<float>()(0);
- } else {
- float value;
- if (!strings::safe_strtof(field, &value)) {
- return errors::InvalidArgument(
- "Field ", output_idx,
- " in record is not a valid float: ", field);
- }
- component.scalar<float>()() = value;
- }
- break;
- }
- case DT_DOUBLE: {
- if (field.empty() || field == dataset()->na_value_) {
- component.scalar<double>()() =
- dataset()->record_defaults_[output_idx].flat<double>()(0);
- } else {
- double value;
- if (!strings::safe_strtod(field, &value)) {
- return errors::InvalidArgument(
- "Field ", output_idx,
- " in record is not a valid double: ", field);
- }
- component.scalar<double>()() = value;
- }
- break;
- }
- case DT_STRING: {
- if (field.empty() || field == dataset()->na_value_) {
- component.scalar<string>()() =
- dataset()->record_defaults_[output_idx].flat<string>()(0);
- } else {
- component.scalar<string>()() = string(field);
- }
- break;
- }
- default:
- return errors::InvalidArgument("csv: data type ", dtype,
- " not supported in field ",
- output_idx);
- }
- out_tensors->push_back(std::move(component));
- return Status::OK();
- }
-
- // Records can be delimited by "\r\n" line breaks. When we encounter a
- // '\r', we have to check the next character to see if it is part of the
- // linebreak, and ignore it if so.
- void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (pos_ >= buffer_.size()) {
- Status s = FillBuffer(&buffer_);
- pos_ = 0;
- // If we failed to fill buffer, it doesn't matter because we're done
- // with the record
- if (!s.ok()) return;
- }
- if (buffer_[pos_] == '\n') {
- pos_++;
- }
- }
-
- // Given a string field, and its index in the output,
- // converts it to a Tensor of the right type and adds it to the
- // out_tensors vector.
- Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field,
- std::vector<Tensor>* out_tensors,
- const std::vector<Piece>& earlier_pieces,
- bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!include) return Status::OK();
-
- if (earlier_pieces.empty()) {
- return FieldToOutput(ctx, field, out_tensors);
- }
-
- size_t str_len = field.size();
- for (const Piece& p : earlier_pieces) {
- str_len += p.len;
- }
- string field_complete;
- field_complete.reserve(str_len);
-
- for (const Piece& p : earlier_pieces) {
- field_complete.append(p.buffer, p.start, p.len);
- }
-
- field_complete.append(field.data(), field.size());
- return FieldToOutput(ctx, field_complete, out_tensors);
- }
-
- // Sets up reader streams to read from the file at `current_file_index_`.
- Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (current_file_index_ >= dataset()->filenames_.size()) {
- return errors::InvalidArgument(
- "current_file_index_:", current_file_index_,
- " >= filenames_.size():", dataset()->filenames_.size());
- }
-
- // Actually move on to next file.
- TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
- dataset()->filenames_[current_file_index_], &file_));
- random_access_input_stream_ =
- std::make_shared<io::RandomAccessInputStream>(file_.get(), false);
-
- if (dataset()->use_compression_) {
- input_stream_ = std::make_shared<io::ZlibInputStream>(
- random_access_input_stream_.get(),
- dataset()->options_.input_buffer_size,
- dataset()->options_.input_buffer_size, dataset()->options_);
- } else {
- input_stream_ = random_access_input_stream_;
- }
- buffer_.clear();
- pos_ = 0;
- num_buffer_reads_ = 0;
- if (dataset()->header_) {
- // Read one line, but don't include it. Pass nullptrs as dummy
- // pointers to objects that shouldn't be invoked anyway
- // We need to process this as a record here instead of just finding
- // the first newline because it might contain quoted fields with
- // newlines in the header as well
- std::vector<int64> empty;
- Status s = ReadRecord(nullptr, nullptr, false, empty);
- if (!s.ok()) {
- return errors::InvalidArgument("Can't read header of file");
- }
- }
- return Status::OK();
- }
-
- // Resets all reader streams.
- void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- input_stream_.reset();
- file_.reset();
- }
-
- mutex mu_;
- string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
- size_t pos_ GUARDED_BY(
- mu_); // Index into the buffer must be maintained between iters
- size_t num_buffer_reads_ GUARDED_BY(mu_);
- std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
- GUARDED_BY(mu_);
- std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
- size_t current_file_index_ GUARDED_BY(mu_) = 0;
- std::unique_ptr<RandomAccessFile> file_
- GUARDED_BY(mu_); // must outlive input_stream_
- }; // class Iterator
-
- const std::vector<string> filenames_;
- const bool header_;
- const DataTypeVector out_type_;
- const std::vector<PartialTensorShape> output_shapes_;
- const std::vector<Tensor> record_defaults_;
- const std::vector<int64> select_cols_;
- const bool use_quote_delim_;
- const char delim_;
- const string na_value_;
- const bool use_compression_;
- const string compression_type_;
- const io::ZlibCompressionOptions options_;
- }; // class Dataset
-
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
-}; // class CSVDatasetOp
-
-// Register the kernel implementation for CSVDataset.
-REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
deleted file mode 100644
index a5321620bf..0000000000
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ /dev/null
@@ -1,280 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/partial_tensor_shape.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/hash/hash.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
-class DirectedInterleaveDatasetOp : public DatasetOpKernel {
- public:
- explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
- : DatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- DatasetBase* selector_input;
- OP_REQUIRES_OK(ctx,
- GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
-
- OP_REQUIRES(
- ctx,
- selector_input->output_dtypes().size() == 1 &&
- selector_input->output_dtypes()[0] == DT_INT64 &&
- selector_input->output_shapes().size() == 1 &&
- selector_input->output_shapes()[0].IsCompatibleWith(
- PartialTensorShape({})),
- errors::InvalidArgument(
- "The selector input must be a dataset of scalar int64 elements."));
-
- std::vector<DatasetBase*> data_inputs;
- for (size_t i = 1; i < ctx->num_inputs(); ++i) {
- DatasetBase* input;
- OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
- data_inputs.push_back(input);
-
- OP_REQUIRES(
- ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
- errors::InvalidArgument(
- "All inputs must have the same output_dtypes. First input "
- "has types ",
- DataTypeVectorString(data_inputs[0]->output_dtypes()),
- ", and input ", i - 1, " has types ",
- DataTypeVectorString(input->output_dtypes())));
- }
- *output = new Dataset(ctx, selector_input, std::move(data_inputs));
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
- std::vector<DatasetBase*> data_inputs)
- : DatasetBase(DatasetContext(ctx)),
- selector_input_(selector_input),
- data_inputs_(std::move(data_inputs)) {
- selector_input_->Ref();
-
- output_shapes_ = data_inputs_[0]->output_shapes();
- data_inputs_[0]->Ref();
- for (size_t i = 1; i < data_inputs_.size(); ++i) {
- const DatasetBase* data_input = data_inputs_[i];
- data_input->Ref();
- for (size_t j = 0; j < output_shapes_.size(); ++j) {
- output_shapes_[j] = MostSpecificCompatibleShape(
- output_shapes_[j], data_input->output_shapes()[j]);
- }
- }
- }
-
- ~Dataset() override {
- selector_input_->Unref();
- for (DatasetBase* data_input : data_inputs_) {
- data_input->Unref();
- }
- }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::DirectedInterleave")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return data_inputs_[0]->output_dtypes();
- }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
- }
-
- string DebugString() const override {
- return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* selector_input_node;
- TF_RETURN_IF_ERROR(
- b->AddInputDataset(ctx, selector_input_, &selector_input_node));
- std::vector<Node*> data_input_nodes(data_inputs_.size());
- for (size_t i = 0; i < data_inputs_.size(); ++i) {
- TF_RETURN_IF_ERROR(
- b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
- }
- TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
- {{1, data_input_nodes}}, {}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- num_active_inputs_(params.dataset->data_inputs_.size()) {}
-
- Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
- ctx, strings::StrCat(prefix(), ".selector"),
- &selector_input_impl_));
- data_input_impls_.resize(dataset()->data_inputs_.size());
- for (size_t i = 0; i < data_input_impls_.size(); ++i) {
- const DatasetBase* data_input = dataset()->data_inputs_[i];
- TF_RETURN_IF_ERROR(data_input->MakeIterator(
- ctx, strings::StrCat(prefix(), "[", i, "]"),
- &data_input_impls_[i]));
- }
- return Status::OK();
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
- if (!selector_input_impl_) {
- *end_of_sequence = true;
- return Status::OK();
- }
-
- while (true) {
- std::vector<Tensor> selector_result;
- *end_of_sequence = false;
- TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
- ctx, &selector_result, end_of_sequence));
- if (*end_of_sequence) {
- selector_input_impl_.reset();
- for (auto& data_input_impl : data_input_impls_) {
- data_input_impl.reset();
- }
- return Status::OK();
- }
-
- int64 selected_input = selector_result[0].scalar<int64>()();
- if (selected_input < 0 || selected_input > data_input_impls_.size()) {
- return errors::InvalidArgument(
- "Selector index out of range: ", selected_input,
- " >= ", data_input_impls_.size());
- }
-
- if (data_input_impls_[selected_input]) {
- bool end_of_selected_input = false;
- TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
- ctx, out_tensors, &end_of_selected_input));
-
- if (!end_of_selected_input) {
- return Status::OK();
- }
-
- data_input_impls_[selected_input].reset();
- --num_active_inputs_;
-
- if (num_active_inputs_ == 0) {
- selector_input_impl_.reset();
- *end_of_sequence = true;
- return Status::OK();
- }
- }
-
- LOG(WARNING) << "DirectedInterleave selected an exhausted input: "
- << selected_input;
- }
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- if (selector_input_impl_) {
- TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_));
- } else {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
- }
- for (size_t i = 0; i < data_input_impls_.size(); ++i) {
- const auto& data_input_impl = data_input_impls_[i];
- if (data_input_impl) {
- TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl));
- } else {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
- ""));
- }
- }
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- if (!reader->Contains(full_name("selector_input_impl_empty"))) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
- } else {
- selector_input_impl_.reset();
- }
- for (size_t i = 0; i < data_input_impls_.size(); ++i) {
- if (!reader->Contains(full_name(
- strings::StrCat("data_input_impl_empty[", i, "]")))) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
- } else {
- data_input_impls_[i].reset();
- }
- }
- return Status::OK();
- }
-
- private:
- mutex mu_;
- std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_);
- std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
- GUARDED_BY(mu_);
- int64 num_active_inputs_ GUARDED_BY(mu_);
- };
-
- static PartialTensorShape MostSpecificCompatibleShape(
- const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
- if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
- return output_tensorshape;
- auto dims1 = ts1.dim_sizes();
- auto dims2 = ts2.dim_sizes();
- for (int d = 0; d < ts1.dims(); d++) {
- if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
- else
- output_tensorshape.Concatenate(-1);
- }
- return output_tensorshape;
- }
-
- const DatasetBase* const selector_input_;
- const std::vector<DatasetBase*> data_inputs_;
- std::vector<PartialTensorShape> output_shapes_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
- DirectedInterleaveDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
deleted file mode 100644
index c3cb45dbf7..0000000000
--- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
+++ /dev/null
@@ -1,155 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
-#include "tensorflow/core/lib/core/errors.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
- public:
- using IndexedDatasetOpKernel::IndexedDatasetOpKernel;
-
- void MakeIndexedDataset(OpKernelContext* ctx,
- IndexedDataset** output) override {
- uint64 size = -1;
- OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size));
- OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0"));
- *output = new Dataset(ctx, size);
- }
-
- class Dataset : public IndexedDataset {
- public:
- Dataset(OpKernelContext* ctx, uint64 size)
- : IndexedDataset(DatasetContext(ctx)), size_(size) {}
-
- Status MaterializeDataset(
- std::shared_ptr<MaterializedIndexedDataset>* materialized) override {
- materialized->reset(new Materialized(this));
- return Status::OK();
- }
-
- const DataTypeVector& output_dtypes() const override {
- static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64});
- return *dtypes;
- }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- static std::vector<PartialTensorShape>* shapes =
- new std::vector<PartialTensorShape>({{}});
- return *shapes;
- }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::IdentityIndexedDataset")}));
- }
-
- string DebugString() const override {
- return "IdentityIndexedDataset::Dataset";
- }
-
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** node) const override {
- return errors::Unimplemented(
- "identity_indexed_dataset.AsGraphDefInternal");
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
- if (cur_ < dataset()->size_) {
- Tensor result_tensor(ctx->allocator({}), DT_UINT64, {});
- result_tensor.scalar<uint64>()() = cur_++;
- out_tensors->emplace_back(std::move(result_tensor));
- *end_of_sequence = false;
- return Status::OK();
- }
- *end_of_sequence = true;
- return Status::OK();
- }
-
- private:
- mutex mu_;
- uint64 cur_ GUARDED_BY(mu_);
- };
-
- class Materialized : public MaterializedIndexedDataset {
- public:
- explicit Materialized(Dataset* dataset) : dataset_(dataset) {
- dataset->Ref();
- }
-
- ~Materialized() override {
- // TODO(saeta): Pull this into MaterializedIndexedDataset
- dataset_->Unref();
- }
-
- const DataTypeVector& output_dtypes() const override {
- return dataset_->output_dtypes();
- }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return dataset_->output_shapes();
- }
-
- Status Get(IteratorContext&& ctx, uint64 index,
- std::vector<Tensor>* out_tensors) const override {
- LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index
- << ")";
- if (index >= dataset_->size_) {
- // Note: use InvalidArgument instead of OutOfRange error because many
- // things consider OutOfRange to be a "clean termination" error.
- return errors::InvalidArgument(
- "Index ", index,
- " is out of range for this dataset. (Size is: ", dataset_->size_,
- ".)");
- }
- Tensor result_tensor(ctx.allocator({}), DT_UINT64, {});
- result_tensor.scalar<uint64>()() = index;
- out_tensors->emplace_back(std::move(result_tensor));
- return Status::OK();
- }
-
- Status Size(uint64* size) const override {
- *size = dataset_->size_;
- return Status::OK();
- }
-
- private:
- const Dataset* const dataset_; // Not owned.
- };
-
- const uint64 size_;
- std::shared_ptr<Materialized> materialized_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
- IdentityIndexedDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
deleted file mode 100644
index beec344534..0000000000
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ /dev/null
@@ -1,141 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/partial_tensor_shape.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/random/random.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
-class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- *output = new Dataset(ctx, input);
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
- : DatasetBase(DatasetContext(ctx)), input_(input) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override {
- return "IgnoreErrorsDatasetOp::Dataset";
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- {
- tf_shared_lock l(mu_);
- if (!input_impl_) {
- *end_of_sequence = true;
- return Status::OK();
- }
- Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
- while (!s.ok()) {
- out_tensors->clear();
- s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
- }
- }
- if (*end_of_sequence) {
- mutex_lock l(mu_);
- input_impl_.reset();
- }
- return Status::OK();
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- if (input_impl_)
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- else
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_impls_empty"), ""));
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- if (reader->Contains(full_name("input_impls_empty")))
- input_impl_.reset();
- else
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
- }
-
- private:
- mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- };
-
- const DatasetBase* const input_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
- IgnoreErrorsDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc
deleted file mode 100644
index ced8ab0d60..0000000000
--- a/tensorflow/contrib/data/kernels/indexed_dataset.cc
+++ /dev/null
@@ -1,373 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
-
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/tensor_shape.h"
-#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/cleanup.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
-class MaterializedDatasetResource : public ResourceBase {
- public:
- MaterializedDatasetResource(
- const DataTypeVector& output_dtypes,
- const std::vector<PartialTensorShape>& output_shapes)
- : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
-
- string DebugString() override {
- return "Materialized IndexedDataset resource";
- }
-
- Status Get(IteratorContext&& ctx, uint64 index,
- std::vector<Tensor>* out_tensors) {
- std::shared_ptr<MaterializedIndexedDataset> captured(materialized_);
- if (captured) {
- return captured->Get(std::move(ctx), index, out_tensors);
- } else {
- return errors::FailedPrecondition(
- "Get() failed because the MaterializedIndexedDataset has not been "
- "initialized. Ensure that you have run the materialization operation "
- "for this MaterializedIndexedDataset before retrieving elements.");
- }
- }
-
- // TODO(saeta): Implement Save and Restore
-
- const DataTypeVector& output_dtypes() const { return output_dtypes_; }
- const std::vector<PartialTensorShape>& output_shapes() const {
- return output_shapes_;
- }
-
- Status set_materialized_dataset(
- const std::shared_ptr<MaterializedIndexedDataset>& dataset) {
- if (dataset) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_dtypes_, dataset->output_dtypes()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, dataset->output_shapes()));
- }
- materialized_ = dataset;
- return Status::OK();
- }
-
- private:
- std::shared_ptr<MaterializedIndexedDataset> materialized_;
- const DataTypeVector output_dtypes_;
- const std::vector<PartialTensorShape> output_shapes_;
-};
-
-// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT
-// tensor. Objects of the wrapper class own a reference on an instance of an
-// `IndexedTensor` and the wrapper's copy constructor and desctructor take care
-// of managing the reference count.
-//
-// NOTE: This is not a feature-complete implementation of the DT_VARIANT
-// specification. In particular, we cannot currently serialize an arbitrary
-// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not
-// implemented.
-//
-// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just
-// use `tensorflow::DatasetVariantWrapper`.
-class IndexedDatasetVariantWrapper {
- public:
- IndexedDatasetVariantWrapper() : dataset_(nullptr) {}
-
- // Transfers ownership of `dataset` to `*this`.
- explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset)
- : dataset_(dataset) {}
-
- IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other)
- : dataset_(other.dataset_) {
- if (dataset_) dataset_->Ref();
- }
-
- ~IndexedDatasetVariantWrapper() {
- if (dataset_) dataset_->Unref();
- }
-
- IndexedDataset* get() const { return dataset_; }
-
- string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; }
- string DebugString() const {
- if (dataset_) {
- return dataset_->DebugString();
- } else {
- return "<Uninitialized IndexedDatasetVariantWrapper>";
- }
- }
-
- void Encode(VariantTensorData* data) const {
- LOG(ERROR) << "The Encode() method is not implemented for "
- "IndexedDatasetVariantWrapper objects.";
- }
-
- bool Decode(const VariantTensorData& data) {
- LOG(ERROR) << "The Decode() method is not implemented for "
- "IndexedDatasetVariantWrapper objects.";
- return false;
- }
-
- private:
- IndexedDataset* const dataset_; // Owns one reference.
-};
-
-} // namespace
-
-Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
- IndexedDataset** out_dataset) {
- if (!(tensor.dtype() == DT_VARIANT ||
- TensorShapeUtils::IsScalar(tensor.shape()))) {
- return errors::InvalidArgument(
- "IndexedDataset tensor must be a scalar of dtype DT_VARIANT.");
- }
- const Variant& variant = tensor.scalar<Variant>()();
- const IndexedDatasetVariantWrapper* wrapper =
- variant.get<IndexedDatasetVariantWrapper>();
- if (wrapper == nullptr) {
- return errors::InvalidArgument("Tensor must be an IndexedDataset object.");
- }
- *out_dataset = wrapper->get();
- if (*out_dataset == nullptr) {
- return errors::Internal("Read uninitialized IndexedDataset variant.");
- }
- return Status::OK();
-}
-
-Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
- Tensor* tensor) {
- if (!(tensor->dtype() == DT_VARIANT ||
- TensorShapeUtils::IsScalar(tensor->shape()))) {
- return errors::InvalidArgument(
- "Dataset tensor must be a scalar of dtype DT_VARIANT.");
- }
- tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset);
- return Status::OK();
-}
-
-void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) {
- IndexedDataset* dataset = nullptr;
- MakeIndexedDataset(ctx, &dataset);
-
- if (ctx->status().ok()) {
- OP_REQUIRES(ctx, dataset != nullptr,
- errors::Internal("MakeIndexedDataset did not correctly "
- "construct the IndexedDataset"));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
- OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output));
- }
-}
-
-namespace {
-
-class MaterializedHandleOp : public OpKernel {
- public:
- explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- }
-
- ~MaterializedHandleOp() override {
- if (resource_ != nullptr) {
- resource_->Unref();
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->template Delete<MaterializedDatasetResource>(
- cinfo_.container(), cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h
- }
- }
- }
- }
-
- void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- if (resource_ == nullptr) {
- ResourceMgr* mgr = context->resource_manager();
- OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
-
- MaterializedDatasetResource* resource;
- OP_REQUIRES_OK(context,
- mgr->LookupOrCreate<MaterializedDatasetResource>(
- cinfo_.container(), cinfo_.name(), &resource,
- [this](MaterializedDatasetResource** ret)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- *ret = new MaterializedDatasetResource(
- output_dtypes_, output_shapes_);
- return Status::OK();
- }));
- Status s = VerifyResource(resource);
- if (TF_PREDICT_FALSE(!s.ok())) {
- resource->Unref();
- context->SetStatus(s);
- return;
- }
-
- resource_ = resource;
- }
- }
- OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
- context, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<MaterializedDatasetResource>()));
- }
-
- private:
- // During the first Compute(), resource is either created or looked up using
- // shared_name. In the latter case, the resource found should be verified if
- // it is compatible with this op's configuration. The verification may fail in
- // cases such as two graphs asking queues of the same shared name to have
- // inconsistent capacities.
- Status VerifyResource(MaterializedDatasetResource* resource) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
- return Status::OK();
- }
-
- mutex mu_;
- ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
- MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr;
- DataTypeVector output_dtypes_;
- std::vector<PartialTensorShape> output_shapes_;
-};
-
-// TODO(saeta): Make async.
-class MaterializeDatasetOp : public OpKernel {
- public:
- explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IndexedDataset* dataset;
- OP_REQUIRES_OK(ctx,
- GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset));
-
- MaterializedDatasetResource* materialized_resource;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
- &materialized_resource));
- core::ScopedUnref unref(materialized_resource);
- std::shared_ptr<MaterializedIndexedDataset> materialized;
- OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized));
- OP_REQUIRES_OK(
- ctx, materialized_resource->set_materialized_dataset(materialized));
- }
-};
-
-// TODO(saeta): Make async
-class IndexedDatasetGet : public OpKernel {
- public:
- explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- MaterializedDatasetResource* materialized_resource;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0),
- &materialized_resource));
- auto cleanup = gtl::MakeCleanup([materialized_resource] {
- materialized_resource->Unref(); // Note: can't use core::ScopedUnref.
- });
-
- const Tensor* index_t;
- OP_REQUIRES_OK(ctx, ctx->input("index", &index_t));
- // TODO(saeta): Support batch reads (indexes should be non-scalar!)
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()),
- errors::InvalidArgument("index must be a scalar"));
- const uint64 index = index_t->scalar<uint64>()();
-
- std::vector<Tensor> out_tensors;
- Status s =
- materialized_resource->Get(IteratorContext(ctx), index, &out_tensors);
-
- // Note: Unref materialized_resource to avoid destruction races. (Important
- // in a [future] async op implementation.)
- cleanup.release()();
-
- if (!s.ok()) {
- ctx->SetStatus(s);
- } else {
- auto expected_shapes = materialized_resource->output_shapes();
- auto expected_types = materialized_resource->output_dtypes();
- for (size_t i = 0; i < out_tensors.size(); ++i) {
- OP_REQUIRES(
- ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()),
- errors::Internal(
- "Materialized dataset output at index ", i,
- " is incompatible with the expected shape. (Expected: ",
- expected_shapes[i], ", got: ", out_tensors[i].shape(), ")"));
- OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i],
- errors::Internal("Materialized dataset output at index ", i,
- " was not the expected dtype. (Expected: ",
- expected_types[i],
- ", got: ", out_tensors[i].dtype(), ")"));
- ctx->set_output(i, out_tensors[i]);
- }
- }
- }
-};
-
-REGISTER_KERNEL_BUILDER(
- Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
- MaterializedHandleOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
- MaterializeDatasetOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
- IndexedDatasetGet);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h
deleted file mode 100644
index 7aa2d3fdbc..0000000000
--- a/tensorflow/contrib/data/kernels/indexed_dataset.h
+++ /dev/null
@@ -1,119 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
-#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
-
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/op_kernel.h"
-
-namespace tensorflow {
-namespace data {
-
-// TODO(saeta): Urgh, this is ugly.
-class MaterializedIndexedDataset {
- public:
- virtual ~MaterializedIndexedDataset() = default;
-
- // Retrieve the element at a given index. The output tensors are stored in
- // out_tensors.
- //
- // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is
- // returned.
- //
- // Get is thread-safe.
- virtual Status Get(IteratorContext&& ctx, uint64 index,
- std::vector<Tensor>* out_tensors) const = 0;
-
- // Size determines the number of elements in this IndexedDataset.
- //
- // Size is thread-safe.
- virtual Status Size(uint64* size) const = 0;
-
- // Returns a vector of DataType values, representing the respective
- // element types of each tuple component in the outputs of this dataset.
- virtual const DataTypeVector& output_dtypes() const = 0;
-
- // Returns a vector of tensor shapes, representing the respective
- // (and possibly partially defined) shapes of each tuple component
- // in the outputs of this dataset.
- virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
-};
-
-// IndexedDataset represents a dataset that supports random access in addition
-// to iterator-based sequential access.
-//
-// Note: IndexedDatasets are HIGHLY experimental at this time. Expect
-// significant (backwards incompatible) changes!
-class IndexedDataset : public DatasetBase {
- public:
- IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {}
-
- // Materialize (if necessary) the dataset, and return a pointer.
- // TODO(saeta): Add in `IteratorContext* ctx` when materializing.
- virtual Status MaterializeDataset(
- std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0;
-};
-
-// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the
-// rest of the TensorFlow runtime.
-//
-// Most IndexedDataset's will be private members of classes inheriting from this
-// class.
-class IndexedDatasetOpKernel : public OpKernel {
- public:
- IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
- void Compute(OpKernelContext* ctx) final;
-
- protected:
- // Subclasses should implement this method. It will be called during Compute
- // execution.
- virtual void MakeIndexedDataset(OpKernelContext* ctx,
- IndexedDataset** output) = 0;
-
- template <typename T>
- Status ParseScalarArgument(OpKernelContext* ctx,
- const StringPiece& argument_name, T* output) {
- const Tensor* argument_t;
- TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
- if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
- return errors::InvalidArgument(argument_name, " must be a scalar");
- }
- *output = argument_t->scalar<T>()();
- return Status::OK();
- }
-};
-
-// Validates and extracts an `IndexedDataset` object from `tensor`.
-//
-// `tensor` must have been written by a call to
-// `StoreIndexedDatasetInVariantTensor`
-//
-// The retrieved pointer isa borrowed reference to the dataset, which is owned
-// by the tensor. The consumer must either acquire its own reference to the
-// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
-// destroyed or mutated while the retrieved pointer is in use.
-Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
- IndexedDataset** out_dataset);
-
-// Stores an `IndexedDataset` object in `tensor.`
-//
-// The ownership of `dataset` is transferred to `tensor`.
-Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
- Tensor* tensor);
-
-} // namespace data
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
deleted file mode 100644
index d233c1f8ec..0000000000
--- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
+++ /dev/null
@@ -1,217 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <sys/stat.h>
-
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/lib/io/buffered_inputstream.h"
-#include "tensorflow/core/platform/file_system.h"
-
-#include "lmdb.h" // NOLINT(build/include)
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-class LMDBDatasetOp : public DatasetOpKernel {
- public:
- using DatasetOpKernel::DatasetOpKernel;
- void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- const Tensor* filenames_tensor;
- OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
- OP_REQUIRES(
- ctx, filenames_tensor->dims() <= 1,
- errors::InvalidArgument("`filenames` must be a scalar or a vector."));
-
- std::vector<string> filenames;
- filenames.reserve(filenames_tensor->NumElements());
- for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
- filenames.push_back(filenames_tensor->flat<string>()(i));
- }
-
- *output = new Dataset(ctx, filenames);
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
- : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- static DataTypeVector* dtypes =
- new DataTypeVector({DT_STRING, DT_STRING});
- return *dtypes;
- }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- static std::vector<PartialTensorShape>* shapes =
- new std::vector<PartialTensorShape>({{}, {}});
- return *shapes;
- }
-
- string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* filenames = nullptr;
- TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
- do {
- if (mdb_cursor_) {
- Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
- key_tensor.scalar<string>()() = string(
- static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
- out_tensors->emplace_back(std::move(key_tensor));
-
- Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
- value_tensor.scalar<string>()() =
- string(static_cast<const char*>(mdb_value_.mv_data),
- mdb_value_.mv_size);
- out_tensors->emplace_back(std::move(value_tensor));
-
- int val;
- val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
- if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- if (val == MDB_NOTFOUND) {
- ResetStreamsLocked();
- ++current_file_index_;
- }
- *end_of_sequence = false;
- return Status::OK();
- }
- if (current_file_index_ == dataset()->filenames_.size()) {
- *end_of_sequence = true;
- return Status::OK();
- }
-
- TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
- } while (true);
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- return errors::Unimplemented(
- "Checkpointing is currently not supported for LMDBDataset.");
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- return errors::Unimplemented(
- "Checkpointing is currently not supported for LMDBDataset.");
- }
-
- private:
- Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (current_file_index_ >= dataset()->filenames_.size()) {
- return errors::InvalidArgument(
- "current_file_index_:", current_file_index_,
- " >= filenames_.size():", dataset()->filenames_.size());
- }
- const string& filename = dataset()->filenames_[current_file_index_];
-
- int val = mdb_env_create(&mdb_env_);
- if (val != MDB_SUCCESS) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
-
- struct stat source_stat;
- if (stat(filename.c_str(), &source_stat) == 0 &&
- (source_stat.st_mode & S_IFREG)) {
- flags |= MDB_NOSUBDIR;
- }
- val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
- if (val != MDB_SUCCESS) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
- if (val != MDB_SUCCESS) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
- if (val != MDB_SUCCESS) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
- if (val != MDB_SUCCESS) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
- if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
- return errors::InvalidArgument(mdb_strerror(val));
- }
- if (val == MDB_NOTFOUND) {
- ResetStreamsLocked();
- }
- return Status::OK();
- }
- void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (mdb_env_ != nullptr) {
- if (mdb_cursor_) {
- mdb_cursor_close(mdb_cursor_);
- mdb_cursor_ = nullptr;
- }
- mdb_dbi_close(mdb_env_, mdb_dbi_);
- mdb_txn_abort(mdb_txn_);
- mdb_env_close(mdb_env_);
- mdb_txn_ = nullptr;
- mdb_dbi_ = 0;
- mdb_env_ = nullptr;
- }
- }
- mutex mu_;
- size_t current_file_index_ GUARDED_BY(mu_) = 0;
- MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
- MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
- MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
- MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
-
- MDB_val mdb_key_ GUARDED_BY(mu_);
- MDB_val mdb_value_ GUARDED_BY(mu_);
- };
-
- const std::vector<string> filenames_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
deleted file mode 100644
index 96f1dd0059..0000000000
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ /dev/null
@@ -1,481 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <deque>
-
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/function.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_op_kernel.h"
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/util/device_name_utils.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-struct BufferElement {
- // The producer sets `status` if getting the input element fails.
- Status status;
- // The buffered data element.
- std::vector<Tensor> value;
-};
-
-using FunctionBufferCallback = std::function<void(const BufferElement&)>;
-
-class FunctionBufferingResource : public ResourceBase {
- public:
- FunctionBufferingResource(FunctionLibraryRuntime* lib,
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
- const NameAttrList& func, int64 buffer_size,
- const string& source_device,
- const string& target_device,
- const std::vector<Tensor>& func_args,
- const DataTypeVector& output_types)
- : lib_(lib),
- pflr_(std::move(pflr)),
- func_(func),
- buffer_size_(buffer_size),
- source_device_(source_device),
- target_device_(target_device),
- func_args_(func_args),
- output_types_(output_types),
- handle_(kInvalidHandle),
- is_buffering_(false),
- end_of_sequence_(false),
- cancelled_(false) {}
-
- ~FunctionBufferingResource() override {
- Cancel();
- }
-
- string DebugString() override {
- return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_,
- "; target_device: ", target_device_);
- }
-
- // Instantiates the function the first time it's called. After that it caches
- // the handle.
- Status Instantiate() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- // Re-use existing handle if it's been set, effectively caching it.
- if (handle_ != kInvalidHandle) {
- return Status::OK();
- }
- AttrValueMap attr_values = func_.attr();
- FunctionLibraryRuntime::InstantiateOptions opts;
- opts.target = target_device_;
- return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts,
- &handle_);
- }
-
- // Returns true if we've got to the end of the sequence and exhausted the
- // buffer.
- bool Finished() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- return end_of_sequence_ && buffer_.empty();
- }
-
- // Cancels any buffering / prefetching going on.
- void Cancel() LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- cancelled_ = true;
- while (is_buffering_) {
- cond_var_.wait(l);
- }
- }
-
- // Cancels all pending operations and then clears out the state.
- void Reset() LOCKS_EXCLUDED(mu_) {
- Cancel();
- mutex_lock l(mu_);
- buffer_.clear();
- requests_.clear();
- is_buffering_ = false;
- end_of_sequence_ = false;
- cancelled_ = false;
- }
-
- // If the buffer has anything, runs `callback` on the first element in the
- // buffer, else schedules the `callback` to be called. Requires `args` and
- // `lib` in case more function calls need to be scheduled.
- void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) {
- bool start_buffering = false;
- bool produced_output = false;
- BufferElement buffer_element;
- {
- mutex_lock l(mu_);
- if (!is_buffering_ && !end_of_sequence_) {
- start_buffering = true;
- }
- if (!buffer_.empty()) {
- produced_output = true;
- std::swap(buffer_element, buffer_.front());
- buffer_.pop_front();
- } else {
- produced_output = false;
- requests_.push_back(std::move(callback));
- }
- }
- if (produced_output) {
- callback(buffer_element);
- }
- if (start_buffering) {
- FillBuffer();
- }
- }
-
- private:
- void FillBuffer() LOCKS_EXCLUDED(mu_) {
- FunctionLibraryRuntime::Handle handle;
- std::vector<FunctionBufferCallback> cancellation_callbacks;
- std::vector<BufferElement> cancellation_buffer_elements;
- bool cancelled = false;
- {
- mutex_lock l(mu_);
- handle = handle_;
- if (cancelled_) {
- cancelled = true;
- // Run through and fulfill all pending requests, if possible.
- while (!requests_.empty()) {
- if (!buffer_.empty()) {
- cancellation_buffer_elements.push_back(std::move(buffer_.front()));
- buffer_.pop_front();
- cancellation_callbacks.push_back(std::move(requests_.front()));
- requests_.pop_front();
- } else {
- LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: "
- << requests_.size() << " requests";
- break;
- }
- }
- is_buffering_ = false;
- } else {
- is_buffering_ = true;
- }
- }
- if (cancelled) {
- for (int i = 0; i < cancellation_callbacks.size(); ++i) {
- cancellation_callbacks[i](cancellation_buffer_elements[i]);
- }
- cond_var_.notify_all();
- return;
- }
- FunctionLibraryRuntime::Options opts;
- // Copied from CapturedFunction::generate_step_id();
- opts.step_id = -std::abs(static_cast<int64>(random::New64()));
- opts.source_device = source_device_;
- AllocatorAttributes arg_alloc_attr;
- arg_alloc_attr.set_on_host(true);
- opts.args_alloc_attrs.push_back(arg_alloc_attr);
- for (const auto& dtype : output_types_) {
- AllocatorAttributes ret_alloc_attrs;
- if (DataTypeAlwaysOnHost(dtype)) {
- ret_alloc_attrs.set_on_host(true);
- }
- opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
- }
- if (opts.source_device != target_device_) {
- opts.remote_execution = true;
- }
- opts.create_rendezvous = true;
- auto* rets = new std::vector<Tensor>;
- lib_->Run(opts, handle, func_args_, rets,
- [this, rets](const Status& status) {
- FunctionBufferCallback callback = nullptr;
- BufferElement buffer_front;
- bool restart_buffering = false;
- {
- mutex_lock l(mu_);
- BufferElement buffer_element;
- buffer_element.status = status;
- if (status.ok()) {
- buffer_element.value.swap(*rets);
- } else {
- end_of_sequence_ = true;
- is_buffering_ = false;
- }
- buffer_.push_back(std::move(buffer_element));
- if (!requests_.empty()) {
- buffer_front = std::move(buffer_.front());
- buffer_.pop_front();
- callback = std::move(requests_.front());
- requests_.pop_front();
- }
- if (buffer_.size() < buffer_size_ && !end_of_sequence_) {
- restart_buffering = true;
- } else {
- // When the buffer is full, we don't want to call
- // FillBuffer() unless we're in cancellation phase in which
- // case FillBuffer() will do the final cleanup post
- // cancellation.
- if (cancelled_) {
- restart_buffering = true;
- }
- is_buffering_ = false;
- }
- }
- if (callback != nullptr) {
- callback(buffer_front);
- }
- if (restart_buffering) {
- FillBuffer();
- }
- });
- }
-
- mutex mu_;
- FunctionLibraryRuntime* lib_;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
- NameAttrList func_;
- const int64 buffer_size_;
- const string source_device_;
- const string target_device_;
- const std::vector<Tensor> func_args_;
- const DataTypeVector output_types_;
- FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
- std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
- std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
- bool is_buffering_ GUARDED_BY(mu_);
- bool end_of_sequence_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_);
- condition_variable cond_var_;
-};
-
-class FunctionBufferResourceHandleOp : public OpKernel {
- public:
- explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), flib_def_(nullptr) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- }
-
- ~FunctionBufferResourceHandleOp() override {
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->Delete<FunctionBufferingResource>(cinfo_.container(),
- cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- }
- }
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor* string_arg;
- OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg));
- std::vector<Tensor> func_args;
- func_args.push_back(*string_arg);
-
- const string& source_device = ctx->device()->name();
-
- // Obtain and canonicalize target_device.
- const Tensor* target_arg;
- OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
- string target_device;
- OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
- target_arg->scalar<string>()(), source_device,
- &target_device));
-
- FunctionLibraryRuntime* lib = ctx->function_library();
- OP_REQUIRES(ctx, lib != nullptr,
- errors::Internal("No function library is provided."));
-
- mutex_lock l(mu_);
- if (!initialized_) {
- OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
- FunctionLibraryRuntime* clone_lib;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr;
- OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib));
- // Create the resource.
- FunctionBufferingResource* buffer;
- OP_REQUIRES_OK(
- ctx,
- ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>(
- cinfo_.container(), cinfo_.name(), &buffer,
- [clone_lib, &pflr, &source_device, &target_device, func_args,
- this](FunctionBufferingResource** ptr) {
- *ptr = new FunctionBufferingResource(
- clone_lib, std::move(pflr), func_, buffer_size_,
- source_device, target_device, func_args, output_types_);
- return Status::OK();
- }));
- core::ScopedUnref s(buffer);
- OP_REQUIRES_OK(ctx, buffer->Instantiate());
- initialized_ = true;
- }
-
- OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
- ctx, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<FunctionBufferingResource>()));
- }
-
- private:
- mutex mu_;
- ContainerInfo cinfo_ GUARDED_BY(mu_);
- bool initialized_ GUARDED_BY(mu_) = false;
- std::unique_ptr<FunctionLibraryDefinition> flib_def_;
- NameAttrList func_;
- int64 buffer_size_;
- string container_;
- string name_;
- DataTypeVector output_types_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
- .Device(DEVICE_CPU)
- .HostMemory("resource")
- .HostMemory("string_arg")
- .HostMemory("target_device"),
- FunctionBufferResourceHandleOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
- .Device(DEVICE_GPU)
- .HostMemory("resource")
- .HostMemory("string_arg")
- .HostMemory("target_device"),
- FunctionBufferResourceHandleOp);
-#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
- .Device(DEVICE_SYCL)
- .HostMemory("resource")
- .HostMemory("string_arg")
- .HostMemory("target_device"),
- FunctionBufferResourceHandleOp);
-#endif // TENSORFLOW_USE_SYCL
-
-// Prefetches and fills up a buffer by calling a function that provides the
-// elements to buffer.
-class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
- public:
- explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx)
- : AsyncOpKernel(ctx) {}
-
- ~FunctionBufferingResourceGetNextOp() override {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- ResourceHandle handle;
- OP_REQUIRES_OK_ASYNC(
- ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done);
- FunctionBufferingResource* buffer = nullptr;
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
- done);
-
- if (buffer->Finished()) {
- buffer->Unref();
- ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
- done();
- return;
- }
-
- FunctionBufferCallback callback =
- [ctx, buffer, done](const BufferElement& buffer_element) {
- Status s = buffer_element.status;
- if (!s.ok()) {
- ctx->SetStatus(s);
- buffer->Unref();
- done();
- return;
- }
- for (size_t i = 0; i < buffer_element.value.size(); ++i) {
- ctx->set_output(i, buffer_element.value[i]);
- }
- buffer->Unref();
- done();
- };
- buffer->MaybeGet(std::move(callback));
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
- .Device(DEVICE_CPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceGetNextOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
- .Device(DEVICE_GPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceGetNextOp);
-#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
- .Device(DEVICE_SYCL)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceGetNextOp);
-#endif // TENSORFLOW_USE_SYCL
-
-// Resets the FunctionBufferingResource, cancelling all pending requests and
-// clearing out the buffer.
-class FunctionBufferingResourceResetOp : public OpKernel {
- public:
- explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- ~FunctionBufferingResourceResetOp() override {}
-
- void Compute(OpKernelContext* ctx) override {
- ResourceHandle handle;
- OP_REQUIRES_OK(ctx,
- HandleFromInput(ctx, "function_buffer_resource", &handle));
- FunctionBufferingResource* buffer = nullptr;
- OP_REQUIRES_OK(
- ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer));
- core::ScopedUnref s(buffer);
-
- buffer->Reset();
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
- .Device(DEVICE_CPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceResetOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
- .Device(DEVICE_GPU)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceResetOp);
-#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
- .Device(DEVICE_SYCL)
- .HostMemory("function_buffer_resource"),
- FunctionBufferingResourceResetOp);
-#endif // TENSORFLOW_USE_SYCL
-
-class IteratorGetDeviceOp : public OpKernel {
- public:
- using OpKernel::OpKernel;
-
- void Compute(OpKernelContext* ctx) override {
- // NOTE(mrry): We do not currently Validate that the handle
- // corresponds to a real IteratorResource, because that symbol is
- // not exposed from the framework library.
- Tensor* device_name_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &device_name_t));
- // NOTE(mrry): Since the operation's input is a resource, we must be
- // colocated with it, and so we can simply return the current device's
- // name without looking at the input.
- device_name_t->scalar<string>()() = ctx->device()->name();
- }
-};
-
-REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
- IteratorGetDeviceOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
deleted file mode 100644
index 30fa97a636..0000000000
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ /dev/null
@@ -1,219 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/util/work_sharder.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-class ThreadPoolResource : public ResourceBase {
- public:
- ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
- const string& name, int num_threads, bool low_latency_hint,
- int max_intra_op_parallelism)
- : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
- max_intra_op_parallelism_(max_intra_op_parallelism) {}
-
- // Schedules fn() for execution in the pool of threads.
- void Schedule(std::function<void()> fn) {
- if (max_intra_op_parallelism_ < 0) {
- thread_pool_.Schedule(std::move(fn));
- } else {
- thread_pool_.Schedule(std::bind(
- [this](std::function<void()> bound_fn) {
- // TODO(mrry): Consider moving this thread-local configuration to
- // the threads themselves.
- ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
- bound_fn();
- },
- std::move(fn)));
- }
- }
-
- string DebugString() override { return "ThreadPoolResource"; }
-
- private:
- thread::ThreadPool thread_pool_;
- const int max_intra_op_parallelism_;
-};
-
-// Creates a handle to a ThreadPool resource. Note that we don't use
-// ResourceOpKernel here because the ThreadPoolResource constructor requires
-// access to `OpKernelContext::env()`, which isn't provided by
-// `ResourceOpKernel<T>::CreateResource()`.
-class ThreadPoolHandleOp : public OpKernel {
- public:
- explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
- &max_intra_op_parallelism_));
- OP_REQUIRES(
- ctx, num_threads_ > 0,
- errors::InvalidArgument("`num_threads` must be greater than zero."));
- }
-
- // The resource is deleted from the resource manager only when it is private
- // to kernel. Ideally the resource should be deleted when it is no longer held
- // by anyone, but it would break backward compatibility.
- ~ThreadPoolHandleOp() override {
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- }
- }
- }
-
- void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
- if (!initialized_) {
- ResourceMgr* mgr = ctx->resource_manager();
- OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
- ThreadPoolResource* resource;
- OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>(
- cinfo_.container(), cinfo_.name(), &resource,
- [this, ctx](ThreadPoolResource** ret)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- *ret = new ThreadPoolResource(
- ctx->env(), {}, display_name_,
- num_threads_, max_intra_op_parallelism_,
- false /* low_latency_hint */);
- return Status::OK();
- }));
- initialized_ = true;
- }
- OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
- ctx, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<ThreadPoolResource>()));
- }
-
- private:
- mutex mu_;
- ContainerInfo cinfo_ GUARDED_BY(mu_);
- bool initialized_ GUARDED_BY(mu_) = false;
- string display_name_;
- int num_threads_;
- int max_intra_op_parallelism_;
-};
-
-class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- ThreadPoolResource* threadpool_resource;
- OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
- &threadpool_resource));
- core::ScopedUnref unref_iterator(threadpool_resource);
-
- *output = new Dataset(ctx, input, threadpool_resource);
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input,
- ThreadPoolResource* threadpool)
- : DatasetBase(DatasetContext(ctx)),
- input_(input),
- threadpool_(threadpool) {
- input_->Ref();
- threadpool_->Ref();
- }
-
- ~Dataset() override {
- input_->Unref();
- threadpool_->Unref();
- }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override {
- return "ThreadPoolDatasetOp::Dataset";
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- ThreadPoolResource* pool = dataset()->threadpool_;
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = [pool](std::function<void()> c) {
- pool->Schedule(std::move(c));
- };
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
- params.lib = ctx->lib();
- params.function_library = ctx->function_library();
- params.allocator_getter = ctx->allocator_getter();
- IteratorContext threadpool_ctx(params);
- return input_impl_->GetNext(&threadpool_ctx, out_tensors,
- end_of_sequence);
- }
-
- private:
- std::unique_ptr<IteratorBase> input_impl_;
- };
-
- const DatasetBase* const input_;
- ThreadPoolResource* const threadpool_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
- ThreadPoolHandleOp);
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
- ThreadPoolDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
deleted file mode 100644
index 57fc5697a4..0000000000
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ /dev/null
@@ -1,223 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/partial_tensor_shape.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/hash/hash.h"
-
-namespace tensorflow {
-namespace data {
-namespace {
-
-// See documentation in ../ops/dataset_ops.cc for a high-level
-// description of the following op.
-
-class UniqueDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit UniqueDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
- errors::InvalidArgument("UniqueDataset only supports "
- "inputs with a single component."));
-
- DataType input_dtype = input->output_dtypes()[0];
- OP_REQUIRES(ctx,
- input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
- input_dtype == DT_STRING,
- errors::InvalidArgument(
- "UniqueDataset only supports inputs with a single "
- "`tf.int32`, `tf.int64`, or `tf.string` component."));
-
- *output = new Dataset(ctx, input);
- }
-
- private:
- class Dataset : public DatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input)
- : DatasetBase(DatasetContext(ctx)), input_(input) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Unique")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
-
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override {
- return strings::StrCat("UniqueDatasetOp::Dataset");
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const typename Iterator::Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
- bool saw_new_value;
- do {
- saw_new_value = false;
- out_tensors->clear();
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (*end_of_sequence) {
- break;
- }
- DCHECK_EQ(1, out_tensors->size());
- saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
- } while (!saw_new_value);
- return Status::OK();
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- if (input_impl_) {
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- } else {
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("input_impl_empty"), ""));
- }
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name("unique_elements_size"), unique_elements_.size()));
- size_t i = 0;
- for (const Tensor& t : unique_elements_) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat("unique_elements[", i++, "]")), t));
- }
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- if (!reader->Contains(full_name("input_impl_empty"))) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- } else {
- input_impl_.reset();
- }
- int64 num_unique_elements;
- unique_elements_.clear();
- TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
- &num_unique_elements));
- for (int64 i = 0; i < num_unique_elements; ++i) {
- Tensor unique_element;
- TF_RETURN_IF_ERROR(reader->ReadTensor(
- full_name(strings::StrCat("unique_elements[", i, "]")),
- &unique_element));
- auto insert_result = unique_elements_.insert(unique_element);
- if (!insert_result.second) {
- return errors::InvalidArgument(
- "Checkpoint contained two unique elements with the same "
- "value.");
- }
- }
- return Status::OK();
- }
-
- private:
- struct TensorHash {
- size_t operator()(const Tensor& t) const {
- if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
- return Hash64(t.tensor_data().data(), t.tensor_data().size());
- } else {
- DCHECK_EQ(DT_STRING, t.dtype());
- auto flat_t = t.flat<string>();
- uint64 hash = 0;
- for (int64 i = 0; i < t.NumElements(); ++i) {
- hash = Hash64Combine(hash, Hash64(flat_t(i)));
- }
- return static_cast<size_t>(hash);
- }
- }
- };
-
- struct TensorKeyEqual {
- bool operator()(const Tensor& lhs, const Tensor& rhs) const {
- if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
- return false;
- }
- switch (lhs.dtype()) {
-#define HANDLE_TYPE(T) \
- case T: \
- do { \
- auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
- auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
- for (int64 i = 0; i < lhs.NumElements(); ++i) { \
- if (lhs_flat(i) != rhs_flat(i)) { \
- return false; \
- } \
- } \
- return true; \
- } while (0)
-
- HANDLE_TYPE(DT_INT32);
- HANDLE_TYPE(DT_INT64);
- HANDLE_TYPE(DT_STRING);
- default:
- LOG(FATAL) << "UniqueDataset unhandled data type: "
- << DataTypeString(lhs.dtype());
- }
- }
- };
-
- mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
- GUARDED_BY(mu_);
- };
-
- const DatasetBase* const input_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
- UniqueDatasetOp);
-
-} // namespace
-} // namespace data
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
deleted file mode 100644
index d1a771f005..0000000000
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ /dev/null
@@ -1,208 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("DirectedInterleaveDataset")
- .Input("selector_input_dataset: variant")
- .Input("data_input_datasets: N * variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
-
-selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines
- which of the `N` data inputs should produce the next output element.
-data_input_datasets: `N` datasets with the same type that will be interleaved
- according to the values of `selector_input_dataset`.
-)doc");
-
-REGISTER_OP("CSVDataset")
- .Input("filenames: string")
- .Input("compression_type: string")
- .Input("buffer_size: int64")
- .Input("header: bool")
- .Input("field_delim: string")
- .Input("use_quote_delim: bool")
- .Input("na_value: string")
- .Input("select_cols: int64")
- .Input("record_defaults: output_types")
- .Output("handle: variant")
- .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- // `filenames` must be a scalar or a vector.
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
- // `compression_type`, `buffer_size`, `header`, `field_delim`,
- // `use_quote_delim`, `na_value` must be scalars
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
- // `select_cols` must be a vector
- TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
- // `record_defaults` must be lists of scalars
- for (size_t i = 8; i < c->num_inputs(); ++i) {
- shape_inference::ShapeHandle v;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
- if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
- return errors::InvalidArgument(
- "Shape of a default must be a length-0 or length-1 vector, or a "
- "scalar.");
- }
- }
- return shape_inference::ScalarShape(c);
- });
-
-REGISTER_OP("IgnoreErrorsDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
-
-REGISTER_OP("UniqueDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the unique elements of `input_dataset`.
-)doc");
-
-REGISTER_OP("IteratorGetDevice")
- .Input("resource: resource")
- .Output("device: string")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Returns the name of the device on which `resource` has been placed.
-)doc");
-
-REGISTER_OP("FunctionBufferingResource")
- .Input("string_arg: string")
- .Input("target_device: string")
- .Output("resource: resource")
- .Attr("shared_name: string")
- .Attr("container: string")
- .Attr("f: func")
- .Attr("buffer_size: int")
- .Attr("output_types: list(type)")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Creates a resource that fills up a buffer by making function calls.
-
-string_arg: String argument to the function call.
-target_device: Target device to execute the function on.
-resource: Handle to the resource created.
-f: Function to be executed.
-buffer_size: Size of the buffer.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceGetNext")
- .Input("function_buffer_resource: resource")
- .Attr("output_types: list(type)")
- .Output("output: output_types")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Gets the next element from a FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-output: A list of return values.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceReset")
- .Input("function_buffer_resource: resource")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Resets the FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-)doc");
-
-REGISTER_OP("ThreadPoolDataset")
- .Input("input_dataset: variant")
- .Input("thread_pool: resource")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that uses a custom thread pool to compute `input_dataset`.
-
-handle: A resource produced by the ThreadPoolHandle op.
-)doc");
-
-REGISTER_OP("ThreadPoolHandle")
- .Output("handle: resource")
- .SetShapeFn(shape_inference::ScalarShape)
- .Attr("num_threads: int")
- .Attr("max_intra_op_parallelism: int = 1")
- .Attr("display_name: string")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Doc(R"doc(
-Creates a custom thread pool with the given number of threads.
-
-handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
-num_threads: The number of threads in the thread pool.
-max_intra_op_parallelism: The maximum degree of parallelism to use within
- operations that execute on this threadpool.
-display_name: A human-readable name for the threads that may be visible in
- some visualizations.
-)doc");
-
-REGISTER_OP("AssertNextDataset")
- .Input("input_dataset: variant")
- .Input("transformations: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- // transformations should be a vector.
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
- return shape_inference::ScalarShape(c);
- });
-
-REGISTER_OP("LMDBDataset")
- .Input("filenames: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape);
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
deleted file mode 100644
index cd9b7c68a0..0000000000
--- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("IdentityIndexedDataset")
- .Input("size: uint64")
- .Output("handle: variant")
- .SetIsStateful()
- .SetShapeFn(
- shape_inference::ScalarShape); // TODO(saeta): check input shapes.
-
-///////////////////////////////////////////////////////////////////////////////
-// IndexedDataset Internals
-///////////////////////////////////////////////////////////////////////////////
-
-// Creates the handle.
-REGISTER_OP("MaterializedIndexDatasetHandle")
- .Output("handle: resource")
- .Attr("container: string")
- .Attr("shared_name: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
-
-// Actually materialize the materialize handle.
-REGISTER_OP("IndexedDatasetMaterialize")
- .Input("dataset: variant")
- .Input("materialized: resource")
- .SetShapeFn(shape_inference::NoOutputs);
-
-namespace {
-
-Status GetShapeFn(shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
-}
-
-} // namespace
-
-REGISTER_OP("IndexedDatasetGet")
- .Input("materialized: resource")
- .Input("index: uint64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(GetShapeFn)
- .Doc(R"doc(
-Gets the element at `index` from `materialized` IndexedDataset.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ce52c990ce..33784afa3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -31,6 +31,7 @@ py_test(
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -54,6 +55,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -77,6 +79,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -97,6 +100,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
],
@@ -112,6 +116,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:random_seed",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -130,6 +135,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -139,12 +145,12 @@ py_test(
name = "indexed_dataset_ops_test",
srcs = ["indexed_dataset_ops_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
"//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -170,6 +176,7 @@ py_test(
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@six_archive//:six",
],
@@ -189,6 +196,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/estimator:estimator_py",
],
@@ -215,6 +223,7 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//third_party/py/numpy",
],
)
@@ -240,6 +249,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -259,6 +269,7 @@ py_test(
"//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -283,6 +294,7 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -301,6 +313,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
@@ -316,6 +329,7 @@ cuda_py_test(
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -341,6 +355,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -366,6 +381,7 @@ py_library(
"//tensorflow/python:lib",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:readers",
],
@@ -412,6 +428,7 @@ py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -434,6 +451,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
@@ -454,6 +472,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -471,6 +490,7 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -490,6 +510,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/kernel_tests:test_base",
"@org_sqlite//:python",
],
)
@@ -534,6 +555,7 @@ py_library(
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
],
)
@@ -550,6 +572,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -568,6 +591,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -588,6 +612,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -605,17 +630,8 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
],
)
-
-py_library(
- name = "test_utils",
- srcs = ["test_utils.py"],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/util:nest",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index e2508de9e9..fed7de5f2b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -40,12 +41,8 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
@@ -723,7 +720,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-class RestructuredDatasetTest(test.TestCase):
+class RestructuredDatasetTest(test_base.DatasetTestBase):
def test_assert_element_shape(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 48971f2ccc..ae401f786c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class GroupByReducerTest(test.TestCase):
+class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
@@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase):
self.assertEqual(y, 45)
-class GroupByWindowTest(test.TestCase):
+class GroupByWindowTest(test_base.DatasetTestBase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase):
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
-class BucketTest(test.TestCase):
+class BucketTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@@ -570,7 +571,7 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None])
-class BucketBySequenceLength(test.TestCase):
+class BucketBySequenceLength(test_base.DatasetTestBase):
def testBucket(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
index f8e74e4583..5b3c512b64 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -43,37 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test.TestCase):
-
- def _get_next(self, dataset):
- # Returns a no argument function whose result is fed to self.evaluate to
- # yield the next element
- it = dataset.make_one_shot_iterator()
- if context.executing_eagerly():
- return it.get_next
- else:
- get_next = it.get_next()
- return lambda: get_next
-
- def _assert_datasets_equal(self, ds1, ds2):
- assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
- '%s') % (ds1.output_shapes,
- ds2.output_shapes)
- assert ds1.output_types == ds2.output_types
- assert ds1.output_classes == ds2.output_classes
- next1 = self._get_next(ds1)
- next2 = self._get_next(ds2)
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = self.evaluate(next1())
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- self.evaluate(next2())
- break
- op2 = self.evaluate(next2())
- self.assertAllEqual(op1, op2)
+class CsvDatasetOpTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
dataset_actual, dataset_expected = self._make_test_datasets(
inputs, **kwargs)
- self._assert_datasets_equal(dataset_actual, dataset_expected)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
dataset,
@@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase):
expected_err_re=None):
if expected_err_re is None:
# Verify that output is expected, without errors
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
@@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase):
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
@@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- self._assert_datasets_equal(
+ self.assertDatasetsEqual(
ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1))
@@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase):
ds = readers.make_csv_dataset(
file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self._get_next(ds)
+ nxt = self.getNext(ds)
result = list(self.evaluate(nxt()).values())
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
index a2ab3de52e..722e87e555 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index eb110324d1..bc10c21472 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import test
-class DirectedInterleaveDatasetTest(test.TestCase):
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index f3968cdc15..cc22ea1df7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import get_single_element
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase, parameterized.TestCase):
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Zero", 0, 1),
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
index 9c508d686d..d4d3d4adb2 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -19,29 +19,30 @@ from __future__ import print_function
import unittest
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import test
-class IndexedDatasetOpsTest(test.TestCase):
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
def testLowLevelIndexedDatasetOps(self):
- identity = gen_dataset_ops.identity_indexed_dataset(
+ identity = ged_ops.experimental_identity_indexed_dataset(
ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = gen_dataset_ops.materialized_index_dataset_handle(
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
container="",
shared_name="",
output_types=[dtypes.uint64],
output_shapes=[[]])
- materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
index = array_ops.placeholder(dtypes.uint64)
- get_op = gen_dataset_ops.indexed_dataset_get(
+ get_op = ged_ops.experimental_indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index b9e74dfddb..28bd670ab5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -25,6 +25,7 @@ import time
from six.moves import zip_longest
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test.TestCase):
+class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 7e2326bd17..58a1d7c93b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-class CheckpointInputPipelineHookTest(test.TestCase):
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
@staticmethod
def _model_fn(features, labels, mode, config):
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 1cc5ddc9a2..d2a72272db 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import shutil
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
-class LMDBDatasetTest(test.TestCase):
+class LMDBDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(LMDBDatasetTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index e8519381d6..385c4ef6ea 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.data.python.ops import error_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -41,7 +42,7 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 25aea0393f..751e6d5b30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -21,6 +21,7 @@ import time
from tensorflow.contrib.data.python.ops import map_defun
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapDefunTest(test.TestCase):
+
+class MapDefunTest(test_base.DatasetTestBase):
def testMapDefunSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
index 1ae92bdeff..d7b5edcd9a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -15,6 +15,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -31,6 +32,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -57,7 +59,6 @@ py_test(
srcs = ["map_vectorization_test.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
@@ -67,6 +68,7 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -85,6 +87,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -102,6 +105,7 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -121,6 +125,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -137,6 +142,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -151,6 +157,7 @@ py_test(
"//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
index d10da80442..fe1b5280ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -18,12 +18,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class AssertNextDatasetTest(test.TestCase):
+class AssertNextDatasetTest(test_base.DatasetTestBase):
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
index 9518c2e1ad..b43efb5c7c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -31,7 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
index e75edf6086..e9e3fc81e5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
index dd547db086..f7907eb890 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
index 5b493f44c9..a5ea85f454 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -22,9 +22,9 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import test_utils
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _get_test_datasets(self,
base_dataset,
@@ -85,7 +85,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
num_parallel_calls)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationBadMapFn(self):
# Test map functions that give an error
@@ -112,7 +112,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
# TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(optimized, unoptimized)
+ self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
@@ -124,7 +124,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
@@ -138,7 +138,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
@@ -148,7 +148,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
index 3b62a7e468..33c250ab2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
@@ -23,12 +23,13 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test.TestCase):
+class ModelDatasetTest(test_base.DatasetTestBase):
def testModelMap(self):
k = 1024 * 1024
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
index 507feda3ad..b9e60cfa4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class NoopEliminationTest(test.TestCase):
+class NoopEliminationTest(test_base.DatasetTestBase):
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
index a3fb824ce9..04f499f8c5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -28,7 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase):
+class OptimizeDatasetTest(test_base.DatasetTestBase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
index c4623bca73..66ccaceea5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test.TestCase):
+class ParseExampleTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 33a64ea767..7a6a7a709a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -22,6 +22,7 @@ import threading
from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test.TestCase):
+class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
@@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
-class PrefetchToDeviceTest(test.TestCase):
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
-class CopyToDeviceTest(test.TestCase):
+class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index db8fe6aa1b..2e901587f4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import counter
from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index ed75b27a44..66ed547b6d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
@@ -242,7 +243,7 @@ class ReadBatchFeaturesTest(
self.assertEqual(32, shape[0])
-class MakeCsvDatasetTest(test.TestCase):
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset(
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index 08b9f03816..f443b5501b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -25,6 +25,7 @@ import zlib
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
@@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class FixedLengthRecordDatasetTestBase(test.TestCase):
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing FixedLengthRecordDataset."""
def setUp(self):
@@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase):
return filenames
-class ReadBatchFeaturesTestBase(test.TestCase):
+class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
@@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
self.assertAllEqual(expected_batch[i], actual_batch[i])
-class TextLineDatasetTestBase(test.TestCase):
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TextLineDataset."""
def _lineText(self, f, l):
@@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase):
return filenames
-class TFRecordDatasetTestBase(test.TestCase):
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TFRecordDataset."""
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index 16b1441baa..32474bd411 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -24,6 +24,7 @@ import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.data.python.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -57,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test.TestCase, parameterized.TestCase):
+class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index dde678bd54..bdf80eae4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
import numpy as np
from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test.TestCase):
+class ScanDatasetTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
index 14cd3e9c4a..a10f85263a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -90,6 +91,16 @@ class StatsDatasetSerializationTest(
lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
None, num_outputs)
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 440e48db30..c97002a255 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -20,13 +20,14 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleAndRepeatTest(test.TestCase):
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 90d18dca2a..c5a7862322 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import sliding
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class SlideDatasetTest(test.TestCase, parameterized.TestCase):
+class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
sliding.sliding_window_batch(
window_size=1, stride=1, window_shift=1, window_stride=1))
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSlideSparse(self):
def _sparse(i):
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
index 1f5c725a92..319a2ea263 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
@@ -24,12 +24,13 @@ import os
import sqlite3
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTestBase(test.TestCase):
+class SqlDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing SqlDataset."""
def _createSqlDataset(self, output_types, num_repeats=1):
@@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase):
9007199254740992.0)])
conn.commit()
conn.close()
-
-
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index b1b4c23510..80f2625927 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -19,10 +19,10 @@ from __future__ import print_function
from tensorflow.core.framework import summary_pb2
-from tensorflow.python.platform import test
+from tensorflow.python.data.kernel_tests import test_base
-class StatsDatasetTestBase(test.TestCase):
+class StatsDatasetTestBase(test_base.DatasetTestBase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
def _assertSummaryContains(self, summary_str, tag):
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
deleted file mode 100644
index 4c3353fe40..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test utilities for tf.data functionality."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DatasetTestBase(test.TestCase):
- """Base class for dataset tests."""
-
- def _assert_datasets_equal(self, dataset1, dataset2):
- # TODO(rachelim): support sparse tensor outputs
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- while True:
- try:
- op1 = sess.run(next1)
- except errors.OutOfRangeError:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next2)
- break
- op2 = sess.run(next2)
-
- op1 = nest.flatten(op1)
- op2 = nest.flatten(op2)
- assert len(op1) == len(op2)
- for i in range(len(op1)):
- self.assertAllEqual(op1[i], op2[i])
-
- def _assert_datasets_raise_same_error(self,
- dataset1,
- dataset2,
- exception_class,
- replacements=None):
- # We are defining next1 and next2 in the same line so that we get identical
- # file:line_number in the error messages
- # pylint: disable=line-too-long
- next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
- # pylint: enable=line-too-long
- with self.cached_session() as sess:
- try:
- sess.run(next1)
- raise ValueError(
- "Expected dataset to raise an error of type %s, but it did not." %
- repr(exception_class))
- except exception_class as e:
- expected_message = e.message
- for old, new, count in replacements:
- expected_message = expected_message.replace(old, new, count)
- # Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exception_class,
- re.escape(expected_message)):
- sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 8d335e87d5..08de3a9143 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import threadpool
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
+class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index f994c8563f..8856ce5afb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test.TestCase):
+class UniqueDatasetTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 8b7b3ac0f7..79134c7bc6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _structuredDataset(self, structure, shape, dtype):
if structure is None:
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index 867ee2ba37..fca546a570 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.data.python.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class TFRecordWriterTest(test.TestCase):
+class TFRecordWriterTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordWriterTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index a14781cd93..5cd1ed542b 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -78,7 +78,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
- ":gen_dataset_ops",
":interleave_ops",
":optimization",
":parsing_ops",
@@ -86,6 +85,7 @@ py_library(
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
@@ -148,8 +148,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -179,12 +178,11 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
":random_ops",
"//tensorflow/contrib/stateless",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
@@ -199,9 +197,8 @@ py_library(
srcs = ["optimization.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -304,8 +301,7 @@ py_library(
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -321,9 +317,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
@@ -342,47 +337,11 @@ py_library(
],
)
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- ],
-)
-
-tf_kernel_library(
- name = "dataset_ops_kernels",
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/core:framework",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
- name = "contrib_op_loader",
- srcs = ["contrib_op_loader.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- ":dataset_ops_kernels",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- ],
-)
-
py_library(
name = "indexed_dataset_ops",
srcs = ["indexed_dataset_ops.py"],
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
@@ -394,7 +353,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- ":contrib_op_loader",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/data/python/ops/contrib_op_loader.py
deleted file mode 100644
index 8f495a9dc9..0000000000
--- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Python helper for loading contrib ops and kernels."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.util import loader
-from tensorflow.python.platform import resource_loader
-
-_dataset_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 615dbcabd4..f962e623ee 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
def ignore_errors():
@@ -60,7 +59,7 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset
def _as_variant_tensor(self):
- return gen_dataset_ops.ignore_errors_dataset(
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
index cc76ab0850..9c06474a2f 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -19,14 +19,13 @@ from __future__ import print_function
import abc
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
class MaterializedIndexedDataset(object):
@@ -57,7 +56,7 @@ class MaterializedIndexedDataset(object):
A tensor containing the values corresponding to `index`.
"""
# TODO(saeta): nest.pack_sequence_as(...)
- return gen_dataset_ops.indexed_dataset_get(
+ return ged_ops.experimental_indexed_dataset_get(
self._materialized_resource,
index,
output_types=nest.flatten(
@@ -90,16 +89,18 @@ class IndexedDataset(dataset_ops.Dataset):
container = ""
if shared_name is None:
shared_name = ""
- materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes, self.output_classes)))
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
with ops.colocate_with(materialized_resource):
- materializer = gen_dataset_ops.indexed_dataset_materialize(
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
@@ -170,7 +171,7 @@ class IdentityIndexedDataset(IndexedDataset):
return tensor_shape.scalar()
def _as_variant_tensor(self):
- return gen_dataset_ops.identity_indexed_dataset(self._size)
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
def _inputs(self):
return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index bfa3fdf543..1ee9db1aa8 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -18,8 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import random_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
@@ -28,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import deprecation
@@ -167,10 +166,12 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
def _as_variant_tensor(self):
# pylint: disable=protected-access
- return gen_dataset_ops.directed_interleave_dataset(
- self._selector_input._as_variant_tensor(),
- [data_input._as_variant_tensor() for data_input in self._data_inputs],
- **dataset_ops.flat_structure(self))
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
# pylint: enable=protected-access
def _inputs(self):
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
index 3eb172acd5..30348ede36 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/contrib/data/python/ops/optimization.py
@@ -17,12 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
@@ -54,7 +53,7 @@ def model():
Returns:
A `Dataset` transformation function, which can be passed to
- @{tf.data.Dataset.apply}.
+ `tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
@@ -97,7 +96,7 @@ class _AssertNextDataset(dataset_ops.UnaryDataset):
transformations, dtype=dtypes.string, name="transformations")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.assert_next_dataset(
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 58395879e6..46f82e453a 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -19,8 +19,6 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
@@ -32,7 +30,8 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
@@ -64,7 +63,7 @@ def function_buffering_resource(string_arg,
"""
if shared_name is None:
shared_name = ""
- return gen_dataset_ops.function_buffering_resource(
+ return ged_ops.experimental_function_buffering_resource(
string_arg=string_arg,
target_device=target_device,
shared_name=shared_name,
@@ -78,14 +77,14 @@ def function_buffering_resource(string_arg,
def function_buffering_resource_get_next(function_buffer_resource,
output_types,
name=None):
- return gen_dataset_ops.function_buffering_resource_get_next(
+ return ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=function_buffer_resource,
output_types=output_types,
name=name)
def function_buffering_resource_reset(function_buffer_resource, name=None):
- return gen_dataset_ops.function_buffering_resource_reset(
+ return ged_ops.experimental_function_buffering_resource_reset(
function_buffer_resource=function_buffer_resource, name=name)
@@ -136,7 +135,7 @@ class _PrefetchToDeviceIterator(object):
ret = remote_iterator.get_next()
return nest.flatten(sparse.serialize_sparse_tensors(ret))
- iterator_device = gen_dataset_ops.iterator_get_device(
+ iterator_device = ged_ops.experimental_iterator_get_device(
self._input_iterator._iterator_resource)
with ops.device(device):
@@ -162,10 +161,11 @@ class _PrefetchToDeviceIterator(object):
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
self._buffering_resource,
- output_types=nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
ret = sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self.output_types, flat_ret),
@@ -219,7 +219,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
buffer_size):
with ops.device("/device:CPU:0"):
super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
self._resource)
self._device = device
@@ -238,7 +238,8 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
self._buffering_resource = function_buffering_resource(
f=_prefetch_fn,
output_types=self._flat_output_types,
- target_device=gen_dataset_ops.iterator_get_device(self._resource),
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
string_arg=input_iterator_handle,
buffer_size=buffer_size,
shared_name=iterator_ops._generate_shared_name(
@@ -252,7 +253,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
# TODO(b/77291417): Fix
with context.execution_mode(context.SYNC):
with ops.device(self._device):
- ret = gen_dataset_ops.function_buffering_resource_get_next(
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
function_buffer_resource=self._buffering_resource,
output_types=self._flat_output_types)
return sparse.deserialize_sparse_tensors(
@@ -409,12 +410,12 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
"""
# pylint: disable=protected-access
ds_variant = self._input_dataset._as_variant_tensor()
- resource = core_gen_dataset_ops.anonymous_iterator(
+ resource = gen_dataset_ops.anonymous_iterator(
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
with ops.control_dependencies(
- [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return core_gen_dataset_ops.iterator_to_string_handle(resource)
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
@function.Defun()
def _remote_init_func():
@@ -463,7 +464,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
Returns:
Tensor constant 0
"""
- iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
@@ -504,7 +505,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
def _as_variant_tensor(self):
with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
+ return gen_dataset_ops.generator_dataset(
self._init_captured_args,
self._next_captured_args,
self._finalize_captured_args,
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index d9d06e2703..360971e200 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -23,7 +23,6 @@ import csv
import numpy as np
from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.data.python.ops import optimization
from tensorflow.contrib.data.python.ops import parsing_ops
@@ -38,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -629,7 +629,7 @@ class CsvDataset(dataset_ops.DatasetSource):
def _as_variant_tensor(self):
# Constructs graph node for the dataset op.
- return contrib_gen_dataset_ops.csv_dataset(
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
filenames=self._filenames,
record_defaults=self._record_defaults,
buffer_size=self._buffer_size,
@@ -1013,7 +1013,7 @@ class LMDBDataset(dataset_ops.DatasetSource):
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.lmdb_dataset(
+ return gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index 9d165ad52a..f73c3fd9cb 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -19,10 +19,9 @@ from __future__ import print_function
import threading
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import resource_variable_ops
_uid_counter = 0
@@ -47,7 +46,7 @@ class PrivateThreadPool(object):
"""Creates a `PrivateThreadPool` with the given number of threads."""
if context.executing_eagerly():
shared_name = _generate_shared_name("privatethreadpool")
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name,
@@ -55,7 +54,7 @@ class PrivateThreadPool(object):
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device=context.context().device_name)
else:
- self._resource = gen_dataset_ops.thread_pool_handle(
+ self._resource = ged_ops.experimental_thread_pool_handle(
num_threads=num_threads,
max_intra_op_parallelism=max_intra_op_parallelism,
display_name=display_name)
@@ -70,7 +69,7 @@ class _ThreadPoolDataset(dataset_ops.UnaryDataset):
self._thread_pool = thread_pool
def _as_variant_tensor(self):
- return gen_dataset_ops.thread_pool_dataset(
+ return ged_ops.experimental_thread_pool_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._thread_pool._resource, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index bad67a580d..ed363a7090 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,10 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
def unique():
@@ -61,7 +60,7 @@ class _UniqueDataset(dataset_ops.UnaryDataset):
"`tf.int32`, `tf.int64`, or `tf.string` component.")
def _as_variant_tensor(self):
- return gen_dataset_ops.unique_dataset(
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
index 3b50a48336..06940a90d5 100644
--- a/tensorflow/contrib/decision_trees/proto/BUILD
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -17,7 +17,6 @@ tf_proto_library(
name = "generic_tree_model",
srcs = ["generic_tree_model.proto"],
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 91a27f97b7..2e025765e4 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use
important to shuffle your dataset in your `input_fn`.
`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
-`input_fn`. As a result, each worker gets a fraction of your input data.
+`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker
+gets a fraction of your input data.
### Performance Tips
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index e329b964c4..422983dbef 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,6 +22,7 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
+ ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -29,7 +30,6 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -648,6 +648,32 @@ cuda_py_test(
)
py_library(
+ name = "prefetching_ops_v2",
+ srcs = ["prefetching_ops_v2.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_v2_test",
+ srcs = ["prefetching_ops_v2_test.py"],
+ additional_deps = [
+ ":prefetching_ops_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index c900b41e14..9809204f8f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -216,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Configures the object.
Args:
- session_config: a @{tf.ConfigProto}
+ session_config: a `tf.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 244d1fcec8..82ca041cc2 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -59,6 +59,7 @@ from tensorflow.python.training import adagrad
from tensorflow.python.training import adam
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import gradient_descent
+from tensorflow.python.training import rmsprop
from tensorflow.python.util import tf_inspect
@@ -354,6 +355,8 @@ gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+rmsprop_optimizer_v1_fn = NamedObject(
+ "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
adagrad_optimizer_v1_fn]
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index a0b8bde132..3aab2c521f 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -173,13 +173,42 @@ def batch_wrapper(dataset, batch_size, distribution):
return dataset.batch(batch_size)
-def all_combinations():
+def get_model():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+ return model
+
+
+def get_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
+strategies = [combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step]
+
+
+def strategy_combinations():
return combinations.combine(
- distribution=[combinations.default_strategy,
- combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu,
- combinations.mirrored_strategy_with_two_gpus,
- combinations.tpu_strategy_one_step],
+ distribution=strategies,
+ mode=['graph'])
+
+
+def strategy_and_optimizer_combinations():
+ return combinations.combine(
+ distribution=strategies,
+ optimizer=[combinations.adagrad_optimizer_v1_fn,
+ combinations.adam_optimizer_v1_fn,
+ combinations.gradient_descent_optimizer_v1_fn,
+ combinations.rmsprop_optimizer_v1_fn],
mode=['graph'])
@@ -360,9 +389,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_model_with_numpy_arrays(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -392,23 +419,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
# with batch_size
model.predict(inputs, batch_size=8)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -461,23 +482,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae', keras.metrics.CategoricalAccuracy()]
model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = batch_wrapper(dataset, 10, distribution)
+ dataset = get_dataset(distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -486,11 +501,23 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
+ @combinations.generate(strategy_and_optimizer_combinations())
+ def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
+ with self.cached_session():
+ model = get_model()
+
+ loss = 'mse'
+ model.compile(optimizer(), loss, distribute=distribution)
+
+ dataset = get_dataset(distribution)
+
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
+ model.evaluate(dataset, steps=2, verbose=1)
+ model.predict(dataset, steps=2)
+
def test_unsupported_features(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -500,11 +527,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
# Test with validation split
with self.assertRaisesRegexp(
@@ -541,9 +564,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_calling_with_unsupported_predefined_callbacks(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
@@ -552,11 +573,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
- dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = get_dataset(strategy)
def schedule(_):
return 0.001
@@ -580,9 +597,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_dataset_input_shape_validation(self):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
@@ -616,17 +631,13 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
mode=['graph']))
def test_dataset_input_shape_fully_defined(self, distribution):
with self.cached_session():
- x = keras.layers.Input(shape=(3,), name='input')
- y = keras.layers.Dense(4, name='dense')(x)
- model = keras.Model(x, y)
+ model = get_model()
optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
loss = 'mse'
model.compile(optimizer, loss, distribute=distribution)
- inputs = np.zeros((10, 3), dtype=np.float32)
- targets = np.zeros((10, 4), dtype=np.float32)
- dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = get_dataset(distribution)
# Input shapes are not fully known. Batch dimension is unknown as we are
# not using the drop_remainder argument.
dataset = dataset.repeat(100).batch(10)
@@ -698,7 +709,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
class NormalizationLayerWithDistributionStrategyTest(
test.TestCase, parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_batchnorm_correctness(self, distribution):
with self.cached_session():
model = keras.models.Sequential()
@@ -726,7 +737,7 @@ class NormalizationLayerWithDistributionStrategyTest(
class CorrectnessWithDistributionStrategyTest(test.TestCase,
parameterized.TestCase):
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_metric_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
@@ -756,7 +767,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10)
self.assertEqual(history.history['binary_accuracy'], [1.0])
- @combinations.generate(all_combinations())
+ @combinations.generate(strategy_combinations())
def test_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index f7773aff4f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,11 +86,10 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_initializable_iterator()
+ dataset_fn).make_one_shot_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
- self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index d082d5c419..ba147e7824 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,14 +41,6 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
- def _get_iterator(self, ds):
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate(iterator.initializer)
- return iterator
-
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -70,7 +62,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -106,7 +99,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
@@ -165,7 +159,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -249,7 +244,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -342,7 +338,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -435,7 +432,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 945f450387..4d7516063c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
+ auto_shard_dataset: whether to auto-shard the dataset when there are
+ multiple workers.
"""
def __init__(self,
@@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
- prefetch_on_device=None):
+ prefetch_on_device=None,
+ auto_shard_dataset=False):
super(MirroredStrategy, self).__init__()
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ self._auto_shard_dataset = auto_shard_dataset
# Rememeber num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
raise ValueError(
@@ -477,13 +481,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cluster_spec:
return values.MultiWorkerDataset(
partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
+ self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn),
- self._devices,
- self._prefetch_on_device,
- source_device=device_util.resolve("/device:CPU:0"))
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 04c712ce1d..f51e543624 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,15 +300,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- ds = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate([iterator.initializer])
-
- features = iterator.get_next()
+ features = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ ).make_one_shot_iterator().get_next()
with dist.scope():
result = dist.call_for_each_tower(
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 17b7ab74f6..7644acedc9 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,7 +51,6 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
- session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 3064433129..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,11 +42,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- ds = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -55,7 +52,6 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
- sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
new file mode 100644
index 0000000000..8d949943b7
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -0,0 +1,232 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extension of prefetching_ops to support more than one device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest as data_nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.util import nest
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset.
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ devices: Devices on which to prefetch.
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be shared
+ under the given name across multiple sessions that share the same devices
+ (e.g. when using a remote server). Only used if one_shot is False.
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ devices,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+ self._devices = devices
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ target_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+ self._buffering_resources = []
+ for device in nest.flatten(self._devices):
+ with ops.device(device):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)),
+ target_device=target_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name)
+ self._buffering_resources.append(buffer_resource_handle)
+
+ if not self._one_shot:
+ reset_ops = []
+ for buffer_resource in self._buffering_resources:
+ reset_ops.append(
+ ged_ops.experimental_function_buffering_resource_reset(
+ buffer_resource))
+ with ops.control_dependencies(reset_ops):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_result = []
+ # TODO(priyag): This will fail if the input size (typically number of
+ # batches) is not divisible by number of devices.
+ # How do we handle that more gracefully / let the user know?
+ for buffer_resource in self._buffering_resources:
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ buffer_resource,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ data_nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+ flat_result.append(ret)
+
+ return nest.pack_sequence_as(self._devices, flat_result)
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to other device(s)."""
+
+ def __init__(self, input_dataset, devices, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._devices = devices
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ def make_one_shot_iterator(self):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=True,
+ devices=self._devices,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "make_initializable_iterator is not supported when eager "
+ "execution is enabled.")
+
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ devices=self._devices,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_devices()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ # TODO(priyag): Fix the output types, shapes and classes to match the result
+ # of get_next (which has the additional nesting layer of devices now).
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+def prefetch_to_devices(devices, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `devices`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ devices: A nested structure of devices on which to prefetch the data. It can
+ be a single device name, or a tuple or list of device names.
+ buffer_size: (Optional.) The number of elements to buffer on each device.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
new file mode 100644
index 0000000000..16799104e8
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -0,0 +1,90 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for prefetching_ops_v2."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchingOpsV2Test(test.TestCase):
+
+ def testPrefetchToOneDevice(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToTwoDevicesInAList(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ output = []
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ for _ in range(4):
+ result = sess.run(next_element)
+ self.assertEqual(2, len(result))
+ output.extend(result)
+ self.assertEquals(set(range(8)), set(output))
+
+ def testPrefetchToTwoDevicesWithReinit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 23bf36184f..1b5a4f64e5 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -51,11 +50,7 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- self._iterator = self._distributed_input.make_one_shot_iterator()
- else:
- # TODO(priyag): Expose initializer via some initializer property.
- self._iterator = self._distributed_input.make_initializable_iterator()
+ self._iterator = self._distributed_input.make_one_shot_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 1ff9b9ceec..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,7 +50,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
- sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index a0cd029f51..4955ded4d5 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -26,7 +26,7 @@ import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
@@ -683,7 +683,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next()
+ data_list = self._iterator.get_next(name=name)
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,26 +703,21 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(
- self,
- dataset,
- devices,
- prefetch_on_device=None,
- source_device="/cpu:0",
- ):
+ def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices
- self._source_device = source_device if source_device is not None else "/cpu:0"
# Default to using prefetching in graph mode, unless specified.
- # TODO(rohanj): Enable prefetching in eager mode.
+ # TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- self._dataset = dataset
- if not self._prefetch_on_device:
+ if self._prefetch_on_device:
+ self._dataset = dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
+ else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -730,33 +725,15 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
- # Graph mode prefetching with one shot iterator is disabled.
- if not context.executing_eagerly():
- raise ValueError("Cannot create a one shot iterator. Please use "
- "`make_initializable_iterator()` instead.")
- # Eager mode prefetching would error out in constructor. Only remaining
- # cases are non-prefetching eager / graph mode. We delegate to
- # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, prefetch_on_device=False)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- # Eager mode generates already initialized iterators. Hence we cannot create
- # an initializable iterator.
- if context.executing_eagerly():
- raise ValueError("Cannot create initializable iterator in Eager mode. "
- "Please use `make_one_shot_iterator` instead.")
- if self._prefetch_on_device:
- dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- self._dataset, self._devices, source_device=self._source_device)
- else:
- dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
- dataset_iterator,
- self._devices,
- prefetch_on_device=self._prefetch_on_device)
+ dataset_iterator = self._dataset.make_initializable_iterator()
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -816,7 +793,8 @@ class MultiWorkerDataset(object):
eager mode.
"""
- def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+ auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
@@ -824,6 +802,7 @@ class MultiWorkerDataset(object):
worker_device_map: a dict mapping from each worker to a list of devices
that belong to this worker.
prefetch_on_device: whether to prefetch to devices.
+ auto_shard: whether to auto-shard the dataset.
"""
self._worker_device_map = worker_device_map
self._datasets = {}
@@ -833,13 +812,11 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- worker_input = input_ops.auto_shard_dataset(
- worker_input, len(worker_device_map), i)
+ if auto_shard:
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input,
- worker_devices,
- source_device=worker,
- prefetch_on_device=prefetch_on_device)
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 002d61f46e..ae3e134333 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,11 +349,7 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- if context.executing_eagerly():
- iterator = per_device_dataset.make_one_shot_iterator()
- else:
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -370,14 +366,21 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
+ # With prefetching, we cannot guarantee which input ends up on which
+ # device, so we verify that the complete set seen on all devices is
+ # correct, and equal numbers are distributed to each device.
+ combined_actual = []
+ combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- computed_value = self.evaluate(
- [values.select_device(d, next_element) for d in devices])
- self.assertEqual(expected_value, computed_value)
+ combined_actual.extend(
+ self.evaluate(
+ [values.select_device(d, next_element) for d in devices]))
+ combined_expected.extend(expected_value)
+
+ self.assertEqual(set(combined_expected), set(combined_actual))
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index e344d7a23b..510f292508 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -154,6 +154,8 @@ tf_py_test(
],
tags = [
"no_pip", # b/38283730
+ "noasan", # b/116875897
+ "nomsan",
"notsan", # Flaky: b/30756419
],
)
@@ -177,7 +179,11 @@ tf_py_test(
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
],
- tags = ["notsan"], # b/62863147
+ tags = [
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan", # b/62863147
+ ],
)
py_library(
@@ -276,6 +282,7 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
+ "notsan", # b/116875897
],
)
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f320b53d94..f3ebe3b245 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -26,6 +26,14 @@ config_setting(
},
)
+# Enables inclusion of TensorFlow kernels via the TF Lite Flex delegate.
+# WARNING: This build flag is experimental and subject to change.
+config_setting(
+ name = "with_tflite_flex",
+ define_values = {"with_tflite_flex": "true"},
+ visibility = ["//visibility:public"],
+)
+
cc_library(
name = "schema_fbs_version",
hdrs = ["version.h"],
@@ -157,6 +165,10 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
+ defines = select({
+ ":with_tflite_flex": ["TFLITE_FLEX"],
+ "//conditions:default": [],
+ }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -180,7 +192,12 @@ cc_library(
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
"//tensorflow/contrib/lite/profiling:profiler",
"//tensorflow/contrib/lite/schema:schema_fbs",
- ],
+ ] + select({
+ ":with_tflite_flex": [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ ],
+ "//conditions:default": [],
+ }),
)
cc_library(
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index 4d2437e7d3..d180cb4785 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -28,6 +28,7 @@ android_binary(
srcs = glob([
"app/src/main/java/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets.
# Remove undesired models (and corresponding Activities in source)
# to reduce APK size.
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index db837cf29e..9d2aead266 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -3,12 +3,12 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
def aar_with_jni(name, android_library):
- # Generate dummy AndroidManifest.xml for dummy apk usage
- # (dummy apk is generated by <name>_dummy_app_for_so target below)
- native.genrule(
- name = name + "_binary_manifest_generator",
- outs = [name + "_generated_AndroidManifest.xml"],
- cmd = """
+ # Generate dummy AndroidManifest.xml for dummy apk usage
+ # (dummy apk is generated by <name>_dummy_app_for_so target below)
+ native.genrule(
+ name = name + "_binary_manifest_generator",
+ outs = [name + "_generated_AndroidManifest.xml"],
+ cmd = """
cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
@@ -17,27 +17,28 @@ cat > $(OUTS) <<EOF
</manifest>
EOF
""",
- )
+ )
- # Generate dummy apk including .so files and later we extract out
- # .so files and throw away the apk.
- android_binary(
- name = name + "_dummy_app_for_so",
- manifest = name + "_generated_AndroidManifest.xml",
- custom_package = "dummy.package.for.so",
- deps = [android_library],
- # In some platforms we don't have an Android SDK/NDK and this target
- # can't be built. We need to prevent the build system from trying to
- # use the target in that case.
- tags = ["manual"],
- )
+ # Generate dummy apk including .so files and later we extract out
+ # .so files and throw away the apk.
+ android_binary(
+ name = name + "_dummy_app_for_so",
+ aapt_version = "aapt",
+ manifest = name + "_generated_AndroidManifest.xml",
+ custom_package = "dummy.package.for.so",
+ deps = [android_library],
+ # In some platforms we don't have an Android SDK/NDK and this target
+ # can't be built. We need to prevent the build system from trying to
+ # use the target in that case.
+ tags = ["manual"],
+ )
- native.genrule(
- name = name,
- srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
- outs = [name + ".aar"],
- tags = ["manual"],
- cmd = """
+ native.genrule(
+ name = name,
+ srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
+ outs = [name + ".aar"],
+ tags = ["manual"],
+ cmd = """
cp $(location {}.aar) $(location :{}.aar)
chmod +w $(location :{}.aar)
origdir=$$PWD
@@ -46,4 +47,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*"
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
- )
+ )
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
index 220d6c2159..5ad738389e 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0
android_binary(
name = "TfLiteCameraDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index b2e3a9bd7d..058240aada 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -8,6 +8,7 @@ android_binary(
srcs = [
"OvicBenchmarkerActivity.java",
],
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index 765c3a03ef..689cea03e7 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -37,10 +37,6 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
- return GetTensorDims(data.data(), data.size());
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
index 5e688ce452..9f5b33d217 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -86,35 +86,6 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr ? tensor->data.b : nullptr;
}
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
if (tensor == nullptr) {
return RuntimeShape();
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
index bf2068d320..2ed73ba82d 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc
@@ -21,28 +21,32 @@ namespace {
using ::testing::ElementsAre;
-TEST(TensorTest, GetTensorDims4D) {
- Dims<4> d = GetTensorDims({2, 3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape4D) {
+ RuntimeShape d = GetTensorShape({2, 3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(2, 3, 4, 5));
}
-TEST(TensorTest, GetTensorDims3D) {
- Dims<4> d = GetTensorDims({3, 4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60));
+TEST(TensorTest, GetTensorShape3D) {
+ RuntimeShape d = GetTensorShape({3, 4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(3, 4, 5));
}
-TEST(TensorTest, GetTensorDims2D) {
- Dims<4> d = GetTensorDims({4, 5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20));
+TEST(TensorTest, GetTensorShape2D) {
+ RuntimeShape d = GetTensorShape({4, 5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(4, 5));
}
-TEST(TensorTest, GetTensorDims1D) {
- Dims<4> d = GetTensorDims({5});
- EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1));
- EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5));
+TEST(TensorTest, GetTensorShape1D) {
+ RuntimeShape d = GetTensorShape({5});
+ EXPECT_THAT(
+ std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()),
+ ElementsAre(5));
}
} // namespace
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
index f18a2ca07a..2e5033dab1 100644
--- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
@@ -20,6 +20,7 @@ filegroup(
android_binary(
name = "SmartReplyDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [":assets"],
assets_dir = "",
custom_package = "com.example.android.smartreply",
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index d962a5e12d..36125c198e 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -133,7 +133,8 @@ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \
-tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc
+tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc \
+tensorflow/contrib/makefile/downloads/absl/absl/hash/internal/print_hash_of.cc
ABSL_CC_SRCS := $(filter-out $(ABSL_CC_EXCLUDE_SRCS), $(ABSL_CC_ALL_SRCS))
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index f4ac70eb1a..6a67c6295d 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -377,6 +377,11 @@ py_test(
size = "large",
srcs = ["python/training/shampoo_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan",
+ ],
deps = [
":opt_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index 05bcf2cfa3..a2fd8fbd87 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -54,9 +54,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -105,9 +105,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -164,9 +164,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -254,9 +254,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -310,9 +310,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -383,9 +383,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(sample_size_2, size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -463,9 +463,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(sample_size, size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -533,9 +533,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
gbar_weight = 0.1
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -628,9 +628,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
@@ -705,9 +705,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index c230919168..cb1f707028 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -159,7 +159,12 @@ py_test(
],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = ["no_pip_gpu"], # b/63391119
+ tags = [
+ "no_pip_gpu", # b/63391119
+ "noasan", # b/116875897
+ "nomsan",
+ "notsan",
+ ],
deps = [
":estimators",
":feature_keys",
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 647455ae42..04d17bc123 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -104,7 +104,7 @@ class EvaluationMetricsTests(test.TestCase):
"ticker":
array_ops.reshape(
math_ops.cast(
- variables.Variable(
+ variables.VariableV1(
name="ticker",
initial_value=0,
dtype=dtypes.int64,
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 766466968a..6ce6b779a2 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,7 +55,9 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+
@@AsyncCheckpointSaverHook
+@@TPUInMemoryEvalHook
"""
from __future__ import absolute_import
@@ -65,6 +67,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
+from tensorflow.contrib.tpu.python.tpu.async_checkpoint import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 1bd1a31e11..bc1a0c5284 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -104,18 +104,9 @@ Status RegisterPerTableLoadOpsForAlgorithmBody(
}
}
{
- auto* table_id_attr = op_def->add_attr();
- table_id_attr->set_name("table_id");
- table_id_attr->set_type("int");
- table_id_attr->set_has_minimum(true);
- table_id_attr->set_minimum(-1);
- table_id_attr->mutable_default_value()->set_i(-1);
- }
- {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
- table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -147,11 +138,9 @@ parameters that are loaded from a checkpoint before a training loop is
executed.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ EmbeddingLayerConfiguration proto.
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
-table_id: Index of this table in the EmbeddingLayerConfiguration proto
- (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -160,14 +149,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- // Exactly one must be non-default.
- if ((table_id >= 0) == (!table_name.empty())) {
- return errors::InvalidArgument(
- "exactly one of table_id or table_name must be non-default");
+ if (table_name.empty()) {
+ return errors::InvalidArgument("table_name attribute must be set");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
@@ -241,18 +226,9 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody(
}
}
{
- auto* table_id_attr = op_def->add_attr();
- table_id_attr->set_name("table_id");
- table_id_attr->set_type("int");
- table_id_attr->set_has_minimum(true);
- table_id_attr->set_minimum(-1);
- table_id_attr->mutable_default_value()->set_i(-1);
- }
- {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
- table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -283,11 +259,9 @@ the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ EmbeddingLayerConfiguration proto.
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
-table_id: Index of this table in the EmbeddingLayerConfiguration proto
- (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -296,14 +270,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
- int table_id;
- TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- // Exactly one must be non-default.
- if ((table_id >= 0) == (!table_name.empty())) {
- return errors::InvalidArgument(
- "exactly one of table_id or table_name must be non-default");
+ if (table_name.empty()) {
+ return errors::InvalidArgument("table_name must be non-empty");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index b498599962..8e6e9aa0cd 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,8 +156,7 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)))
- << new_session_response.error_message();
+ stub->NewSession(&context, new_session_request, &new_session_response)));
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index b25d06dda8..292108f949 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -66,8 +66,8 @@ message Metrics {
// - it does not reveal the peak core FLOPS of the hardware
double flops = 2;
- // The VMEM bandwidth used to load operands from HBM, as a fraction of
- // thereotical VMEM bandwidth on the specific hardware.
+ // The memory bandwidth used to load operands, as a fraction of
+ // thereotical memory bandwidth on the specific hardware.
double memory_bandwidth = 3;
double raw_time = 11; // Elapsed core-time in picoseconds.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index fc1320501b..a43f45554f 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -22,13 +22,22 @@ message LearningRate {
}
}
+// Each optimizer's parameter proto has a link to its documentation and CPU
+// implementation (if available) for user reference.
+
+// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151
message AdagradParameters {
float initial_accumulator = 1;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423
message StochasticGradientDescentParameters {
}
+// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192
message FtrlParameters {
float l1 = 1;
float l2 = 2;
@@ -41,21 +50,38 @@ message FtrlParameters {
// learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
+//
+// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54
+//
+// Note that the code by default implements the lazy version of Adam
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
+// unless the use_non_lazy_adam parameter is set, in which case it implements
+// the normal version of Adam that updates all parameters in the embedding
+// table, even for entries that are not used in the current minibatch
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
+// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
+// order to get correct results; a warning will be printed otherwise (which may
+// change to an error in the future).
message AdamParameters {
float beta1 = 3;
float beta2 = 4;
float epsilon = 5;
float initial_m = 6;
float initial_v = 7;
+ bool use_non_lazy_adam = 8;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271
message MomentumParameters {
float momentum = 1;
bool use_nesterov = 2;
float initial_accum = 3;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356
message RmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -64,6 +90,8 @@ message RmsPropParameters {
float initial_mom = 5;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372
message CenteredRmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -73,6 +101,7 @@ message CenteredRmsPropParameters {
float initial_mg = 6;
}
+// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
message MdlAdagradLightParameters {
float l2 = 1;
float lr_power = 2;
@@ -91,6 +120,8 @@ message MdlAdagradLightParameters {
float initial_benefit = 15;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68
message AdadeltaParameters {
float rho = 1;
float epsilon = 2;
@@ -98,6 +129,8 @@ message AdadeltaParameters {
float initial_update = 4;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164
message ProximalAdagradParameters {
float l1 = 1;
float l2 = 2;
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
index e06a720e82..20b7ba0997 100644
--- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
-
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
@@ -28,18 +27,16 @@ import threading
import time
from tensorflow.core.util.event_pb2 import SessionLog
-
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
-class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
+class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
@@ -67,7 +64,7 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
- logging.info("Create CheckpointSaverHook.")
+ logging.info("Create AsyncCheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
@@ -144,6 +141,10 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
+ # Skip saving on step 0
+ if step == 0:
+ return
+
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
@@ -162,7 +163,6 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
- logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 956d0142a3..696656e840 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -959,7 +959,16 @@ class TPUFunction(object):
# Compute our outfeed depending on the execution mode
if is_training:
- self._cloned_model._make_train_function()
+ if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
+ # For Keras optimizer, we try to place the variable weights on the TPU
+ # device. Keras creates optimizer variables (e.g. momentum values for
+ # the Momentum optimizer) when _make_train_function is invoked.
+ with keras_tpu_variables.replicated_variable_for_optimizer(
+ self._tpu_assignment.num_towers):
+ self._cloned_model._make_train_function()
+ else:
+ self._cloned_model._make_train_function()
+
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
for tensor in self._cloned_model.train_function.outputs
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 170977d8ab..598da7418e 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -25,10 +25,15 @@ from __future__ import print_function
import contextlib
+import numpy as np
+
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -285,3 +290,51 @@ def replicated_scope(num_replicas):
return variable_scope.variable_scope(
"", custom_getter=_replicated_variable_getter)
+
+
+@contextlib.contextmanager
+def replicated_variable_for_optimizer(num_replicas):
+ """Context manager for optimizer weights. Overrides K.variable."""
+ if num_replicas == 1:
+ yield
+ return
+
+ try:
+ old_v = backend.variable
+
+ def opt_variable(value, dtype=None, name=None, constraint=None):
+ """Instantiates a variable and returns it."""
+ if dtype is None:
+ dtype = backend.floatx()
+
+ variables = []
+ for i in range(num_replicas):
+ # Keras holds the variables in optimizer class instance , so the name
+ # does not matter here. ResourceVariable constructor will find a unique
+ # name (including name=None) for each replica.
+ with ops.device("device:TPU:{}".format(i)):
+ v = resource_variable_ops.ResourceVariable(
+ value,
+ dtype=dtypes_module.as_dtype(dtype),
+ name=name,
+ constraint=constraint)
+ variables.append(v)
+ name = "replicate_{}_{}".format("variable" if name is None else name,
+ ops.uid())
+ v = ReplicatedVariable(name, variables)
+
+ # pylint: disable=protected-access
+
+ if isinstance(value, np.ndarray):
+ v._keras_shape = value.shape
+ elif hasattr(value, "shape"):
+ v._keras_shape = backend.int_shape(value)
+ v._uses_learning_phase = False
+ backend.track_variable(v)
+ return v
+
+ backend.variable = opt_variable
+ yield
+
+ finally:
+ backend.variable = old_v
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 23c54511ca..545cee637f 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -231,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
+ `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -254,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with @{tf.contrib.summary.create_file_writer}.
+ summaries with `tf.contrib.summary.create_file_writer`.
"""
def __new__(cls,
@@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._feed_error = None
self._finished = False
+ self._should_initialize_tpu = True
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
- self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ if self._should_initialize_tpu:
+ self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ else:
+ self._init_ops = []
+ self._finalize_ops = []
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
self._init_ops.extend(summary_writer_init_ops)
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index ddf8365d61..b565ebd073 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -313,6 +313,5 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)