From 18135df3a56d0fb0a4f8e93d7b8332e4de3283e2 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Nov 2017 22:33:54 -0800 Subject: Automated g4 rollback of changelist 174912490 PiperOrigin-RevId: 174961746 --- tensorflow/contrib/cmake/tf_core_kernels.cmake | 1 - tensorflow/contrib/cmake/tf_core_ops.cmake | 1 - tensorflow/contrib/cmake/tf_python.cmake | 2 - tensorflow/contrib/data/BUILD | 13 +- tensorflow/contrib/data/__init__.py | 2 +- tensorflow/contrib/data/ops/dataset_ops.cc | 232 ----------- .../data/python/kernel_tests/iterator_ops_test.py | 2 +- .../python/kernel_tests/range_dataset_op_test.py | 2 +- .../python/kernel_tests/reader_dataset_ops_test.py | 2 +- tensorflow/contrib/data/python/ops/BUILD | 40 +- tensorflow/contrib/data/python/ops/batching.py | 2 +- tensorflow/contrib/data/python/ops/dataset_ops.py | 8 +- tensorflow/contrib/data/python/ops/error_ops.py | 2 +- tensorflow/contrib/data/python/ops/grouping.py | 2 +- .../contrib/data/python/ops/interleave_ops.py | 2 +- tensorflow/contrib/data/python/ops/iterator_ops.py | 2 +- tensorflow/contrib/data/python/ops/readers.py | 2 +- tensorflow/contrib/data/python/ops/scan_ops.py | 2 +- tensorflow/core/ops/compat/ops_history.v1.pbtxt | 452 +++++++++++++++++++++ tensorflow/core/ops/dataset_ops.cc | 197 +++++++++ .../python/kernel_tests/iterator_ops_test.py | 62 +++ .../python/kernel_tests/range_dataset_op_test.py | 330 +++++++++++++++ .../python/kernel_tests/reader_dataset_ops_test.py | 298 ++++++++++++++ 23 files changed, 1366 insertions(+), 292 deletions(-) delete mode 100644 tensorflow/contrib/data/ops/dataset_ops.cc diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 5b62598aa5..f978c8ccd5 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -70,7 +70,6 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 03c168795c..4a61ed7a35 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -81,7 +81,6 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index a14b733158..7636e9ba6e 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -776,8 +776,6 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 7bcf5a5f4d..eaede0e00e 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -35,19 +35,8 @@ tf_custom_op_library( ], ) -# TODO(mrry): Move the kernels out of the core library into this library. -tf_custom_op_library( - name = "_dataset_ops.so", - srcs = [ - "ops/dataset_ops.cc", - ], -) - tf_gen_op_libs( - op_lib_names = [ - "dataset_ops", - "prefetching_ops", - ], + op_lib_names = ["prefetching_ops"], ) filegroup( diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 0c7e793689..824ac4298f 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -41,8 +41,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function - # pylint: disable=unused-import + from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import unbatch diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc deleted file mode 100644 index 1574384cb2..0000000000 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ /dev/null @@ -1,232 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_def_builder.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -// -------------------------------------------------------------------------- - -// The ops in this section can be composed to define an input -// pipeline. Each op produces a DT_VARIANT tensor that represents -// a DAG of "dataset" objects. An "dataset" object can be converted -// to a stateful "iterator" by passing the "dataset" to the -// "MakeIterator" op. -// -// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are -// not presently serializable. To avoid issues with constant folding, ensure -// that any "source dataset" ops (i.e. ops that output a dataset and do not -// take one as input) are marked "stateful". - -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the elements of `input_dataset` ignoring errors. -)doc"); - -REGISTER_OP("MapAndBatchDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("batch_size: int64") - .Input("num_parallel_batches: int64") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset` and then -batches `batch_size` of them. - -Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up -to `batch_size * num_parallel_batches` copies of `f` in parallel. - -batch_size: A scalar representing the number of elements to accumulate in a - batch. It determines the number of concurrent invocations of `f` that process - elements from `input_dataset` in parallel. -num_parallel_batches: A scalar representing the number of batches to create in - parallel. Processing multiple batches in parallel benefits workloads prone to - stragglers. -)doc"); - -REGISTER_OP("ScanDataset") - .Input("input_dataset: variant") - .Input("initial_state: Tstate") - .Input("other_arguments: Targuments") - .Output("handle: variant") - .Attr("f: func") - .Attr("Tstate: list(type) >= 1") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset successively reduces `f` over the elements of `input_dataset`. -)doc"); - -REGISTER_OP("ParallelInterleaveDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("cycle_length: int64") - .Input("block_length: int64") - .Input("sloppy: bool") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset`. - -The resulting dataset is similar to the `InterleaveDataset`, with the exception -that if retrieving the next value from a dataset would cause the requester to -block, it will skip that input dataset. This dataset is especially useful -when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it -allows the training step to proceed so long as some data is available. - -!! WARNING !! This dataset is not deterministic! - -f: A function mapping elements of `input_dataset`, concatenated with - `other_arguments`, to a Dataset variant that contains elements matching - `output_types` and `output_shapes`. -)doc"); - -REGISTER_OP("GroupByWindowDataset") - .Input("input_dataset: variant") - .Input("key_func_other_arguments: Tkey_func_other_arguments") - .Input("reduce_func_other_arguments: Treduce_func_other_arguments") - .Input( - "window_size_func_other_arguments: Twindow_size_func_other_arguments") - .Output("handle: variant") - .Attr("key_func: func") - .Attr("reduce_func: func") - .Attr("window_size_func: func") - .Attr("Tkey_func_other_arguments: list(type) >= 0") - .Attr("Treduce_func_other_arguments: list(type) >= 0") - .Attr("Twindow_size_func_other_arguments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that computes a windowed group-by on `input_dataset`. - -// TODO(mrry): Support non-int64 keys. - -key_func: A function mapping an element of `input_dataset`, concatenated - with `key_func_other_arguments` to a scalar value of type DT_INT64. -)doc"); - -REGISTER_OP("DenseToSparseBatchDataset") - .Input("input_dataset: variant") - .Input("batch_size: int64") - .Input("row_shape: int64") - .Output("handle: variant") - // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. - .Attr("output_types: list(type) >= 1") - // NOTE(mrry): the 1st and 2nd elements will be vectors. - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that yields a SparseTensor for each element of the input. - -input_dataset: A handle to an input dataset. Must have a single component. -batch_size: A scalar representing the number of elements to accumulate in a - batch. -row_shape: A vector representing the dense shape of each row in the produced - SparseTensor. The shape may be partially specified, using `-1` to indicate - that a particular dimension should use the maximum size of all batch elements. -)doc"); - -REGISTER_OP("SqlDataset") - .Input("driver_name: string") - .Input("data_source_name: string") - .Input("query: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that executes a SQL query and emits rows of the result set. - -driver_name: The database type. Currently, the only supported type is 'sqlite'. -data_source_name: A connection string to connect to the database. -query: A SQL query to execute. -)doc"); - -REGISTER_OP("DatasetToSingleElement") - .Input("dataset: variant") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast(i), output_shape_handle); - } - return Status::OK(); - }) - .Doc(R"doc( -Outputs the single element from the given dataset. - -dataset: A handle to a dataset that contains a single element. -components: The components of the single element of `input`. -)doc"); - -REGISTER_OP("SerializeIterator") - .Input("resource_handle: resource") - .Output("serialized: variant") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Converts the given `resource_handle` representing an iterator to a variant tensor. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -REGISTER_OP("DeserializeIterator") - .Input("resource_handle: resource") - .Input("serialized: variant") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Converts the given variant tensor to an iterator and stores it in the given resource. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 271d80a54b..bda9a2a4a3 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -21,7 +21,6 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -34,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 329dc80ba5..f59ac760dc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,7 +21,6 @@ import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -30,6 +29,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 8033f1d388..3ae8f71d77 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,7 +21,6 @@ import gzip import os import zlib -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 @@ -34,6 +33,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 727c5d1c38..1b81cf5be9 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -11,6 +11,20 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +py_library( + name = "dataset_ops", + srcs = [ + "dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":transformation_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + py_library( name = "iterator_ops", srcs = [ @@ -59,7 +73,6 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -115,31 +128,6 @@ tf_custom_op_py_library( ], ) -tf_gen_op_wrapper_py( - name = "gen_dataset_ops", - out = "gen_dataset_ops.py", - deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], -) - -tf_custom_op_py_library( - name = "dataset_ops", - srcs = ["dataset_ops.py"], - dso = ["//tensorflow/contrib/data:_dataset_ops.so"], - kernels = [ - "//tensorflow/contrib/data:dataset_ops_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_dataset_ops", - ":transformation_ops", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:platform", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index e6e5f716b6..abc9212a87 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -25,6 +24,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index c4c4426809..45d6dbe743 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -20,21 +20,15 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import grouping -from tensorflow.contrib.util import loader from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops -from tensorflow.python.platform import resource_loader from tensorflow.python.util import deprecation -_dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) - - class Dataset(dataset_ops.Dataset): """Represents a potentially large set of elements. diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 51a2791072..238bb52b02 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,9 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 1c7c94b3c8..6df7b22fb6 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops def group_by_window(key_func, diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index ce23e95697..74a919c1ff 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index 32d2f42c93..d736029fb0 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -17,8 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import saver diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index f22298b757..2e1c3153ca 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,7 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest @@ -26,6 +25,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 87bbbb7d19..5acaed48a3 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -19,11 +19,11 @@ from __future__ import print_function import collections -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops class _ScanDataset(dataset_ops.Dataset): diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 8b8251f84b..a4b5ca16af 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -8270,6 +8270,29 @@ op { } } } +op { + name: "DatasetToSingleElement" + input_arg { + name: "dataset" + type: DT_VARIANT + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "DebugGradientIdentity" input_arg { @@ -9248,6 +9271,69 @@ op { } } } +op { + name: "DenseToSparseBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "row_shape" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "DenseToSparseBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "row_shape" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "DenseToSparseSetOperation" input_arg { @@ -9741,6 +9827,18 @@ op { } } } +op { + name: "DeserializeIterator" + input_arg { + name: "resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "serialized" + type: DT_VARIANT + } + is_stateful: true +} op { name: "DeserializeManySparse" input_arg { @@ -13494,6 +13592,131 @@ op { } } } +op { + name: "GroupByWindowDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "key_func_other_arguments" + type_list_attr: "Tkey_func_other_arguments" + } + input_arg { + name: "reduce_func_other_arguments" + type_list_attr: "Treduce_func_other_arguments" + } + input_arg { + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "key_func" + type: "func" + } + attr { + name: "reduce_func" + type: "func" + } + attr { + name: "window_size_func" + type: "func" + } + attr { + name: "Tkey_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Treduce_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "GroupByWindowDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "key_func_other_arguments" + type_list_attr: "Tkey_func_other_arguments" + } + input_arg { + name: "reduce_func_other_arguments" + type_list_attr: "Treduce_func_other_arguments" + } + input_arg { + name: "window_size_func_other_arguments" + type_list_attr: "Twindow_size_func_other_arguments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "key_func" + type: "func" + } + attr { + name: "reduce_func" + type: "func" + } + attr { + name: "window_size_func" + type: "func" + } + attr { + name: "Tkey_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Treduce_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "Twindow_size_func_other_arguments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "HSVToRGB" input_arg { @@ -13914,6 +14137,53 @@ op { } } } +op { + name: "IgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "IgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "Imag" input_arg { @@ -15818,6 +16088,50 @@ op { } is_stateful: true } +op { + name: "MapAndBatchDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "batch_size" + type: DT_INT64 + } + input_arg { + name: "num_parallel_batches" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "MapClear" attr { @@ -20556,6 +20870,54 @@ op { type: "type" } } +op { + name: "ParallelInterleaveDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "sloppy" + type: DT_BOOL + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ParallelMapDataset" input_arg { @@ -30146,6 +30508,52 @@ op { } } } +op { + name: "ScanDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "initial_state" + type_list_attr: "Tstate" + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Tstate" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} op { name: "ScatterAdd" input_arg { @@ -31861,6 +32269,18 @@ op { } } } +op { + name: "SerializeIterator" + input_arg { + name: "resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "serialized" + type: DT_VARIANT + } + is_stateful: true +} op { name: "SerializeManySparse" input_arg { @@ -37265,6 +37685,38 @@ op { } } } +op { + name: "SqlDataset" + input_arg { + name: "driver_name" + type: DT_STRING + } + input_arg { + name: "data_source_name" + type: DT_STRING + } + input_arg { + name: "query" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "Sqrt" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 8f5d8308a3..f512213964 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -141,6 +141,16 @@ count: A scalar representing the number of elements from the `input_dataset` that should be skipped. If count is -1, skips everything. )doc"); +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + REGISTER_OP("MapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -174,6 +184,32 @@ num_parallel_calls: The number of concurrent invocations of `f` that process elements from `input_dataset` in parallel. )doc"); +REGISTER_OP("MapAndBatchDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_batches: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + +batch_size: A scalar representing the number of elements to accumulate in a + batch. It determines the number of concurrent invocations of `f` that process + elements from `input_dataset` in parallel. +num_parallel_batches: A scalar representing the number of batches to create in + parallel. Processing multiple batches in parallel benefits workloads prone to + stragglers. +)doc"); + REGISTER_OP("PrefetchDataset") .Input("input_dataset: variant") .Input("buffer_size: int64") @@ -188,6 +224,21 @@ buffer_size: The maximum number of elements to buffer in an iterator over this dataset. )doc"); +REGISTER_OP("ScanDataset") + .Input("input_dataset: variant") + .Input("initial_state: Tstate") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("f: func") + .Attr("Tstate: list(type) >= 1") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset successively reduces `f` over the elements of `input_dataset`. +)doc"); + REGISTER_OP("FlatMapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -234,6 +285,59 @@ f: A function mapping elements of `input_dataset`, concatenated with `output_types` and `output_shapes`. )doc"); +REGISTER_OP("ParallelInterleaveDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("sloppy: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset variant that contains elements matching + `output_types` and `output_shapes`. +)doc"); + +REGISTER_OP("GroupByWindowDataset") + .Input("input_dataset: variant") + .Input("key_func_other_arguments: Tkey_func_other_arguments") + .Input("reduce_func_other_arguments: Treduce_func_other_arguments") + .Input( + "window_size_func_other_arguments: Twindow_size_func_other_arguments") + .Output("handle: variant") + .Attr("key_func: func") + .Attr("reduce_func: func") + .Attr("window_size_func: func") + .Attr("Tkey_func_other_arguments: list(type) >= 0") + .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Twindow_size_func_other_arguments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that computes a windowed group-by on `input_dataset`. + +// TODO(mrry): Support non-int64 keys. + +key_func: A function mapping an element of `input_dataset`, concatenated + with `key_func_other_arguments` to a scalar value of type DT_INT64. +)doc"); + REGISTER_OP("FilterDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -304,6 +408,27 @@ padding_values: A list of scalars containing the padding value to use for each of the outputs. )doc"); +REGISTER_OP("DenseToSparseBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("row_shape: int64") + .Output("handle: variant") + // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. + .Attr("output_types: list(type) >= 1") + // NOTE(mrry): the 1st and 2nd elements will be vectors. + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that yields a SparseTensor for each element of the input. + +input_dataset: A handle to an input dataset. Must have a single component. +batch_size: A scalar representing the number of elements to accumulate in a + batch. +row_shape: A vector representing the dense shape of each row in the produced + SparseTensor. The shape may be partially specified, using `-1` to indicate + that a particular dimension should use the maximum size of all batch elements. +)doc"); + REGISTER_OP("RangeDataset") .Input("start: int64") .Input("stop: int64") @@ -389,6 +514,24 @@ compression_type: A scalar containing either (i) the empty string (no buffer_size: A scalar containing the number of bytes to buffer. )doc"); +REGISTER_OP("SqlDataset") + .Input("driver_name: string") + .Input("data_source_name: string") + .Input("query: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that executes a SQL query and emits rows of the result set. + +driver_name: The database type. Currently, the only supported type is 'sqlite'. +data_source_name: A connection string to connect to the database. +query: A SQL query to execute. +)doc"); + REGISTER_OP("FixedLengthRecordDataset") .Input("filenames: string") .Input("header_bytes: int64") @@ -519,6 +662,36 @@ REGISTER_OP("IteratorGetNext") Gets the next output from the given iterator. )doc"); +REGISTER_OP("DatasetToSingleElement") + .Input("dataset: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast(i), output_shape_handle); + } + return Status::OK(); + }) + .Doc(R"doc( +Outputs the single element from the given dataset. + +dataset: A handle to a dataset that contains a single element. +components: The components of the single element of `input`. +)doc"); + REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") .Output("string_handle: string") @@ -547,4 +720,28 @@ output_shapes: If specified, defines the shape of each tuple component in an element produced by the resulting iterator. )doc"); +REGISTER_OP("SerializeIterator") + .Input("resource_handle: resource") + .Output("serialized: variant") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Converts the given `resource_handle` representing an iterator to a variant tensor. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +REGISTER_OP("DeserializeIterator") + .Input("resource_handle: resource") + .Input("serialized: variant") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Converts the given variant tensor to an iterator and stores it in the given resource. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py index 60a44b5b14..2128ef4ae1 100644 --- a/tensorflow/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/kernel_tests/iterator_ops_test.py @@ -17,12 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +33,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -533,6 +537,64 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) + def testIncorrectIteratorRestore(self): + + def _path(): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def _build_range_dataset_graph(): + start = 1 + stop = 10 + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + def _build_reader_dataset_graph(): + filenames = ["test"] # Does not exist but we don't care in this test. + iterator = readers.FixedLengthRecordDataset( + filenames, 1, 0, 0).make_initializable_iterator() + init_op = iterator.initializer + get_next_op = iterator.get_next() + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) + return init_op, get_next_op, save_op, restore_op + + # Saving iterator for RangeDataset graph. + with ops.Graph().as_default() as g: + init_op, _, save_op, _ = _build_range_dataset_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(save_op) + + # Attempt to restore the saved iterator into an IteratorResource of + # incompatible type. An iterator of RangeDataset has output type int64, + # while an iterator of FixedLengthRecordDataset has output type string. + # So an InvalidArgumentError should be raised by + # IteratorResource::set_iterator. + with ops.Graph().as_default() as g: + _, _, _, restore_op = _build_reader_dataset_graph() + with self.test_session(graph=g) as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(restore_op) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py index 3c1685c951..0c530522b8 100644 --- a/tensorflow/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py @@ -17,15 +17,32 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import test class RangeDatasetTest(test.TestCase): + def tearDown(self): + # Remove all checkpoint files. + prefix = self._iterator_checkpoint_prefix() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + def testStop(self): stop = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() @@ -151,6 +168,319 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def _iterator_checkpoint_prefix(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def testSaveRestore(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Saving and restoring in same session. + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreWithoutBuildingDatasetGraph(self): + + def _build_graph(start, stop, num_epochs): + dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + num_epochs = 5 + break_point = 5 + break_epoch = 3 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_epoch): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + # Create an empty IteratorResource and restore the Iterator into it. + output_types = dtypes.int64 + output_shapes = tensor_shape.scalar() + iterator = iterator_ops.Iterator.from_structure(output_types, + output_shapes) + restore_op = self._restore_op(iterator._iterator_resource) + get_next = iterator.get_next() + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + for _ in range(break_epoch + 1, num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreInModifiedGraph(self): + + def _build_graph(start, stop): + dataset = dataset_ops.Dataset.range(start, stop) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + stop_1 = 8 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + # Intentionally build a graph with a different value for stop to make sure + # the original dataset graph is actually getting loaded. + init_op, get_next, _, restore_op = _build_graph(start, stop_1) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testInitThenRestore(self): + # Note: Calling init_op before restore_op is redundant. This test just makes + # sure we do not fail if restore is called on an already initialized + # iterator resource. + + def _build_graph(start, stop): + dataset = dataset_ops.Dataset.range(start, stop) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testMultipleSaves(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + break_point1 = 5 + break_point2 = 7 + + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point1): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point1, break_point2): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + break_point2 = 7 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point2, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreWithRepeat(self): + + def _build_graph(start, stop, num_epochs): + iterator = dataset_ops.Dataset.range( + start, stop).repeat(num_epochs).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + num_epochs = 5 + break_range = 5 + break_epoch = 3 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph( + start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(break_epoch - 1): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + for i in range(start, break_range): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_range, stop): + self.assertEqual(i, sess.run(get_next)) + for _ in range(break_epoch, num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreExhaustedIterator(self): + + def _build_graph(start, stop, num_epochs): + iterator = dataset_ops.Dataset.range( + start, stop).repeat(num_epochs).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + start = 2 + stop = 10 + num_epochs = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, restore_op = _build_graph( + start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py index 70b6ce442e..c8e7333b4b 100644 --- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py @@ -26,8 +26,13 @@ from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -267,6 +272,299 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def _iterator_checkpoint_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def _build_iterator_graph(self, num_epochs): + filenames = self._createFiles() + dataset = (readers.FixedLengthRecordDataset( + filenames, self._record_bytes, self._header_bytes, self._footer_bytes) + .repeat(num_epochs)) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next_op = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next_op, save_op, restore_op + + def _restore_iterator(self): + output_types = dtypes.string + output_shapes = tensor_shape.scalar() + iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) + get_next = iterator.get_next() + restore_op = self._restore_op(iterator._iterator_resource) + return restore_op, get_next + + def testSaveRestore(self): + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testInitThenRestore(self): + # Note: Calling init_op before restore_op is redundant. This test just makes + # sure we do not fail if restore is called on an already initialized + # iterator resource. + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreInModifiedGraph(self): + num_epochs = 10 + num_epochs_1 = 20 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs_1) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreWithoutBuildingDatasetGraph(self): + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + restore_op, get_next_op = self._restore_iterator() + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreUnusedIterator(self): + num_epochs = 10 + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + # Save unused iterator. + sess.run(save_op) + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for _ in range(num_epochs * self._num_files * self._num_records): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreExhaustedIterator(self): + num_epochs = 10 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for _ in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + class TFRecordDatasetTest(test.TestCase): -- cgit v1.2.3