aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/contrib/data/BUILD13
-rw-r--r--tensorflow/contrib/data/__init__.py2
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc232
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py2
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD40
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py2
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py8
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py2
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py2
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py2
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt452
-rw-r--r--tensorflow/core/ops/dataset_ops.cc197
-rw-r--r--tensorflow/python/kernel_tests/iterator_ops_test.py62
-rw-r--r--tensorflow/python/kernel_tests/range_dataset_op_test.py330
-rw-r--r--tensorflow/python/kernel_tests/reader_dataset_ops_test.py298
23 files changed, 292 insertions, 1366 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index f978c8ccd5..5b62598aa5 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -70,6 +70,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 4a61ed7a35..03c168795c 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -81,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc")
+GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 7636e9ba6e..a14b733158 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -776,6 +776,8 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py)
+GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops"
+ DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops"
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index eaede0e00e..7bcf5a5f4d 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -35,8 +35,19 @@ tf_custom_op_library(
],
)
+# TODO(mrry): Move the kernels out of the core library into this library.
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = [
+ "ops/dataset_ops.cc",
+ ],
+)
+
tf_gen_op_libs(
- op_lib_names = ["prefetching_ops"],
+ op_lib_names = [
+ "dataset_ops",
+ "prefetching_ops",
+ ],
)
filegroup(
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 824ac4298f..0c7e793689 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -41,8 +41,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import
+# pylint: disable=unused-import
from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder
from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch
from tensorflow.contrib.data.python.ops.batching import unbatch
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
new file mode 100644
index 0000000000..1574384cb2
--- /dev/null
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -0,0 +1,232 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def_builder.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+// --------------------------------------------------------------------------
+
+// The ops in this section can be composed to define an input
+// pipeline. Each op produces a DT_VARIANT tensor that represents
+// a DAG of "dataset" objects. An "dataset" object can be converted
+// to a stateful "iterator" by passing the "dataset" to the
+// "MakeIterator" op.
+//
+// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are
+// not presently serializable. To avoid issues with constant folding, ensure
+// that any "source dataset" ops (i.e. ops that output a dataset and do not
+// take one as input) are marked "stateful".
+
+REGISTER_OP("IgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+)doc");
+
+REGISTER_OP("MapAndBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("batch_size: int64")
+ .Input("num_parallel_batches: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset` and then
+batches `batch_size` of them.
+
+Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
+to `batch_size * num_parallel_batches` copies of `f` in parallel.
+
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch. It determines the number of concurrent invocations of `f` that process
+ elements from `input_dataset` in parallel.
+num_parallel_batches: A scalar representing the number of batches to create in
+ parallel. Processing multiple batches in parallel benefits workloads prone to
+ stragglers.
+)doc");
+
+REGISTER_OP("ScanDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset successively reduces `f` over the elements of `input_dataset`.
+)doc");
+
+REGISTER_OP("ParallelInterleaveDataset")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("sloppy: bool")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that applies `f` to the outputs of `input_dataset`.
+
+The resulting dataset is similar to the `InterleaveDataset`, with the exception
+that if retrieving the next value from a dataset would cause the requester to
+block, it will skip that input dataset. This dataset is especially useful
+when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
+allows the training step to proceed so long as some data is available.
+
+!! WARNING !! This dataset is not deterministic!
+
+f: A function mapping elements of `input_dataset`, concatenated with
+ `other_arguments`, to a Dataset variant that contains elements matching
+ `output_types` and `output_shapes`.
+)doc");
+
+REGISTER_OP("GroupByWindowDataset")
+ .Input("input_dataset: variant")
+ .Input("key_func_other_arguments: Tkey_func_other_arguments")
+ .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
+ .Input(
+ "window_size_func_other_arguments: Twindow_size_func_other_arguments")
+ .Output("handle: variant")
+ .Attr("key_func: func")
+ .Attr("reduce_func: func")
+ .Attr("window_size_func: func")
+ .Attr("Tkey_func_other_arguments: list(type) >= 0")
+ .Attr("Treduce_func_other_arguments: list(type) >= 0")
+ .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that computes a windowed group-by on `input_dataset`.
+
+// TODO(mrry): Support non-int64 keys.
+
+key_func: A function mapping an element of `input_dataset`, concatenated
+ with `key_func_other_arguments` to a scalar value of type DT_INT64.
+)doc");
+
+REGISTER_OP("DenseToSparseBatchDataset")
+ .Input("input_dataset: variant")
+ .Input("batch_size: int64")
+ .Input("row_shape: int64")
+ .Output("handle: variant")
+ // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
+ .Attr("output_types: list(type) >= 1")
+ // NOTE(mrry): the 1st and 2nd elements will be vectors.
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that yields a SparseTensor for each element of the input.
+
+input_dataset: A handle to an input dataset. Must have a single component.
+batch_size: A scalar representing the number of elements to accumulate in a
+ batch.
+row_shape: A vector representing the dense shape of each row in the produced
+ SparseTensor. The shape may be partially specified, using `-1` to indicate
+ that a particular dimension should use the maximum size of all batch elements.
+)doc");
+
+REGISTER_OP("SqlDataset")
+ .Input("driver_name: string")
+ .Input("data_source_name: string")
+ .Input("query: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a dataset that executes a SQL query and emits rows of the result set.
+
+driver_name: The database type. Currently, the only supported type is 'sqlite'.
+data_source_name: A connection string to connect to the database.
+query: A SQL query to execute.
+)doc");
+
+REGISTER_OP("DatasetToSingleElement")
+ .Input("dataset: variant")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Outputs the single element from the given dataset.
+
+dataset: A handle to a dataset that contains a single element.
+components: The components of the single element of `input`.
+)doc");
+
+REGISTER_OP("SerializeIterator")
+ .Input("resource_handle: resource")
+ .Output("serialized: variant")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Converts the given `resource_handle` representing an iterator to a variant tensor.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+REGISTER_OP("DeserializeIterator")
+ .Input("resource_handle: resource")
+ .Input("serialized: variant")
+ .SetShapeFn(shape_inference::NoOutputs)
+ .Doc(R"doc(
+Converts the given variant tensor to an iterator and stores it in the given resource.
+
+resource_handle: A handle to an iterator resource.
+serialized: A variant tensor storing the state of the iterator contained in the
+ resource.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
index bda9a2a4a3..271d80a54b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
@@ -21,6 +21,7 @@ import os
import numpy as np
from tensorflow.contrib.data.python.ops import dataset_ops
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
@@ -33,7 +34,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index f59ac760dc..329dc80ba5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -21,6 +21,7 @@ import os
from tensorflow.contrib.data.python.ops import dataset_ops
from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -29,7 +30,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 3ae8f71d77..8033f1d388 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,7 @@ import gzip
import os
import zlib
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
@@ -33,7 +34,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 1b81cf5be9..727c5d1c38 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -12,20 +12,6 @@ load(
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
py_library(
- name = "dataset_ops",
- srcs = [
- "dataset_ops.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":transformation_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_library(
name = "iterator_ops",
srcs = [
"iterator_ops.py",
@@ -73,6 +59,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":gen_dataset_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dataset_ops_gen",
@@ -128,6 +115,31 @@ tf_custom_op_py_library(
],
)
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "dataset_ops",
+ srcs = ["dataset_ops.py"],
+ dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
+ kernels = [
+ "//tensorflow/contrib/data:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ ":transformation_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index abc9212a87..e6e5f716b6 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -24,7 +25,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index 45d6dbe743..c4c4426809 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -20,15 +20,21 @@ from __future__ import print_function
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import enumerate_ops
from tensorflow.contrib.data.python.ops import error_ops
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.contrib.util import loader
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
+from tensorflow.python.platform import resource_loader
from tensorflow.python.util import deprecation
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
+
+
class Dataset(dataset_ops.Dataset):
"""Represents a potentially large set of elements.
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 238bb52b02..51a2791072 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
-from tensorflow.python.ops import gen_dataset_ops
def ignore_errors():
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 6df7b22fb6..1c7c94b3c8 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
def group_by_window(key_func,
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 74a919c1ff..ce23e95697 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,12 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index d736029fb0..32d2f42c93 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.training import saver
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 2e1c3153ca..f22298b757 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@@ -25,7 +26,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index 5acaed48a3..87bbbb7d19 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -19,11 +19,11 @@ from __future__ import print_function
import collections
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
class _ScanDataset(dataset_ops.Dataset):
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index a4b5ca16af..8b8251f84b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -8271,29 +8271,6 @@ op {
}
}
op {
- name: "DatasetToSingleElement"
- input_arg {
- name: "dataset"
- type: DT_VARIANT
- }
- output_arg {
- name: "components"
- type_list_attr: "output_types"
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "DebugGradientIdentity"
input_arg {
name: "input"
@@ -9272,69 +9249,6 @@ op {
}
}
op {
- name: "DenseToSparseBatchDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "batch_size"
- type: DT_INT64
- }
- input_arg {
- name: "row_shape"
- type: DT_INT64
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
- is_stateful: true
-}
-op {
- name: "DenseToSparseBatchDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "batch_size"
- type: DT_INT64
- }
- input_arg {
- name: "row_shape"
- type: DT_INT64
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "DenseToSparseSetOperation"
input_arg {
name: "set1"
@@ -9828,18 +9742,6 @@ op {
}
}
op {
- name: "DeserializeIterator"
- input_arg {
- name: "resource_handle"
- type: DT_RESOURCE
- }
- input_arg {
- name: "serialized"
- type: DT_VARIANT
- }
- is_stateful: true
-}
-op {
name: "DeserializeManySparse"
input_arg {
name: "serialized_sparse"
@@ -13593,131 +13495,6 @@ op {
}
}
op {
- name: "GroupByWindowDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "key_func_other_arguments"
- type_list_attr: "Tkey_func_other_arguments"
- }
- input_arg {
- name: "reduce_func_other_arguments"
- type_list_attr: "Treduce_func_other_arguments"
- }
- input_arg {
- name: "window_size_func_other_arguments"
- type_list_attr: "Twindow_size_func_other_arguments"
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "key_func"
- type: "func"
- }
- attr {
- name: "reduce_func"
- type: "func"
- }
- attr {
- name: "window_size_func"
- type: "func"
- }
- attr {
- name: "Tkey_func_other_arguments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "Treduce_func_other_arguments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "Twindow_size_func_other_arguments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
- is_stateful: true
-}
-op {
- name: "GroupByWindowDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "key_func_other_arguments"
- type_list_attr: "Tkey_func_other_arguments"
- }
- input_arg {
- name: "reduce_func_other_arguments"
- type_list_attr: "Treduce_func_other_arguments"
- }
- input_arg {
- name: "window_size_func_other_arguments"
- type_list_attr: "Twindow_size_func_other_arguments"
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "key_func"
- type: "func"
- }
- attr {
- name: "reduce_func"
- type: "func"
- }
- attr {
- name: "window_size_func"
- type: "func"
- }
- attr {
- name: "Tkey_func_other_arguments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "Treduce_func_other_arguments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "Twindow_size_func_other_arguments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "HSVToRGB"
input_arg {
name: "images"
@@ -14138,53 +13915,6 @@ op {
}
}
op {
- name: "IgnoreErrorsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
- is_stateful: true
-}
-op {
- name: "IgnoreErrorsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Imag"
input_arg {
name: "input"
@@ -16089,50 +15819,6 @@ op {
is_stateful: true
}
op {
- name: "MapAndBatchDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "other_arguments"
- type_list_attr: "Targuments"
- }
- input_arg {
- name: "batch_size"
- type: DT_INT64
- }
- input_arg {
- name: "num_parallel_batches"
- type: DT_INT64
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "f"
- type: "func"
- }
- attr {
- name: "Targuments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "MapClear"
attr {
name: "capacity"
@@ -20871,54 +20557,6 @@ op {
}
}
op {
- name: "ParallelInterleaveDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "other_arguments"
- type_list_attr: "Targuments"
- }
- input_arg {
- name: "cycle_length"
- type: DT_INT64
- }
- input_arg {
- name: "block_length"
- type: DT_INT64
- }
- input_arg {
- name: "sloppy"
- type: DT_BOOL
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "f"
- type: "func"
- }
- attr {
- name: "Targuments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -30509,52 +30147,6 @@ op {
}
}
op {
- name: "ScanDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "initial_state"
- type_list_attr: "Tstate"
- }
- input_arg {
- name: "other_arguments"
- type_list_attr: "Targuments"
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "f"
- type: "func"
- }
- attr {
- name: "Tstate"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "Targuments"
- type: "list(type)"
- has_minimum: true
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "ScatterAdd"
input_arg {
name: "ref"
@@ -32270,18 +31862,6 @@ op {
}
}
op {
- name: "SerializeIterator"
- input_arg {
- name: "resource_handle"
- type: DT_RESOURCE
- }
- output_arg {
- name: "serialized"
- type: DT_VARIANT
- }
- is_stateful: true
-}
-op {
name: "SerializeManySparse"
input_arg {
name: "sparse_indices"
@@ -37686,38 +37266,6 @@ op {
}
}
op {
- name: "SqlDataset"
- input_arg {
- name: "driver_name"
- type: DT_STRING
- }
- input_arg {
- name: "data_source_name"
- type: DT_STRING
- }
- input_arg {
- name: "query"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
- is_stateful: true
-}
-op {
name: "Sqrt"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index f512213964..8f5d8308a3 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -141,16 +141,6 @@ count: A scalar representing the number of elements from the `input_dataset`
that should be skipped. If count is -1, skips everything.
)doc");
-REGISTER_OP("IgnoreErrorsDataset")
- .Input("input_dataset: variant")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
-
REGISTER_OP("MapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -184,32 +174,6 @@ num_parallel_calls: The number of concurrent invocations of `f` that process
elements from `input_dataset` in parallel.
)doc");
-REGISTER_OP("MapAndBatchDataset")
- .Input("input_dataset: variant")
- .Input("other_arguments: Targuments")
- .Input("batch_size: int64")
- .Input("num_parallel_batches: int64")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset` and then
-batches `batch_size` of them.
-
-Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up
-to `batch_size * num_parallel_batches` copies of `f` in parallel.
-
-batch_size: A scalar representing the number of elements to accumulate in a
- batch. It determines the number of concurrent invocations of `f` that process
- elements from `input_dataset` in parallel.
-num_parallel_batches: A scalar representing the number of batches to create in
- parallel. Processing multiple batches in parallel benefits workloads prone to
- stragglers.
-)doc");
-
REGISTER_OP("PrefetchDataset")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
@@ -224,21 +188,6 @@ buffer_size: The maximum number of elements to buffer in an iterator over
this dataset.
)doc");
-REGISTER_OP("ScanDataset")
- .Input("input_dataset: variant")
- .Input("initial_state: Tstate")
- .Input("other_arguments: Targuments")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Tstate: list(type) >= 1")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset successively reduces `f` over the elements of `input_dataset`.
-)doc");
-
REGISTER_OP("FlatMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -285,59 +234,6 @@ f: A function mapping elements of `input_dataset`, concatenated with
`output_types` and `output_shapes`.
)doc");
-REGISTER_OP("ParallelInterleaveDataset")
- .Input("input_dataset: variant")
- .Input("other_arguments: Targuments")
- .Input("cycle_length: int64")
- .Input("block_length: int64")
- .Input("sloppy: bool")
- .Output("handle: variant")
- .Attr("f: func")
- .Attr("Targuments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that applies `f` to the outputs of `input_dataset`.
-
-The resulting dataset is similar to the `InterleaveDataset`, with the exception
-that if retrieving the next value from a dataset would cause the requester to
-block, it will skip that input dataset. This dataset is especially useful
-when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it
-allows the training step to proceed so long as some data is available.
-
-!! WARNING !! This dataset is not deterministic!
-
-f: A function mapping elements of `input_dataset`, concatenated with
- `other_arguments`, to a Dataset variant that contains elements matching
- `output_types` and `output_shapes`.
-)doc");
-
-REGISTER_OP("GroupByWindowDataset")
- .Input("input_dataset: variant")
- .Input("key_func_other_arguments: Tkey_func_other_arguments")
- .Input("reduce_func_other_arguments: Treduce_func_other_arguments")
- .Input(
- "window_size_func_other_arguments: Twindow_size_func_other_arguments")
- .Output("handle: variant")
- .Attr("key_func: func")
- .Attr("reduce_func: func")
- .Attr("window_size_func: func")
- .Attr("Tkey_func_other_arguments: list(type) >= 0")
- .Attr("Treduce_func_other_arguments: list(type) >= 0")
- .Attr("Twindow_size_func_other_arguments: list(type) >= 0")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that computes a windowed group-by on `input_dataset`.
-
-// TODO(mrry): Support non-int64 keys.
-
-key_func: A function mapping an element of `input_dataset`, concatenated
- with `key_func_other_arguments` to a scalar value of type DT_INT64.
-)doc");
-
REGISTER_OP("FilterDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -408,27 +304,6 @@ padding_values: A list of scalars containing the padding value to use for
each of the outputs.
)doc");
-REGISTER_OP("DenseToSparseBatchDataset")
- .Input("input_dataset: variant")
- .Input("batch_size: int64")
- .Input("row_shape: int64")
- .Output("handle: variant")
- // NOTE(mrry): the 0th and 2nd elements will be DT_INT64.
- .Attr("output_types: list(type) >= 1")
- // NOTE(mrry): the 1st and 2nd elements will be vectors.
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that yields a SparseTensor for each element of the input.
-
-input_dataset: A handle to an input dataset. Must have a single component.
-batch_size: A scalar representing the number of elements to accumulate in a
- batch.
-row_shape: A vector representing the dense shape of each row in the produced
- SparseTensor. The shape may be partially specified, using `-1` to indicate
- that a particular dimension should use the maximum size of all batch elements.
-)doc");
-
REGISTER_OP("RangeDataset")
.Input("start: int64")
.Input("stop: int64")
@@ -514,24 +389,6 @@ compression_type: A scalar containing either (i) the empty string (no
buffer_size: A scalar containing the number of bytes to buffer.
)doc");
-REGISTER_OP("SqlDataset")
- .Input("driver_name: string")
- .Input("data_source_name: string")
- .Input("query: string")
- .Output("handle: variant")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
- // stateful to inhibit constant folding.
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that executes a SQL query and emits rows of the result set.
-
-driver_name: The database type. Currently, the only supported type is 'sqlite'.
-data_source_name: A connection string to connect to the database.
-query: A SQL query to execute.
-)doc");
-
REGISTER_OP("FixedLengthRecordDataset")
.Input("filenames: string")
.Input("header_bytes: int64")
@@ -662,36 +519,6 @@ REGISTER_OP("IteratorGetNext")
Gets the next output from the given iterator.
)doc");
-REGISTER_OP("DatasetToSingleElement")
- .Input("dataset: variant")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
- })
- .Doc(R"doc(
-Outputs the single element from the given dataset.
-
-dataset: A handle to a dataset that contains a single element.
-components: The components of the single element of `input`.
-)doc");
-
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
@@ -720,28 +547,4 @@ output_shapes: If specified, defines the shape of each tuple component in an
element produced by the resulting iterator.
)doc");
-REGISTER_OP("SerializeIterator")
- .Input("resource_handle: resource")
- .Output("serialized: variant")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Converts the given `resource_handle` representing an iterator to a variant tensor.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
- resource.
-)doc");
-
-REGISTER_OP("DeserializeIterator")
- .Input("resource_handle: resource")
- .Input("serialized: variant")
- .SetShapeFn(shape_inference::NoOutputs)
- .Doc(R"doc(
-Converts the given variant tensor to an iterator and stores it in the given resource.
-
-resource_handle: A handle to an iterator resource.
-serialized: A variant tensor storing the state of the iterator contained in the
- resource.
-)doc");
-
} // namespace tensorflow
diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py
index 2128ef4ae1..60a44b5b14 100644
--- a/tensorflow/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/kernel_tests/iterator_ops_test.py
@@ -17,14 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,9 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
@@ -537,64 +533,6 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
- def testIncorrectIteratorRestore(self):
-
- def _path():
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- _path(), parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(_path()), dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def _build_range_dataset_graph():
- start = 1
- stop = 10
- iterator = dataset_ops.Dataset.range(start,
- stop).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = _save_op(iterator._iterator_resource)
- restore_op = _restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- def _build_reader_dataset_graph():
- filenames = ["test"] # Does not exist but we don't care in this test.
- iterator = readers.FixedLengthRecordDataset(
- filenames, 1, 0, 0).make_initializable_iterator()
- init_op = iterator.initializer
- get_next_op = iterator.get_next()
- save_op = _save_op(iterator._iterator_resource)
- restore_op = _restore_op(iterator._iterator_resource)
- return init_op, get_next_op, save_op, restore_op
-
- # Saving iterator for RangeDataset graph.
- with ops.Graph().as_default() as g:
- init_op, _, save_op, _ = _build_range_dataset_graph()
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- sess.run(save_op)
-
- # Attempt to restore the saved iterator into an IteratorResource of
- # incompatible type. An iterator of RangeDataset has output type int64,
- # while an iterator of FixedLengthRecordDataset has output type string.
- # So an InvalidArgumentError should be raised by
- # IteratorResource::set_iterator.
- with ops.Graph().as_default() as g:
- _, _, _, restore_op = _build_reader_dataset_graph()
- with self.test_session(graph=g) as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(restore_op)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py
index 0c530522b8..3c1685c951 100644
--- a/tensorflow/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py
@@ -17,32 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os
-
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
class RangeDatasetTest(test.TestCase):
- def tearDown(self):
- # Remove all checkpoint files.
- prefix = self._iterator_checkpoint_prefix()
- pattern = prefix + "*"
- files = gfile.Glob(pattern)
- map(gfile.Remove, files)
-
def testStop(self):
stop = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator()
@@ -168,319 +151,6 @@ class RangeDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def _iterator_checkpoint_prefix(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_prefix(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def testSaveRestore(self):
-
- def _build_graph(start, stop):
- iterator = dataset_ops.Dataset.range(start,
- stop).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- break_point = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Saving and restoring in same session.
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testRestoreWithoutBuildingDatasetGraph(self):
-
- def _build_graph(start, stop, num_epochs):
- dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs)
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- num_epochs = 5
- break_point = 5
- break_epoch = 3
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for _ in range(break_epoch):
- for i in range(start, stop):
- self.assertEqual(i, sess.run(get_next))
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- # Create an empty IteratorResource and restore the Iterator into it.
- output_types = dtypes.int64
- output_shapes = tensor_shape.scalar()
- iterator = iterator_ops.Iterator.from_structure(output_types,
- output_shapes)
- restore_op = self._restore_op(iterator._iterator_resource)
- get_next = iterator.get_next()
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- for _ in range(break_epoch + 1, num_epochs):
- for i in range(start, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testRestoreInModifiedGraph(self):
-
- def _build_graph(start, stop):
- dataset = dataset_ops.Dataset.range(start, stop)
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- stop_1 = 8
- break_point = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- # Intentionally build a graph with a different value for stop to make sure
- # the original dataset graph is actually getting loaded.
- init_op, get_next, _, restore_op = _build_graph(start, stop_1)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testInitThenRestore(self):
- # Note: Calling init_op before restore_op is redundant. This test just makes
- # sure we do not fail if restore is called on an already initialized
- # iterator resource.
-
- def _build_graph(start, stop):
- dataset = dataset_ops.Dataset.range(start, stop)
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- break_point = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMultipleSaves(self):
-
- def _build_graph(start, stop):
- iterator = dataset_ops.Dataset.range(start,
- stop).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- start = 2
- stop = 10
- break_point1 = 5
- break_point2 = 7
-
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point1):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for i in range(break_point1, break_point2):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- break_point2 = 7
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for i in range(break_point2, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSaveRestoreWithRepeat(self):
-
- def _build_graph(start, stop, num_epochs):
- iterator = dataset_ops.Dataset.range(
- start, stop).repeat(num_epochs).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- start = 2
- stop = 10
- num_epochs = 5
- break_range = 5
- break_epoch = 3
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(
- start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for _ in range(break_epoch - 1):
- for i in range(start, stop):
- self.assertEqual(i, sess.run(get_next))
- for i in range(start, break_range):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for i in range(break_range, stop):
- self.assertEqual(i, sess.run(get_next))
- for _ in range(break_epoch, num_epochs):
- for i in range(start, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSaveRestoreExhaustedIterator(self):
-
- def _build_graph(start, stop, num_epochs):
- iterator = dataset_ops.Dataset.range(
- start, stop).repeat(num_epochs).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- start = 2
- stop = 10
- num_epochs = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(
- start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for _ in range(num_epochs):
- for i in range(start, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
index c8e7333b4b..70b6ce442e 100644
--- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py
@@ -26,13 +26,8 @@ from tensorflow.python.data.ops import readers
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.util import compat
@@ -272,299 +267,6 @@ class FixedLengthRecordReaderTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(iterator.get_next())
- def _iterator_checkpoint_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_path(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def _build_iterator_graph(self, num_epochs):
- filenames = self._createFiles()
- dataset = (readers.FixedLengthRecordDataset(
- filenames, self._record_bytes, self._header_bytes, self._footer_bytes)
- .repeat(num_epochs))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next_op = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next_op, save_op, restore_op
-
- def _restore_iterator(self):
- output_types = dtypes.string
- output_shapes = tensor_shape.scalar()
- iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes)
- get_next = iterator.get_next()
- restore_op = self._restore_op(iterator._iterator_resource)
- return restore_op, get_next
-
- def testSaveRestore(self):
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testInitThenRestore(self):
- # Note: Calling init_op before restore_op is redundant. This test just makes
- # sure we do not fail if restore is called on an already initialized
- # iterator resource.
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreInModifiedGraph(self):
- num_epochs = 10
- num_epochs_1 = 20
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs_1)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreWithoutBuildingDatasetGraph(self):
- num_epochs = 10
- epoch_break = 5
- file_break = self._num_files // 2
- record_break = self._num_records // 2
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch == epoch_break and f == file_break and
- r == record_break):
- sess.run(save_op)
- break
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- else:
- continue
- break
- else:
- continue
- break
- else:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- with ops.Graph().as_default() as g:
- restore_op, get_next_op = self._restore_iterator()
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for epoch in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- if (epoch < epoch_break or
- (epoch == epoch_break and f < file_break) or
- (epoch == epoch_break and f == file_break and
- r < record_break)):
- continue
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreUnusedIterator(self):
- num_epochs = 10
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- # Save unused iterator.
- sess.run(save_op)
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- for _ in range(num_epochs * self._num_files * self._num_records):
- sess.run(get_next_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- def testRestoreExhaustedIterator(self):
- num_epochs = 10
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(init_op)
- # Note: There is no checkpoint saved currently so a NotFoundError is
- # raised.
- with self.assertRaises(errors.NotFoundError):
- sess.run(restore_op)
- for _ in range(num_epochs):
- for f in range(self._num_files):
- for r in range(self._num_records):
- self.assertEqual(self._record(f, r), sess.run(get_next_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next_op, save_op, restore_op = self._build_iterator_graph(
- num_epochs=num_epochs)
- with self.test_session(graph=g) as sess:
- sess.run(restore_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
class TFRecordDatasetTest(test.TestCase):