aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-07 22:33:54 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:35 -0800
commit18135df3a56d0fb0a4f8e93d7b8332e4de3283e2 (patch)
tree3910bcbb035ad4a8aef3f0559ce09eb78a9cc205
parent1e53f4caff649fb86ef74a2485547312372d399f (diff)
Automated g4 rollback of changelist 174912490
PiperOrigin-RevId: 174961746
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/contrib/data/BUILD13
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc232
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD40
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py2
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py8
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py2
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py2
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt452
-rw-r--r--tensorflow/core/ops/dataset_ops.cc197
-rw-r--r--tensorflow/python/kernel_tests/iterator_ops_test.py62
-rw-r--r--tensorflow/python/kernel_tests/range_dataset_op_test.py330
-rw-r--r--tensorflow/python/kernel_tests/reader_dataset_ops_test.py298
23 files changed, 1366 insertions, 292 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index 5b62598aa5..f978c8ccd5 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -70,7 +70,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
- "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 03c168795c..4a61ed7a35 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -81,7 +81,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
-GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index a14b733158..7636e9ba6e 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -776,8 +776,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py)
-GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops"
- DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 7bcf5a5f4d..eaede0e00e 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -35,19 +35,8 @@ tf_custom_op_library(
],
)
-# TODO(mrry): Move the kernels out of the core library into this library.
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- ],
-)
-
tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "prefetching_ops",
- ],
+ op_lib_names = ["prefetching_ops"],
)
filegroup(
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 0c7e793689..824ac4298f 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -41,8 +41,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
# pylint: disable=unused-import
+
from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
from tensorflow.contrib.data.python.ops.batching import unbatch
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
deleted file mode 100644
index 1574384cb2..0000000000
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ /dev/null
@@ -1,232 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_def_builder.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-// --------------------------------------------------------------------------
-
-// The ops in this section can be composed to define an input
-// pipeline. Each op produces a DT_VARIANT tensor that represents
-// a DAG of "dataset" objects. An "dataset" object can be converted
-// to a stateful "iterator" by passing the "dataset" to the
-// "MakeIterator" op.
-//
-// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are
-// not presently serializable. To avoid issues with constant folding, ensure
-// that any "source dataset" ops (i.e. ops that output a dataset and do not
-// take one as input) are marked "stateful".
-
-REGISTER_OP("IgnoreErrorsDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
-
-REGISTER_OP("MapAndBatchDataset")
- .Input("input_dataset: variant")
- .Input("other_arguments: Targuments")
- .Input("batch_size: int64")
- .Input("num_parallel_batches: int64")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset` and then
-batches `batch_size` of them.
-
-Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
-to `batch_size * num_parallel_batches` copies of `f` in parallel.
-
-batch_size: A scalar representing the number of elements to accumulate in a
- batch. It determines the number of concurrent invocations of `f` that process
- elements from `input_dataset` in parallel.
-num_parallel_batches: A scalar representing the number of batches to create in
- parallel. Processing multiple batches in parallel benefits workloads prone to
- stragglers.
-)doc");
-
-REGISTER_OP("ScanDataset")
- .Input("input_dataset: variant")
- .Input("initial_state: Tstate")
- .Input("other_arguments: Targuments")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Tstate: list(type) >= 1")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset successively reduces `f` over the elements of `input_dataset`.
-)doc");
-
-REGISTER_OP("ParallelInterleaveDataset")
- .Input("input_dataset: variant")
- .Input("other_arguments: Targuments")
- .Input("cycle_length: int64")
- .Input("block_length: int64")
- .Input("sloppy: bool")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset`.
-
-The resulting dataset is similar to the `InterleaveDataset`, with the exception
-that if retrieving the next value from a dataset would cause the requester to
-block, it will skip that input dataset. This dataset is especially useful
-when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
-allows the training step to proceed so long as some data is available.
-
-!! WARNING !! This dataset is not deterministic!
-
-f: A function mapping elements of `input_dataset`, concatenated with
- `other_arguments`, to a Dataset variant that contains elements matching
- `output_types` and `output_shapes`.
-)doc");
-
-REGISTER_OP("GroupByWindowDataset")
- .Input("input_dataset: variant")
- .Input("key_func_other_arguments: Tkey_func_other_arguments")
- .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
- .Input(
- "window_size_func_other_arguments: Twindow_size_func_other_arguments")
- .Output("handle: variant")
- .Attr("key_func: func")
- .Attr("reduce_func: func")
- .Attr("window_size_func: func")
- .Attr("Tkey_func_other_arguments: list(type) >= 0")
- .Attr("Treduce_func_other_arguments: list(type) >= 0")
- .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that computes a windowed group-by on `input_dataset`.
-
-// TODO(mrry): Support non-int64 keys.
-
-key_func: A function mapping an element of `input_dataset`, concatenated
- with `key_func_other_arguments` to a scalar value of type DT_INT64.
-)doc");
-
-REGISTER_OP("DenseToSparseBatchDataset")
- .Input("input_dataset: variant")
- .Input("batch_size: int64")
- .Input("row_shape: int64")
- .Output("handle: variant")
- // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
- .Attr("output_types: list(type) >= 1")
- // NOTE(mrry): the 1st and 2nd elements will be vectors.
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that yields a SparseTensor for each element of the input.
-
-input_dataset: A handle to an input dataset. Must have a single component.
-batch_size: A scalar representing the number of elements to accumulate in a
- batch.
-row_shape: A vector representing the dense shape of each row in the produced
- SparseTensor. The shape may be partially specified, using `-1` to indicate
- that a particular dimension should use the maximum size of all batch elements.
-)doc");
-
-REGISTER_OP("SqlDataset")
- .Input("driver_name: string")
- .Input("data_source_name: string")
- .Input("query: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that executes a SQL query and emits rows of the result set.
-
-driver_name: The database type. Currently, the only supported type is 'sqlite'.
-data_source_name: A connection string to connect to the database.
-query: A SQL query to execute.
-)doc");
-
-REGISTER_OP("DatasetToSingleElement")
- .Input("dataset: variant")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs the single element from the given dataset.
-
-dataset: A handle to a dataset that contains a single element.
-components: The components of the single element of `input`.
-)doc");
-
-REGISTER_OP("SerializeIterator")
- .Input("resource_handle: resource")
- .Output("serialized: variant")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Converts the given `resource_handle` representing an iterator to a variant tensor.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
- resource.
-)doc");
-
-REGISTER_OP("DeserializeIterator")
- .Input("resource_handle: resource")
- .Input("serialized: variant")
- .SetShapeFn(shape_inference::NoOutputs)
- .Doc(R"doc(
-Converts the given variant tensor to an iterator and stores it in the given resource.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
- resource.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index 271d80a54b..bda9a2a4a3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -21,7 +21,6 @@ import os
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 329dc80ba5..f59ac760dc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -21,7 +21,6 @@ import os
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import enumerate_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,6 +29,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 8033f1d388..3ae8f71d77 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -21,7 +21,6 @@ import gzip
import os
import zlib
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
@@ -34,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 727c5d1c38..1b81cf5be9 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -12,6 +12,20 @@ load(
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
py_library(
+ name = "dataset_ops",
+ srcs = [
+ "dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":transformation_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
name = "iterator_ops",
srcs = [
"iterator_ops.py",
@@ -59,7 +73,6 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":gen_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
@@ -115,31 +128,6 @@ tf_custom_op_py_library(
],
)
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"],
-)
-
-tf_custom_op_py_library(
- name = "dataset_ops",
- srcs = ["dataset_ops.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- ":transformation_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index e6e5f716b6..abc9212a87 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,6 +24,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index c4c4426809..45d6dbe743 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -20,21 +20,15 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import enumerate_ops
from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
-from tensorflow.python.platform import resource_loader
from tensorflow.python.util import deprecation
-_dataset_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
-
-
class Dataset(dataset_ops.Dataset):
"""Represents a potentially large set of elements.
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 51a2791072..238bb52b02 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
+from tensorflow.python.ops import gen_dataset_ops
def ignore_errors():
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 1c7c94b3c8..6df7b22fb6 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
def group_by_window(key_func,
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index ce23e95697..74a919c1ff 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 32d2f42c93..d736029fb0 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import saver
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index f22298b757..2e1c3153ca 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@@ -26,6 +25,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 87bbbb7d19..5acaed48a3 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -19,11 +19,11 @@ from __future__ import print_function
import collections
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
class _ScanDataset(dataset_ops.Dataset):
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 8b8251f84b..a4b5ca16af 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -8271,6 +8271,29 @@ op {
}
}
op {
+ name: "DatasetToSingleElement"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "DebugGradientIdentity"
input_arg {
name: "input"
@@ -9249,6 +9272,69 @@ op {
}
}
op {
+ name: "DenseToSparseBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "row_shape"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "DenseToSparseBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "row_shape"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "DenseToSparseSetOperation"
input_arg {
name: "set1"
@@ -9742,6 +9828,18 @@ op {
}
}
op {
+ name: "DeserializeIterator"
+ input_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "serialized"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
name: "DeserializeManySparse"
input_arg {
name: "serialized_sparse"
@@ -13495,6 +13593,131 @@ op {
}
}
op {
+ name: "GroupByWindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "window_size_func_other_arguments"
+ type_list_attr: "Twindow_size_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "window_size_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Twindow_size_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "GroupByWindowDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "key_func_other_arguments"
+ type_list_attr: "Tkey_func_other_arguments"
+ }
+ input_arg {
+ name: "reduce_func_other_arguments"
+ type_list_attr: "Treduce_func_other_arguments"
+ }
+ input_arg {
+ name: "window_size_func_other_arguments"
+ type_list_attr: "Twindow_size_func_other_arguments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "key_func"
+ type: "func"
+ }
+ attr {
+ name: "reduce_func"
+ type: "func"
+ }
+ attr {
+ name: "window_size_func"
+ type: "func"
+ }
+ attr {
+ name: "Tkey_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Treduce_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Twindow_size_func_other_arguments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "HSVToRGB"
input_arg {
name: "images"
@@ -13915,6 +14138,53 @@ op {
}
}
op {
+ name: "IgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "IgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Imag"
input_arg {
name: "input"
@@ -15819,6 +16089,50 @@ op {
is_stateful: true
}
op {
+ name: "MapAndBatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "batch_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_batches"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "MapClear"
attr {
name: "capacity"
@@ -20557,6 +20871,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "sloppy"
+ type: DT_BOOL
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -30147,6 +30509,52 @@ op {
}
}
op {
+ name: "ScanDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ScatterAdd"
input_arg {
name: "ref"
@@ -31862,6 +32270,18 @@ op {
}
}
op {
+ name: "SerializeIterator"
+ input_arg {
+ name: "resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "serialized"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
name: "SerializeManySparse"
input_arg {
name: "sparse_indices"
@@ -37266,6 +37686,38 @@ op {
}
}
op {
+ name: "SqlDataset"
+ input_arg {
+ name: "driver_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "data_source_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "query"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
name: "Sqrt"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 8f5d8308a3..f512213964 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -141,6 +141,16 @@ count: A scalar representing the number of elements from the `input_dataset`
that should be skipped. If count is -1, skips everything.
)doc");
+REGISTER_OP("IgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+)doc");
+
REGISTER_OP("MapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -174,6 +184,32 @@ num_parallel_calls: The number of concurrent invocations of `f` that process
elements from `input_dataset` in parallel.
)doc");
+REGISTER_OP("MapAndBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_batches: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch. It determines the number of concurrent invocations of `f` that process
+ elements from `input_dataset` in parallel.
+num_parallel_batches: A scalar representing the number of batches to create in
+ parallel. Processing multiple batches in parallel benefits workloads prone to
+ stragglers.
+)doc");
+
REGISTER_OP("PrefetchDataset")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
@@ -188,6 +224,21 @@ buffer_size: The maximum number of elements to buffer in an iterator over
this dataset.
)doc");
+REGISTER_OP("ScanDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset successively reduces `f` over the elements of `input_dataset`.
+)doc");
+
REGISTER_OP("FlatMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -234,6 +285,59 @@ f: A function mapping elements of `input_dataset`, concatenated with
`output_types` and `output_shapes`.
)doc");
+REGISTER_OP("ParallelInterleaveDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("sloppy: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset`.
+
+The resulting dataset is similar to the `InterleaveDataset`, with the exception
+that if retrieving the next value from a dataset would cause the requester to
+block, it will skip that input dataset. This dataset is especially useful
+when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
+allows the training step to proceed so long as some data is available.
+
+!! WARNING !! This dataset is not deterministic!
+
+f: A function mapping elements of `input_dataset`, concatenated with
+ `other_arguments`, to a Dataset variant that contains elements matching
+ `output_types` and `output_shapes`.
+)doc");
+
+REGISTER_OP("GroupByWindowDataset")
+ .Input("input_dataset: variant")
+ .Input("key_func_other_arguments: Tkey_func_other_arguments")
+ .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+ .Input(
+ "window_size_func_other_arguments: Twindow_size_func_other_arguments")
+ .Output("handle: variant")
+ .Attr("key_func: func")
+ .Attr("reduce_func: func")
+ .Attr("window_size_func: func")
+ .Attr("Tkey_func_other_arguments: list(type) >= 0")
+ .Attr("Treduce_func_other_arguments: list(type) >= 0")
+ .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that computes a windowed group-by on `input_dataset`.
+
+// TODO(mrry): Support non-int64 keys.
+
+key_func: A function mapping an element of `input_dataset`, concatenated
+ with `key_func_other_arguments` to a scalar value of type DT_INT64.
+)doc");
+
REGISTER_OP("FilterDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -304,6 +408,27 @@ padding_values: A list of scalars containing the padding value to use for
each of the outputs.
)doc");
+REGISTER_OP("DenseToSparseBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("row_shape: int64")
+ .Output("handle: variant")
+ // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
+ .Attr("output_types: list(type) >= 1")
+ // NOTE(mrry): the 1st and 2nd elements will be vectors.
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that yields a SparseTensor for each element of the input.
+
+input_dataset: A handle to an input dataset. Must have a single component.
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch.
+row_shape: A vector representing the dense shape of each row in the produced
+ SparseTensor. The shape may be partially specified, using `-1` to indicate
+ that a particular dimension should use the maximum size of all batch elements.
+)doc");
+
REGISTER_OP("RangeDataset")
.Input("start: int64")
.Input("stop: int64")
@@ -389,6 +514,24 @@ compression_type: A scalar containing either (i) the empty string (no
buffer_size: A scalar containing the number of bytes to buffer.
)doc");
+REGISTER_OP("SqlDataset")
+ .Input("driver_name: string")
+ .Input("data_source_name: string")
+ .Input("query: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that executes a SQL query and emits rows of the result set.
+
+driver_name: The database type. Currently, the only supported type is 'sqlite'.
+data_source_name: A connection string to connect to the database.
+query: A SQL query to execute.
+)doc");
+
REGISTER_OP("FixedLengthRecordDataset")
.Input("filenames: string")
.Input("header_bytes: int64")
@@ -519,6 +662,36 @@ REGISTER_OP("IteratorGetNext")
Gets the next output from the given iterator.
)doc");
+REGISTER_OP("DatasetToSingleElement")
+ .Input("dataset: variant")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the single element from the given dataset.
+
+dataset: A handle to a dataset that contains a single element.
+components: The components of the single element of `input`.
+)doc");
+
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
@@ -547,4 +720,28 @@ output_shapes: If specified, defines the shape of each tuple component in an
element produced by the resulting iterator.
)doc");
+REGISTER_OP("SerializeIterator")
+ .Input("resource_handle: resource")
+ .Output("serialized: variant")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+ .Input("resource_handle: resource")
+ .Input("serialized: variant")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index 60a44b5b14..2128ef4ae1 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -17,12 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +33,9 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@@ -533,6 +537,64 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
+ def testIncorrectIteratorRestore(self):
+
+ def _path():
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ _path(), parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def _build_range_dataset_graph():
+ start = 1
+ stop = 10
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ def _build_reader_dataset_graph():
+ filenames = ["test"] # Does not exist but we don't care in this test.
+ iterator = readers.FixedLengthRecordDataset(
+ filenames, 1, 0, 0).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next_op = iterator.get_next()
+ save_op = _save_op(iterator._iterator_resource)
+ restore_op = _restore_op(iterator._iterator_resource)
+ return init_op, get_next_op, save_op, restore_op
+
+ # Saving iterator for RangeDataset graph.
+ with ops.Graph().as_default() as g:
+ init_op, _, save_op, _ = _build_range_dataset_graph()
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(save_op)
+
+ # Attempt to restore the saved iterator into an IteratorResource of
+ # incompatible type. An iterator of RangeDataset has output type int64,
+ # while an iterator of FixedLengthRecordDataset has output type string.
+ # So an InvalidArgumentError should be raised by
+ # IteratorResource::set_iterator.
+ with ops.Graph().as_default() as g:
+ _, _, _, restore_op = _build_reader_dataset_graph()
+ with self.test_session(graph=g) as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(restore_op)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 3c1685c951..0c530522b8 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -17,15 +17,32 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
class RangeDatasetTest(test.TestCase):
+ def tearDown(self):
+ # Remove all checkpoint files.
+ prefix = self._iterator_checkpoint_prefix()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
+
def testStop(self):
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
@@ -151,6 +168,319 @@ class RangeDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def _iterator_checkpoint_prefix(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def testSaveRestore(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Saving and restoring in same session.
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testRestoreWithoutBuildingDatasetGraph(self):
+
+ def _build_graph(start, stop, num_epochs):
+ dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ num_epochs = 5
+ break_point = 5
+ break_epoch = 3
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for _ in range(break_epoch):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ # Create an empty IteratorResource and restore the Iterator into it.
+ output_types = dtypes.int64
+ output_shapes = tensor_shape.scalar()
+ iterator = iterator_ops.Iterator.from_structure(output_types,
+ output_shapes)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ get_next = iterator.get_next()
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for _ in range(break_epoch + 1, num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testRestoreInModifiedGraph(self):
+
+ def _build_graph(start, stop):
+ dataset = dataset_ops.Dataset.range(start, stop)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ stop_1 = 8
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ # Intentionally build a graph with a different value for stop to make sure
+ # the original dataset graph is actually getting loaded.
+ init_op, get_next, _, restore_op = _build_graph(start, stop_1)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testInitThenRestore(self):
+ # Note: Calling init_op before restore_op is redundant. This test just makes
+ # sure we do not fail if restore is called on an already initialized
+ # iterator resource.
+
+ def _build_graph(start, stop):
+ dataset = dataset_ops.Dataset.range(start, stop)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMultipleSaves(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ break_point1 = 5
+ break_point2 = 7
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point1):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point1, break_point2):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ break_point2 = 7
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_point2, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSaveRestoreWithRepeat(self):
+
+ def _build_graph(start, stop, num_epochs):
+ iterator = dataset_ops.Dataset.range(
+ start, stop).repeat(num_epochs).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ num_epochs = 5
+ break_range = 5
+ break_epoch = 3
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(
+ start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(break_epoch - 1):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for i in range(start, break_range):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for i in range(break_range, stop):
+ self.assertEqual(i, sess.run(get_next))
+ for _ in range(break_epoch, num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSaveRestoreExhaustedIterator(self):
+
+ def _build_graph(start, stop, num_epochs):
+ iterator = dataset_ops.Dataset.range(
+ start, stop).repeat(num_epochs).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ start = 2
+ stop = 10
+ num_epochs = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(
+ start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(num_epochs):
+ for i in range(start, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index 70b6ce442e..c8e7333b4b 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -26,8 +26,13 @@ from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -267,6 +272,299 @@ class FixedLengthRecordReaderTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
+ def _iterator_checkpoint_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_path(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def _build_iterator_graph(self, num_epochs):
+ filenames = self._createFiles()
+ dataset = (readers.FixedLengthRecordDataset(
+ filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
+ .repeat(num_epochs))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next_op = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next_op, save_op, restore_op
+
+ def _restore_iterator(self):
+ output_types = dtypes.string
+ output_shapes = tensor_shape.scalar()
+ iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
+ get_next = iterator.get_next()
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return restore_op, get_next
+
+ def testSaveRestore(self):
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testInitThenRestore(self):
+ # Note: Calling init_op before restore_op is redundant. This test just makes
+ # sure we do not fail if restore is called on an already initialized
+ # iterator resource.
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreInModifiedGraph(self):
+ num_epochs = 10
+ num_epochs_1 = 20
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs_1)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreWithoutBuildingDatasetGraph(self):
+ num_epochs = 10
+ epoch_break = 5
+ file_break = self._num_files // 2
+ record_break = self._num_records // 2
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch == epoch_break and f == file_break and
+ r == record_break):
+ sess.run(save_op)
+ break
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ else:
+ continue
+ break
+ else:
+ continue
+ break
+ else:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ with ops.Graph().as_default() as g:
+ restore_op, get_next_op = self._restore_iterator()
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for epoch in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ if (epoch < epoch_break or
+ (epoch == epoch_break and f < file_break) or
+ (epoch == epoch_break and f == file_break and
+ r < record_break)):
+ continue
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreUnusedIterator(self):
+ num_epochs = 10
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ # Save unused iterator.
+ sess.run(save_op)
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ for _ in range(num_epochs * self._num_files * self._num_records):
+ sess.run(get_next_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ def testRestoreExhaustedIterator(self):
+ num_epochs = 10
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(init_op)
+ # Note: There is no checkpoint saved currently so a NotFoundError is
+ # raised.
+ with self.assertRaises(errors.NotFoundError):
+ sess.run(restore_op)
+ for _ in range(num_epochs):
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ self.assertEqual(self._record(f, r), sess.run(get_next_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
+ num_epochs=num_epochs)
+ with self.test_session(graph=g) as sess:
+ sess.run(restore_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
class TFRecordDatasetTest(test.TestCase):