diff options
23 files changed, 292 insertions, 1366 deletions
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f978c8ccd5..5b62598aa5 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -70,6 +70,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4a61ed7a35..03c168795c 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -81,6 +81,7 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/t GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 7636e9ba6e..a14b733158 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -776,6 +776,8 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index eaede0e00e..7bcf5a5f4d 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -35,8 +35,19 @@ tf_custom_op_library( ], ) +# TODO(mrry): Move the kernels out of the core library into this library. +tf_custom_op_library( + name = "_dataset_ops.so", + srcs = [ + "ops/dataset_ops.cc", + ], +) + tf_gen_op_libs( - op_lib_names = ["prefetching_ops"], + op_lib_names = [ + "dataset_ops", + "prefetching_ops", + ], ) filegroup( diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 824ac4298f..0c7e793689 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -41,8 +41,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import +# pylint: disable=unused-import from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch from tensorflow.contrib.data.python.ops.batching import unbatch diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc new file mode 100644 index 0000000000..1574384cb2 --- /dev/null +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -0,0 +1,232 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +// -------------------------------------------------------------------------- + +// The ops in this section can be composed to define an input +// pipeline. Each op produces a DT_VARIANT tensor that represents +// a DAG of "dataset" objects. An "dataset" object can be converted +// to a stateful "iterator" by passing the "dataset" to the +// "MakeIterator" op. +// +// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are +// not presently serializable. To avoid issues with constant folding, ensure +// that any "source dataset" ops (i.e. ops that output a dataset and do not +// take one as input) are marked "stateful". + +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + +REGISTER_OP("MapAndBatchDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("batch_size: int64") + .Input("num_parallel_batches: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset` and then +batches `batch_size` of them. + +Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up +to `batch_size * num_parallel_batches` copies of `f` in parallel. + +batch_size: A scalar representing the number of elements to accumulate in a + batch. It determines the number of concurrent invocations of `f` that process + elements from `input_dataset` in parallel. +num_parallel_batches: A scalar representing the number of batches to create in + parallel. Processing multiple batches in parallel benefits workloads prone to + stragglers. +)doc"); + +REGISTER_OP("ScanDataset") + .Input("input_dataset: variant") + .Input("initial_state: Tstate") + .Input("other_arguments: Targuments") + .Output("handle: variant") + .Attr("f: func") + .Attr("Tstate: list(type) >= 1") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset successively reduces `f` over the elements of `input_dataset`. +)doc"); + +REGISTER_OP("ParallelInterleaveDataset") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("sloppy: bool") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that applies `f` to the outputs of `input_dataset`. + +The resulting dataset is similar to the `InterleaveDataset`, with the exception +that if retrieving the next value from a dataset would cause the requester to +block, it will skip that input dataset. This dataset is especially useful +when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it +allows the training step to proceed so long as some data is available. + +!! WARNING !! This dataset is not deterministic! + +f: A function mapping elements of `input_dataset`, concatenated with + `other_arguments`, to a Dataset variant that contains elements matching + `output_types` and `output_shapes`. +)doc"); + +REGISTER_OP("GroupByWindowDataset") + .Input("input_dataset: variant") + .Input("key_func_other_arguments: Tkey_func_other_arguments") + .Input("reduce_func_other_arguments: Treduce_func_other_arguments") + .Input( + "window_size_func_other_arguments: Twindow_size_func_other_arguments") + .Output("handle: variant") + .Attr("key_func: func") + .Attr("reduce_func: func") + .Attr("window_size_func: func") + .Attr("Tkey_func_other_arguments: list(type) >= 0") + .Attr("Treduce_func_other_arguments: list(type) >= 0") + .Attr("Twindow_size_func_other_arguments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that computes a windowed group-by on `input_dataset`. + +// TODO(mrry): Support non-int64 keys. + +key_func: A function mapping an element of `input_dataset`, concatenated + with `key_func_other_arguments` to a scalar value of type DT_INT64. +)doc"); + +REGISTER_OP("DenseToSparseBatchDataset") + .Input("input_dataset: variant") + .Input("batch_size: int64") + .Input("row_shape: int64") + .Output("handle: variant") + // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. + .Attr("output_types: list(type) >= 1") + // NOTE(mrry): the 1st and 2nd elements will be vectors. + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that yields a SparseTensor for each element of the input. + +input_dataset: A handle to an input dataset. Must have a single component. +batch_size: A scalar representing the number of elements to accumulate in a + batch. +row_shape: A vector representing the dense shape of each row in the produced + SparseTensor. The shape may be partially specified, using `-1` to indicate + that a particular dimension should use the maximum size of all batch elements. +)doc"); + +REGISTER_OP("SqlDataset") + .Input("driver_name: string") + .Input("data_source_name: string") + .Input("query: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that executes a SQL query and emits rows of the result set. + +driver_name: The database type. Currently, the only supported type is 'sqlite'. +data_source_name: A connection string to connect to the database. +query: A SQL query to execute. +)doc"); + +REGISTER_OP("DatasetToSingleElement") + .Input("dataset: variant") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); + }) + .Doc(R"doc( +Outputs the single element from the given dataset. + +dataset: A handle to a dataset that contains a single element. +components: The components of the single element of `input`. +)doc"); + +REGISTER_OP("SerializeIterator") + .Input("resource_handle: resource") + .Output("serialized: variant") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Converts the given `resource_handle` representing an iterator to a variant tensor. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +REGISTER_OP("DeserializeIterator") + .Input("resource_handle: resource") + .Input("serialized: variant") + .SetShapeFn(shape_inference::NoOutputs) + .Doc(R"doc( +Converts the given variant tensor to an iterator and stores it in the given resource. + +resource_handle: A handle to an iterator resource. +serialized: A variant tensor storing the state of the iterator contained in the + resource. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index bda9a2a4a3..271d80a54b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -21,6 +21,7 @@ import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session @@ -33,7 +34,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index f59ac760dc..329dc80ba5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -21,6 +21,7 @@ import os from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -29,7 +30,6 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 3ae8f71d77..8033f1d388 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,6 +21,7 @@ import gzip import os import zlib +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 @@ -33,7 +34,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 1b81cf5be9..727c5d1c38 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -12,20 +12,6 @@ load( load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") py_library( - name = "dataset_ops", - srcs = [ - "dataset_ops.py", - ], - srcs_version = "PY2AND3", - deps = [ - ":transformation_ops", - "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - ], -) - -py_library( name = "iterator_ops", srcs = [ "iterator_ops.py", @@ -73,6 +59,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -128,6 +115,31 @@ tf_custom_op_py_library( ], ) +tf_gen_op_wrapper_py( + name = "gen_dataset_ops", + out = "gen_dataset_ops.py", + deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "dataset_ops", + srcs = ["dataset_ops.py"], + dso = ["//tensorflow/contrib/data:_dataset_ops.so"], + kernels = [ + "//tensorflow/contrib/data:dataset_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_dataset_ops", + ":transformation_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index abc9212a87..e6e5f716b6 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -24,7 +25,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import math_ops diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 45d6dbe743..c4c4426809 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -20,15 +20,21 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import enumerate_ops from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import grouping +from tensorflow.contrib.util import loader from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops +from tensorflow.python.platform import resource_loader from tensorflow.python.util import deprecation +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) + + class Dataset(dataset_ops.Dataset): """Represents a potentially large set of elements. diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 238bb52b02..51a2791072 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,9 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index 6df7b22fb6..1c7c94b3c8 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops def group_by_window(key_func, diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 74a919c1ff..ce23e95697 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -17,12 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py index d736029fb0..32d2f42c93 100644 --- a/tensorflow/contrib/data/python/ops/iterator_ops.py +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -17,8 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.training import saver diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index 2e1c3153ca..f22298b757 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest @@ -25,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 5acaed48a3..87bbbb7d19 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -19,11 +19,11 @@ from __future__ import print_function import collections +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_dataset_ops class _ScanDataset(dataset_ops.Dataset): diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index a4b5ca16af..8b8251f84b 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -8271,29 +8271,6 @@ op { } } op { - name: "DatasetToSingleElement" - input_arg { - name: "dataset" - type: DT_VARIANT - } - output_arg { - name: "components" - type_list_attr: "output_types" - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "DebugGradientIdentity" input_arg { name: "input" @@ -9272,69 +9249,6 @@ op { } } op { - name: "DenseToSparseBatchDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "batch_size" - type: DT_INT64 - } - input_arg { - name: "row_shape" - type: DT_INT64 - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} -op { - name: "DenseToSparseBatchDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "batch_size" - type: DT_INT64 - } - input_arg { - name: "row_shape" - type: DT_INT64 - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "DenseToSparseSetOperation" input_arg { name: "set1" @@ -9828,18 +9742,6 @@ op { } } op { - name: "DeserializeIterator" - input_arg { - name: "resource_handle" - type: DT_RESOURCE - } - input_arg { - name: "serialized" - type: DT_VARIANT - } - is_stateful: true -} -op { name: "DeserializeManySparse" input_arg { name: "serialized_sparse" @@ -13593,131 +13495,6 @@ op { } } op { - name: "GroupByWindowDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "key_func_other_arguments" - type_list_attr: "Tkey_func_other_arguments" - } - input_arg { - name: "reduce_func_other_arguments" - type_list_attr: "Treduce_func_other_arguments" - } - input_arg { - name: "window_size_func_other_arguments" - type_list_attr: "Twindow_size_func_other_arguments" - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "key_func" - type: "func" - } - attr { - name: "reduce_func" - type: "func" - } - attr { - name: "window_size_func" - type: "func" - } - attr { - name: "Tkey_func_other_arguments" - type: "list(type)" - has_minimum: true - } - attr { - name: "Treduce_func_other_arguments" - type: "list(type)" - has_minimum: true - } - attr { - name: "Twindow_size_func_other_arguments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} -op { - name: "GroupByWindowDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "key_func_other_arguments" - type_list_attr: "Tkey_func_other_arguments" - } - input_arg { - name: "reduce_func_other_arguments" - type_list_attr: "Treduce_func_other_arguments" - } - input_arg { - name: "window_size_func_other_arguments" - type_list_attr: "Twindow_size_func_other_arguments" - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "key_func" - type: "func" - } - attr { - name: "reduce_func" - type: "func" - } - attr { - name: "window_size_func" - type: "func" - } - attr { - name: "Tkey_func_other_arguments" - type: "list(type)" - has_minimum: true - } - attr { - name: "Treduce_func_other_arguments" - type: "list(type)" - has_minimum: true - } - attr { - name: "Twindow_size_func_other_arguments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "HSVToRGB" input_arg { name: "images" @@ -14138,53 +13915,6 @@ op { } } op { - name: "IgnoreErrorsDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} -op { - name: "IgnoreErrorsDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "Imag" input_arg { name: "input" @@ -16089,50 +15819,6 @@ op { is_stateful: true } op { - name: "MapAndBatchDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "other_arguments" - type_list_attr: "Targuments" - } - input_arg { - name: "batch_size" - type: DT_INT64 - } - input_arg { - name: "num_parallel_batches" - type: DT_INT64 - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "f" - type: "func" - } - attr { - name: "Targuments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "MapClear" attr { name: "capacity" @@ -20871,54 +20557,6 @@ op { } } op { - name: "ParallelInterleaveDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "other_arguments" - type_list_attr: "Targuments" - } - input_arg { - name: "cycle_length" - type: DT_INT64 - } - input_arg { - name: "block_length" - type: DT_INT64 - } - input_arg { - name: "sloppy" - type: DT_BOOL - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "f" - type: "func" - } - attr { - name: "Targuments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -30509,52 +30147,6 @@ op { } } op { - name: "ScanDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - input_arg { - name: "initial_state" - type_list_attr: "Tstate" - } - input_arg { - name: "other_arguments" - type_list_attr: "Targuments" - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "f" - type: "func" - } - attr { - name: "Tstate" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "Targuments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} -op { name: "ScatterAdd" input_arg { name: "ref" @@ -32270,18 +31862,6 @@ op { } } op { - name: "SerializeIterator" - input_arg { - name: "resource_handle" - type: DT_RESOURCE - } - output_arg { - name: "serialized" - type: DT_VARIANT - } - is_stateful: true -} -op { name: "SerializeManySparse" input_arg { name: "sparse_indices" @@ -37686,38 +37266,6 @@ op { } } op { - name: "SqlDataset" - input_arg { - name: "driver_name" - type: DT_STRING - } - input_arg { - name: "data_source_name" - type: DT_STRING - } - input_arg { - name: "query" - type: DT_STRING - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} -op { name: "Sqrt" input_arg { name: "x" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f512213964..8f5d8308a3 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -141,16 +141,6 @@ count: A scalar representing the number of elements from the `input_dataset` that should be skipped. If count is -1, skips everything. )doc"); -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the elements of `input_dataset` ignoring errors. -)doc"); - REGISTER_OP("MapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -184,32 +174,6 @@ num_parallel_calls: The number of concurrent invocations of `f` that process elements from `input_dataset` in parallel. )doc"); -REGISTER_OP("MapAndBatchDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("batch_size: int64") - .Input("num_parallel_batches: int64") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset` and then -batches `batch_size` of them. - -Unlike a "MapDataset", which applies `f` sequentially, this dataset invokes up -to `batch_size * num_parallel_batches` copies of `f` in parallel. - -batch_size: A scalar representing the number of elements to accumulate in a - batch. It determines the number of concurrent invocations of `f` that process - elements from `input_dataset` in parallel. -num_parallel_batches: A scalar representing the number of batches to create in - parallel. Processing multiple batches in parallel benefits workloads prone to - stragglers. -)doc"); - REGISTER_OP("PrefetchDataset") .Input("input_dataset: variant") .Input("buffer_size: int64") @@ -224,21 +188,6 @@ buffer_size: The maximum number of elements to buffer in an iterator over this dataset. )doc"); -REGISTER_OP("ScanDataset") - .Input("input_dataset: variant") - .Input("initial_state: Tstate") - .Input("other_arguments: Targuments") - .Output("handle: variant") - .Attr("f: func") - .Attr("Tstate: list(type) >= 1") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset successively reduces `f` over the elements of `input_dataset`. -)doc"); - REGISTER_OP("FlatMapDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -285,59 +234,6 @@ f: A function mapping elements of `input_dataset`, concatenated with `output_types` and `output_shapes`. )doc"); -REGISTER_OP("ParallelInterleaveDataset") - .Input("input_dataset: variant") - .Input("other_arguments: Targuments") - .Input("cycle_length: int64") - .Input("block_length: int64") - .Input("sloppy: bool") - .Output("handle: variant") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that applies `f` to the outputs of `input_dataset`. - -The resulting dataset is similar to the `InterleaveDataset`, with the exception -that if retrieving the next value from a dataset would cause the requester to -block, it will skip that input dataset. This dataset is especially useful -when loading data from a variable-latency datastores (e.g. HDFS, GCS), as it -allows the training step to proceed so long as some data is available. - -!! WARNING !! This dataset is not deterministic! - -f: A function mapping elements of `input_dataset`, concatenated with - `other_arguments`, to a Dataset variant that contains elements matching - `output_types` and `output_shapes`. -)doc"); - -REGISTER_OP("GroupByWindowDataset") - .Input("input_dataset: variant") - .Input("key_func_other_arguments: Tkey_func_other_arguments") - .Input("reduce_func_other_arguments: Treduce_func_other_arguments") - .Input( - "window_size_func_other_arguments: Twindow_size_func_other_arguments") - .Output("handle: variant") - .Attr("key_func: func") - .Attr("reduce_func: func") - .Attr("window_size_func: func") - .Attr("Tkey_func_other_arguments: list(type) >= 0") - .Attr("Treduce_func_other_arguments: list(type) >= 0") - .Attr("Twindow_size_func_other_arguments: list(type) >= 0") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that computes a windowed group-by on `input_dataset`. - -// TODO(mrry): Support non-int64 keys. - -key_func: A function mapping an element of `input_dataset`, concatenated - with `key_func_other_arguments` to a scalar value of type DT_INT64. -)doc"); - REGISTER_OP("FilterDataset") .Input("input_dataset: variant") .Input("other_arguments: Targuments") @@ -408,27 +304,6 @@ padding_values: A list of scalars containing the padding value to use for each of the outputs. )doc"); -REGISTER_OP("DenseToSparseBatchDataset") - .Input("input_dataset: variant") - .Input("batch_size: int64") - .Input("row_shape: int64") - .Output("handle: variant") - // NOTE(mrry): the 0th and 2nd elements will be DT_INT64. - .Attr("output_types: list(type) >= 1") - // NOTE(mrry): the 1st and 2nd elements will be vectors. - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that yields a SparseTensor for each element of the input. - -input_dataset: A handle to an input dataset. Must have a single component. -batch_size: A scalar representing the number of elements to accumulate in a - batch. -row_shape: A vector representing the dense shape of each row in the produced - SparseTensor. The shape may be partially specified, using `-1` to indicate - that a particular dimension should use the maximum size of all batch elements. -)doc"); - REGISTER_OP("RangeDataset") .Input("start: int64") .Input("stop: int64") @@ -514,24 +389,6 @@ compression_type: A scalar containing either (i) the empty string (no buffer_size: A scalar containing the number of bytes to buffer. )doc"); -REGISTER_OP("SqlDataset") - .Input("driver_name: string") - .Input("data_source_name: string") - .Input("query: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that executes a SQL query and emits rows of the result set. - -driver_name: The database type. Currently, the only supported type is 'sqlite'. -data_source_name: A connection string to connect to the database. -query: A SQL query to execute. -)doc"); - REGISTER_OP("FixedLengthRecordDataset") .Input("filenames: string") .Input("header_bytes: int64") @@ -662,36 +519,6 @@ REGISTER_OP("IteratorGetNext") Gets the next output from the given iterator. )doc"); -REGISTER_OP("DatasetToSingleElement") - .Input("dataset: variant") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector<PartialTensorShape> output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast<int>(i), output_shape_handle); - } - return Status::OK(); - }) - .Doc(R"doc( -Outputs the single element from the given dataset. - -dataset: A handle to a dataset that contains a single element. -components: The components of the single element of `input`. -)doc"); - REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") .Output("string_handle: string") @@ -720,28 +547,4 @@ output_shapes: If specified, defines the shape of each tuple component in an element produced by the resulting iterator. )doc"); -REGISTER_OP("SerializeIterator") - .Input("resource_handle: resource") - .Output("serialized: variant") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Converts the given `resource_handle` representing an iterator to a variant tensor. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - -REGISTER_OP("DeserializeIterator") - .Input("resource_handle: resource") - .Input("serialized: variant") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Converts the given variant tensor to an iterator and stores it in the given resource. - -resource_handle: A handle to an iterator resource. -serialized: A variant tensor storing the state of the iterator contained in the - resource. -)doc"); - } // namespace tensorflow diff --git a/tensorflow/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/kernel_tests/iterator_ops_test.py index 2128ef4ae1..60a44b5b14 100644 --- a/tensorflow/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/python/kernel_tests/iterator_ops_test.py @@ -17,14 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import numpy as np from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,9 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -537,64 +533,6 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) - def testIncorrectIteratorRestore(self): - - def _path(): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - _path(), parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_range_dataset_graph(): - start = 1 - stop = 10 - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - def _build_reader_dataset_graph(): - filenames = ["test"] # Does not exist but we don't care in this test. - iterator = readers.FixedLengthRecordDataset( - filenames, 1, 0, 0).make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - # Saving iterator for RangeDataset graph. - with ops.Graph().as_default() as g: - init_op, _, save_op, _ = _build_range_dataset_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(save_op) - - # Attempt to restore the saved iterator into an IteratorResource of - # incompatible type. An iterator of RangeDataset has output type int64, - # while an iterator of FixedLengthRecordDataset has output type string. - # So an InvalidArgumentError should be raised by - # IteratorResource::set_iterator. - with ops.Graph().as_default() as g: - _, _, _, restore_op = _build_reader_dataset_graph() - with self.test_session(graph=g) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(restore_op) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/kernel_tests/range_dataset_op_test.py index 0c530522b8..3c1685c951 100644 --- a/tensorflow/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/python/kernel_tests/range_dataset_op_test.py @@ -17,32 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os - from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test class RangeDatasetTest(test.TestCase): - def tearDown(self): - # Remove all checkpoint files. - prefix = self._iterator_checkpoint_prefix() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - def testStop(self): stop = array_ops.placeholder(dtypes.int64, shape=[]) iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() @@ -168,319 +151,6 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def _iterator_checkpoint_prefix(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_prefix(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def testSaveRestore(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Saving and restoring in same session. - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreWithoutBuildingDatasetGraph(self): - - def _build_graph(start, stop, num_epochs): - dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - num_epochs = 5 - break_point = 5 - break_epoch = 3 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for _ in range(break_epoch): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - # Create an empty IteratorResource and restore the Iterator into it. - output_types = dtypes.int64 - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, - output_shapes) - restore_op = self._restore_op(iterator._iterator_resource) - get_next = iterator.get_next() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - for _ in range(break_epoch + 1, num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreInModifiedGraph(self): - - def _build_graph(start, stop): - dataset = dataset_ops.Dataset.range(start, stop) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - stop_1 = 8 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - # Intentionally build a graph with a different value for stop to make sure - # the original dataset graph is actually getting loaded. - init_op, get_next, _, restore_op = _build_graph(start, stop_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - - def _build_graph(start, stop): - dataset = dataset_ops.Dataset.range(start, stop) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMultipleSaves(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - break_point1 = 5 - break_point2 = 7 - - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point1): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point1, break_point2): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - break_point2 = 7 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point2, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreWithRepeat(self): - - def _build_graph(start, stop, num_epochs): - iterator = dataset_ops.Dataset.range( - start, stop).repeat(num_epochs).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - num_epochs = 5 - break_range = 5 - break_epoch = 3 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph( - start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(break_epoch - 1): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - for i in range(start, break_range): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_range, stop): - self.assertEqual(i, sess.run(get_next)) - for _ in range(break_epoch, num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreExhaustedIterator(self): - - def _build_graph(start, stop, num_epochs): - iterator = dataset_ops.Dataset.range( - start, stop).repeat(num_epochs).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - num_epochs = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph( - start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py index c8e7333b4b..70b6ce442e 100644 --- a/tensorflow/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_dataset_ops_test.py @@ -26,13 +26,8 @@ from tensorflow.python.data.ops import readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -272,299 +267,6 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) - def _iterator_checkpoint_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(self, iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - self._iterator_checkpoint_path(), - parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(self, iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_iterator_graph(self, num_epochs): - filenames = self._createFiles() - dataset = (readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, self._footer_bytes) - .repeat(num_epochs)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - def _restore_iterator(self): - output_types = dtypes.string - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) - get_next = iterator.get_next() - restore_op = self._restore_op(iterator._iterator_resource) - return restore_op, get_next - - def testSaveRestore(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreInModifiedGraph(self): - num_epochs = 10 - num_epochs_1 = 20 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreWithoutBuildingDatasetGraph(self): - num_epochs = 10 - epoch_break = 5 - file_break = self._num_files // 2 - record_break = self._num_records // 2 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch == epoch_break and f == file_break and - r == record_break): - sess.run(save_op) - break - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - else: - continue - break - else: - continue - break - else: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - restore_op, get_next_op = self._restore_iterator() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for epoch in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - if (epoch < epoch_break or - (epoch == epoch_break and f < file_break) or - (epoch == epoch_break and f == file_break and - r < record_break)): - continue - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreUnusedIterator(self): - num_epochs = 10 - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - # Save unused iterator. - sess.run(save_op) - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for _ in range(num_epochs * self._num_files * self._num_records): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testRestoreExhaustedIterator(self): - num_epochs = 10 - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for f in range(self._num_files): - for r in range(self._num_records): - self.assertEqual(self._record(f, r), sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( - num_epochs=num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - class TFRecordDatasetTest(test.TestCase): |