diff options
Diffstat (limited to 'tensorflow/contrib')
113 files changed, 958 insertions, 4380 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index ae5ca32bcf..98dff965a9 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -112,26 +112,14 @@ py_library( "//tensorflow/python:util", "//tensorflow/python/estimator:estimator_py", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + select({ - "//tensorflow:with_kafka_support_windows_override": [], - "//tensorflow:with_kafka_support": [ - "//tensorflow/contrib/kafka", - ], - "//conditions:default": [], - }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support": [ - "//tensorflow/contrib/kinesis", - ], - "//conditions:default": [], - }) + if_not_windows_cuda([ - "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols - ]) + if_not_windows([ - ]) + select({ "//tensorflow:linux_s390x": [], "//tensorflow:windows": [], "//conditions:default": [ "//tensorflow/contrib/bigtable", "//tensorflow/contrib/cloud:cloud_py", + "//tensorflow/contrib/fused_conv:fused_conv_py", # unresolved symbols, need to export more symbols + "//tensorflow/contrib/kafka", + "//tensorflow/contrib/kinesis", "//tensorflow/contrib/tensorrt:init_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", ], @@ -144,7 +132,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", "//tensorflow/contrib/coder:all_kernels", - "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/hadoop:dataset_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", @@ -159,20 +146,14 @@ cc_library( ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ "//tensorflow/contrib/nccl:nccl_kernels", ]) + select({ - "//tensorflow:with_kafka_support_windows_override": [], - "//tensorflow:with_kafka_support": [ + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_kernels", - ], - "//conditions:default": [], - }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support": [ "//tensorflow/contrib/kinesis:dataset_kernels", + "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", ], - "//conditions:default": [], - }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_kernel", - ]), + }), ) cc_library( @@ -181,8 +162,6 @@ cc_library( deps = [ "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", "//tensorflow/contrib/coder:all_ops", - "//tensorflow/contrib/data:dataset_ops_op_lib", - "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/hadoop:dataset_ops_op_lib", @@ -198,18 +177,12 @@ cc_library( "//tensorflow/contrib/text:all_ops", "//tensorflow/contrib/tpu:all_ops", ] + select({ - "//tensorflow:with_kafka_support_windows_override": [], - "//tensorflow:with_kafka_support": [ + "//tensorflow:linux_s390x": [], + "//tensorflow:windows": [], + "//conditions:default": [ "//tensorflow/contrib/kafka:dataset_ops_op_lib", - ], - "//conditions:default": [], - }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support": [ "//tensorflow/contrib/kinesis:dataset_ops_op_lib", + "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", ], - "//conditions:default": [], - }) + if_not_windows([ - "//tensorflow/contrib/tensorrt:trt_engine_op_op_lib", - ]), + }), ) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py index 6b6fe9663a..839eedd3a8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -188,9 +188,8 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): # Train for a few steps. est.train(input_fn=_train_input_fn, steps=1000) - # 10 steps for dnn + 3 for 1 tree of depth 3 + 1 after the tree finished - # + 1 for resource variables. - self._assert_checkpoint(est.model_dir, global_step=15) + # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished + self._assert_checkpoint(est.model_dir, global_step=14) res = est.evaluate(input_fn=_eval_input_fn, steps=1) self.assertLess(0.5, res["auc"]) est.predict(input_fn=_eval_input_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index d7b14e00ba..c155128c0e 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -238,8 +238,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase): output_leaf_index=False) classifier.fit(input_fn=_train_input_fn, steps=15) - # When no override of global steps, 6 steps were used. - self._assert_checkpoint(classifier.model_dir, global_step=6) + # When no override of global steps, 5 steps were used. + self._assert_checkpoint(classifier.model_dir, global_step=5) def testOverridesGlobalSteps(self): learner_config = learner_pb2.LearnerConfig() diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index c7eb2493a8..8531e97f90 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._num_quantiles = num_quantiles - self._max_tree_depth = variables.Variable( + self._max_tree_depth = variables.VariableV1( initial_value=self._learner_config.constraints.max_tree_depth) - self._attempted_trees = variables.Variable( + self._attempted_trees = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, name="attempted_trees") - self._finalized_trees = variables.Variable( + self._finalized_trees = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), trainable=False, name="finalized_trees") @@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object): fc_name_idx += 1 # Create ensemble stats variables. - num_layer_examples = variables.Variable( + num_layer_examples = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="num_layer_examples", trainable=False) - num_layer_steps = variables.Variable( + num_layer_steps = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="num_layer_steps", trainable=False) - num_layers = variables.Variable( + num_layers = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="num_layers", trainable=False) - active_tree = variables.Variable( + active_tree = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="active_tree", trainable=False) - active_layer = variables.Variable( + active_layer = variables.VariableV1( initial_value=array_ops.zeros([], dtypes.int64), name="active_layer", trainable=False) # Variable that becomes false once bias centering is done. - continue_centering = variables.Variable( + continue_centering = variables.VariableV1( initial_value=self._center_bias, name="continue_centering", trainable=False) diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 9d9941f696..6d20a2e7f4 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase): weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase): weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase): weights = array_ops.ones([batch_size, 1], dtypes.float32) partition_ids = array_ops.zeros([batch_size], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, @@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase): predictions = array_ops.constant( [[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32) partition_ids = array_ops.zeros([4], dtypes.int32) - ensemble_stamp = variables.Variable( + ensemble_stamp = variables.VariableV1( initial_value=0, name="ensemble_stamp", trainable=False, diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index c0763f4c0e..2975b167ec 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -132,7 +132,6 @@ tensorflow/contrib/cudnn_rnn/python tensorflow/contrib/cudnn_rnn/python/layers tensorflow/contrib/cudnn_rnn/python/ops tensorflow/contrib/data -tensorflow/contrib/data/kernels tensorflow/contrib/data/python tensorflow/contrib/data/python/kernel_tests tensorflow/contrib/data/python/kernel_tests/serialization diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 9f710613dd..38f1c65a4d 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load( - "//tensorflow:tensorflow.bzl", - "tf_custom_op_library", - "tf_gen_op_libs", - "if_not_windows", -) -load( - "//tensorflow/core:platform/default/build_config_root.bzl", - "if_static", -) - py_library( name = "data", srcs = ["__init__.py"], @@ -25,30 +14,3 @@ py_library( "//tensorflow/python:util", ], ) - -cc_library( - name = "lib_proto_parsing_for_dataset_ops", - deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]), -) - -tf_custom_op_library( - name = "_dataset_ops.so", - srcs = [ - "ops/dataset_ops.cc", - "ops/indexed_dataset_ops.cc", - ], - deps = [ - "//tensorflow/contrib/data/kernels:dataset_kernels", - "//tensorflow/contrib/data/kernels:indexed_dataset", - ] + if_static( - extra_deps = [":lib_proto_parsing_for_dataset_ops"], - otherwise = [], - ), -) - -tf_gen_op_libs( - op_lib_names = [ - "dataset_ops", - "indexed_dataset_ops", - ], -) diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD deleted file mode 100644 index ec6cb37193..0000000000 --- a/tensorflow/contrib/data/kernels/BUILD +++ /dev/null @@ -1,139 +0,0 @@ -# Description: -# Contains kernels for datasets and iterators. -package(default_visibility = ["//tensorflow:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "indexed_dataset_headers", - hdrs = ["indexed_dataset.h"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], -) - -cc_library( - name = "indexed_dataset", - srcs = [ - "identity_indexed_dataset.cc", - "indexed_dataset.cc", - ], - deps = [ - ":indexed_dataset_headers", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "prefetching_kernels", - srcs = ["prefetching_kernels.cc"], - deps = [ - "//tensorflow/core:core_cpu_headers_lib", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "directed_interleave_dataset_op", - srcs = ["directed_interleave_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "csv_dataset_op", - srcs = ["csv_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "ignore_errors_dataset_op", - srcs = ["ignore_errors_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "lmdb_dataset_op", - srcs = ["lmdb_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@lmdb", - "@protobuf_archive//:protobuf_headers", - ], -) - -cc_library( - name = "threadpool_dataset_op", - srcs = ["threadpool_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "unique_dataset_op", - srcs = ["unique_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "assert_next_dataset_op", - srcs = ["assert_next_dataset_op.cc"], - deps = [ - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], - alwayslink = 1, -) - -cc_library( - name = "dataset_kernels", - deps = [ - ":assert_next_dataset_op", - ":csv_dataset_op", - ":directed_interleave_dataset_op", - ":ignore_errors_dataset_op", - ":indexed_dataset", - ":lmdb_dataset_op", - ":prefetching_kernels", - ":threadpool_dataset_op", - ":unique_dataset_op", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - "@protobuf_archive//:protobuf_headers", - ], -) diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc deleted file mode 100644 index c19a609780..0000000000 --- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include <map> - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/tensor.h" - -namespace tensorflow { -namespace data { -namespace { - -// See documentation in ../ops/dataset_ops.cc for a high-level -// description of the following op. -class AssertNextDatasetOp : public UnaryDatasetOpKernel { - public: - explicit AssertNextDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } - - protected: - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - std::vector<string> transformations; - OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations", - &transformations)); - *output = - new Dataset(ctx, input, transformations, output_types_, output_shapes_); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - const std::vector<string>& transformations, - const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes) - : DatasetBase(DatasetContext(ctx)), - input_(input), - transformations_(transformations), - output_types_(output_types), - output_shapes_(output_shapes) { - input_->Ref(); - } - - ~Dataset() override { input_->Unref(); } - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Assert")})); - } - - const DataTypeVector& output_dtypes() const override { - return output_types_; - } - const std::vector<PartialTensorShape>& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return "AssertNextDatasetOp::Dataset"; - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* transformations_node = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); - TF_RETURN_IF_ERROR(b->AddDataset( - this, {input_graph_node, transformations_node}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - Status Initialize(IteratorContext* ctx) override { - std::vector<string> tokens = - str_util::Split(prefix(), ':', str_util::SkipEmpty()); - if (dataset()->transformations_.size() > tokens.size() - 2) { - return errors::InvalidArgument( - "Asserted next ", dataset()->transformations_.size(), - " transformations but encountered only ", tokens.size() - 2, "."); - } - int n = tokens.size(); - for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { - if (dataset()->transformations_[i] != tokens[n - 2 - i]) { - return errors::InvalidArgument( - "Asserted ", dataset()->transformations_[i], - " transformation at offset ", i, " but encountered ", - tokens[n - 2 - i], " transformation instead."); - } - } - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); - } - - private: - std::unique_ptr<IteratorBase> input_impl_; - }; - - const DatasetBase* input_; - const std::vector<string> transformations_; - const DataTypeVector output_types_; - const std::vector<PartialTensorShape> output_shapes_; - }; - - DataTypeVector output_types_; - std::vector<PartialTensorShape> output_shapes_; -}; - -REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU), - AssertNextDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc deleted file mode 100644 index 21ec50fb6b..0000000000 --- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc +++ /dev/null @@ -1,859 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// See docs in ../ops/parsing_ops.cc. -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/lib/io/inputstream_interface.h" -#include "tensorflow/core/lib/io/random_inputstream.h" -#include "tensorflow/core/lib/io/zlib_compression_options.h" -#include "tensorflow/core/lib/io/zlib_inputstream.h" - -namespace tensorflow { -namespace data { -namespace { - -class CSVDatasetOp : public DatasetOpKernel { - public: - explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } - - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); - - string compression_type; - OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type", - &compression_type)); - - OpInputList record_defaults_list; - OP_REQUIRES_OK(ctx, - ctx->input_list("record_defaults", &record_defaults_list)); - for (int i = 0; i < record_defaults_list.size(); ++i) { - OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, - errors::InvalidArgument( - "Each record default should be at most rank 1")); - OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, - errors::InvalidArgument( - "There should only be 1 default per field but field ", i, - " has ", record_defaults_list[i].NumElements())); - } - - const Tensor* select_cols_tensor; - OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); - OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, - errors::InvalidArgument("`select_cols` must be a vector.")); - - int64 buffer_size; - OP_REQUIRES_OK( - ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); - OP_REQUIRES(ctx, buffer_size > 0, - errors::InvalidArgument("buffer_size should be positive")); - - string delim; - OP_REQUIRES_OK(ctx, - ParseScalarArgument<string>(ctx, "field_delim", &delim)); - OP_REQUIRES(ctx, delim.size() == 1, - errors::InvalidArgument("field_delim should be only 1 char")); - - bool header; - OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header)); - - bool use_quote_delim; - OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim", - &use_quote_delim)); - string na_value; - OP_REQUIRES_OK(ctx, - ParseScalarArgument<string>(ctx, "na_value", &na_value)); - - std::vector<Tensor> record_defaults; - record_defaults.reserve(record_defaults_list.size()); - for (const Tensor& t : record_defaults_list) { - record_defaults.push_back(t); - } - - std::vector<string> filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat<string>()(i)); - } - - io::ZlibCompressionOptions zlib_compression_options = - io::ZlibCompressionOptions::DEFAULT(); - if (compression_type == "ZLIB") { - zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); - } else if (compression_type == "GZIP") { - zlib_compression_options = io::ZlibCompressionOptions::GZIP(); - } else { - OP_REQUIRES(ctx, compression_type.empty(), - errors::InvalidArgument( - "Unsupported compression_type: ", compression_type, ".")); - } - zlib_compression_options.input_buffer_size = buffer_size; - - std::vector<int64> select_cols; - select_cols.reserve(select_cols_tensor->NumElements()); - for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { - select_cols.push_back(select_cols_tensor->flat<int64>()(i)); - } - OP_REQUIRES( - ctx, output_types_.size() == select_cols.size() || select_cols.empty(), - errors::InvalidArgument("select_cols should match output size")); - for (int i = 1; i < select_cols.size(); i++) { - OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], - errors::InvalidArgument( - "select_cols should be strictly increasing indices")); - } - OP_REQUIRES( - ctx, select_cols.empty() || select_cols.front() >= 0, - errors::InvalidArgument("select_cols should be non-negative indices")); - - *output = new Dataset(ctx, std::move(filenames), header, - std::move(compression_type), zlib_compression_options, - output_types_, output_shapes_, - std::move(record_defaults), std::move(select_cols), - use_quote_delim, delim[0], std::move(na_value)); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header, - string compression_type, io::ZlibCompressionOptions options, - const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes, - std::vector<Tensor> record_defaults, std::vector<int64> select_cols, - bool use_quote_delim, char delim, string na_value) - : DatasetBase(DatasetContext(ctx)), - filenames_(std::move(filenames)), - header_(header), - out_type_(output_types), - output_shapes_(output_shapes), - record_defaults_(std::move(record_defaults)), - select_cols_(std::move(select_cols)), - use_quote_delim_(use_quote_delim), - delim_(delim), - na_value_(std::move(na_value)), - use_compression_(!compression_type.empty()), - compression_type_(std::move(compression_type)), - options_(options) {} - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::CSV")})); - } - - const DataTypeVector& output_dtypes() const override { return out_type_; } - - const std::vector<PartialTensorShape>& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { return "CSVDatasetOp::Dataset"; } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - Node* compression_type = nullptr; - Node* buffer_size = nullptr; - Node* header = nullptr; - Node* delim = nullptr; - Node* use_quote_delim = nullptr; - Node* na_value = nullptr; - Node* select_cols = nullptr; - - std::vector<Node*> record_defaults; - record_defaults.reserve(record_defaults_.size()); - for (const Tensor& t : record_defaults_) { - Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); - record_defaults.emplace_back(node); - } - - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); - TF_RETURN_IF_ERROR( - b->AddScalar(options_.input_buffer_size, &buffer_size)); - TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); - - string delim_string(1, delim_); - TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); - TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); - TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); - TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); - - TF_RETURN_IF_ERROR(b->AddDataset( - this, - {std::make_pair(0, filenames), std::make_pair(1, compression_type), - std::make_pair(2, buffer_size), std::make_pair(3, header), - std::make_pair(4, delim), std::make_pair(5, use_quote_delim), - std::make_pair(6, na_value), - std::make_pair(7, select_cols)}, // Single tensor inputs - {std::make_pair(8, record_defaults)}, // Tensor list inputs - {}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - bool select_all = dataset()->select_cols_.empty(); - do { - // We are currently processing a file, so try to read the next record - if (input_stream_) { - Status s = ReadRecord(ctx, out_tensors, select_all, - dataset()->select_cols_); - if (s.ok()) { - // Validate output - if (out_tensors->size() != dataset()->out_type_.size()) { - return errors::InvalidArgument( - "Expect ", dataset()->out_type_.size(), " fields but have ", - out_tensors->size(), " in record"); - } - - *end_of_sequence = false; - return s; - } - if (!errors::IsOutOfRange(s)) { - // Not at the end of file, return OK or non-EOF errors to caller. - *end_of_sequence = false; - return s; - } - // We have reached the end of the current file, so maybe - // move on to next file. - ResetStreamsLocked(); - ++current_file_index_; - } - // Iteration ends when there are no more files to process. - if (current_file_index_ == dataset()->filenames_.size()) { - *end_of_sequence = true; - return Status::OK(); - } - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } while (true); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), - current_file_index_)); - // `input_stream_` is empty if - // 1. GetNext has not been called even once. - // 2. All files have been read and the iterator has been exhausted. - if (input_stream_ && num_buffer_reads_ > 0) { - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); - // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. - TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), - num_buffer_reads_)); - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - ResetStreamsLocked(); - int64 current_file_index; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), - ¤t_file_index)); - current_file_index_ = size_t(current_file_index); - // The keys "pos" and "num_buffer_reads" are written only if - // the iterator was saved with an open, partially read file. - if (reader->Contains(full_name("pos"))) { - int64 pos, num_buffer_reads; - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), - &num_buffer_reads)); - - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - - num_buffer_reads_ = size_t(num_buffer_reads - 1); - - // Restores the most recently held buffer - Status s = input_stream_->SkipNBytes( - num_buffer_reads_ * dataset()->options_.input_buffer_size); - if (!s.ok() && !errors::IsOutOfRange(s)) { - // We might get out of range error here if the size of the file - // is not an exact multiple of the buffer size, and the last buffer - // read is < buffer_size. This is valid and we do not surface the - // error. - return s; - } - - Status s2 = FillBuffer(&buffer_); - if (!s2.ok() && !errors::IsOutOfRange(s2)) { - return s2; - } - pos_ = size_t(pos); - } - return Status::OK(); - } - - private: - // Reads an entire CSV row from the input stream, either from the - // existing buffer or by filling the buffer as needed. Converts extracted - // fields to output tensors as we go. - // - // When this function is called, pos_ should be the index of the first - // character of the record in buffer_, or past the end of the buffer. - // Note: ctx and out_tensors are only used in this function - // when fields are included in the record. - Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors, - bool select_all, const std::vector<int64>& selected) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - // At the end of the file, this will return errors::OutOfRange - TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); - pos_ = 0; - } - - // The first character may be \n if this is the continuation of a - // \r\n linebreak between this and the previous record. If so, skip it. - - bool end_of_record = false; // Keep track of when we find \n, \r or EOF - size_t num_parsed = 0; - size_t num_selected_parsed = 0; - - Status result; - - while (!end_of_record) { // Read till we reach \n, \r or EOF - bool include = - select_all || (num_selected_parsed < selected.size() && - selected[num_selected_parsed] == num_parsed); - - // Don't fail fast, so that the next call to GetNext may still return - // a valid record - result.Update( - ParseOneField(ctx, out_tensors, &end_of_record, include)); - - num_parsed++; - if (include) num_selected_parsed++; - } - - return result; - } - - // Parses one field from position pos_ in the buffer. Fields are - // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of - // the next field. - Status ParseOneField(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - // If we get here, this means the previous field's end coincided - // with the end of the buffer. We can fill the buffer without abandon. - Status s = FillBuffer(&buffer_); - - if (errors::IsOutOfRange(s)) { - // Reached EOF, and last field is empty - *end_of_record = true; - if (include) { - return FieldToOutput(ctx, StringPiece(), out_tensors); - } else { - return Status::OK(); - } - } else if (!s.ok()) { - return s; // Surface other errors back to caller - } - - pos_ = 0; - } - - if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { - return ParseQuotedField(ctx, out_tensors, end_of_record, include); - } - - return ParseUnquotedField(ctx, out_tensors, end_of_record, include); - } - - // For keeping track of relevant parts of a field from a previous buffer - struct Piece { - size_t start; - size_t len; - string buffer; - - Piece(string buffer, size_t start, size_t len) - : start(start), len(len), buffer(std::move(buffer)) {} - }; - - // Given that pos_ exceeds the buffer, saves the relevant part of the - // current buffer (if necessary), fills the buffer, and resets indices to - // 0. - Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces, - size_t* start, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - string temp_buffer; - - buffer_.swap(temp_buffer); - if (include && pos_ > *start) { - earlier_pieces->push_back( - Piece(std::move(temp_buffer), *start, pos_ - *start)); - } - pos_ = 0; - *start = 0; - return FillBuffer(&buffer_); - } - - // Parses unquoted field from position pos_ in the buffer. Continually - // reads from buffer until end of field is reached (delim, CRLF, or EOF). - // Advances pos_ to keep track of our position in the buffer as we go, - // stopping at the first character of the next field. - Status ParseQuotedField(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::vector<Piece> earlier_pieces; - size_t start = pos_; - pos_++; // Starting quotation mark - - Status parse_result; - while (true) { // Each iter reads 1 char, filling buffer if necessary - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - if (errors::IsOutOfRange(s)) { - return errors::InvalidArgument( - "Reached end of file without closing quoted field in " - "record"); - } else if (!s.ok()) { - return s; // Surface all other errors to caller - } - } - - char ch = buffer_[pos_]; - if (ch == '"') { - // When we encounter a quote, we look ahead to the next character to - // decide what to do - pos_++; - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - if (errors::IsOutOfRange(s)) { - // This was the last field. We are done - *end_of_record = true; - parse_result.Update(QuotedFieldToOutput( - ctx, StringPiece(), out_tensors, earlier_pieces, include)); - return parse_result; - } else if (!s.ok()) { - return s; - } - } - - char next = buffer_[pos_]; - pos_++; - if (next == dataset()->delim_) { - parse_result.Update(QuotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include)); - return parse_result; - - } else if (next == '\n' || next == '\r') { - *end_of_record = true; - parse_result.Update(QuotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - 1 - start), - out_tensors, earlier_pieces, include)); - if (next == '\r') SkipNewLineIfNecessary(); - return parse_result; - } else if (next != '"') { - // Take note of the error, but keep going to end of field. - include = false; // So we don't get funky errors when trying to - // unescape the quotes. - parse_result.Update(errors::InvalidArgument( - "Quote inside a string has to be escaped by another quote")); - } - - } else { - pos_++; - } - } - } - - // Converts quoted field to an output tensor, removing the starting - // and ending quotes from it and unescaping double quotations if - // necessary. - Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector<Tensor>* out_tensors, - const std::vector<Piece>& earlier_pieces, - bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return Status::OK(); - - if (earlier_pieces.empty()) { - if (field.find('\"', 1) == field.size() - 1) { - // `field` contains no escaped quotation marks. - // Exclude framing quotation marks - field.remove_prefix(1); - field.remove_suffix(1); - return FieldToOutput(ctx, field, out_tensors); - } - } - string field_complete; - size_t str_len = field.size(); - for (const Piece& p : earlier_pieces) { - str_len += p.len; - } - field_complete.reserve(str_len); - - // This bool flips every time we see a quote, so that we skip the second - // quote of every pair of adjacent quotes in the field. We need to track - // this across iterations of the for loop because adjacent double quotes - // may be in different buffers. Initialize to true because we also skip - // the opening quotation mark of the quoted field. - bool skip_next_quote = true; - for (const Piece& p : earlier_pieces) { - AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), - &field_complete, &skip_next_quote); - } - AppendUnescapedPiece(field, &field_complete, &skip_next_quote); - StringPiece result = StringPiece(field_complete); - result.remove_suffix(1); // Skip final quote - - return FieldToOutput(ctx, result, out_tensors); - } - - void AppendUnescapedPiece(StringPiece piece, string* field_complete, - bool* skip_next_quote) { - size_t from = 0; - size_t found = piece.find('\"', from); - while (found != string::npos) { - if (!*skip_next_quote) { - // This is the first quote in a pair of adjacent double quotes - field_complete->append(piece.data() + from, found + 1 - from); - } - *skip_next_quote = !*skip_next_quote; - from = found + 1; - found = piece.find('\"', from); - } - // Include the chunk after the last quotation mark in the string - if (from < piece.size()) { - field_complete->append(piece.data() + from, piece.size() - from); - } - } - - // Parses unquoted field from position pos_ in the buffer. Continually - // reads from buffer until end of field is reached (delim, CRLF, or EOF). - // Advances pos_ to keep track of our position in the buffer as we go, - // stopping at the first character of the next field. - Status ParseUnquotedField(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_record, bool include) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - std::vector<Piece> earlier_pieces; - size_t start = pos_; - Status parse_result; - - while (true) { // Each iter reads 1 char, filling buffer if necessary - if (pos_ >= buffer_.size()) { - Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); - // Handle errors - if (errors::IsOutOfRange(s)) { - // Whatever we have is the last field of the last record - *end_of_record = true; - parse_result.Update(UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include)); - return parse_result; - } else if (!s.ok()) { - return s; // Surface all other errors to caller - } - } - - char ch = buffer_[pos_]; - - if (ch == dataset()->delim_) { - parse_result.Update(UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include)); - pos_++; - return parse_result; - } - if (ch == '\n' || ch == '\r') { - // need special case to skip over first \n of record if the line - // breaks are \r\n - parse_result.Update(UnquotedFieldToOutput( - ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, - earlier_pieces, include)); - *end_of_record = true; - pos_++; - if (ch == '\r') SkipNewLineIfNecessary(); - return parse_result; - } - if (dataset()->use_quote_delim_ && ch == '"') { - // Take note of the error, but keep going to end of field. - parse_result.Update(errors::InvalidArgument( - "Unquoted fields cannot have quotes inside")); - } - // Otherwise, go to next character - pos_++; - } - } - - Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - result->clear(); - ++num_buffer_reads_; - Status s = input_stream_->ReadNBytes( - dataset()->options_.input_buffer_size, result); - - if (errors::IsOutOfRange(s) && !result->empty()) { - // Ignore OutOfRange error when ReadNBytes read < N bytes. - return Status::OK(); - } - return s; - } - - // Given a field, converts it to the right output tensor type - Status FieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector<Tensor>* out_tensors) { - size_t output_idx = out_tensors->size(); - if (output_idx >= dataset()->out_type_.size()) { - // We can get here if we're selecting all columns, but the number of - // fields exceeds the number of defaults provided - return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), - " fields but have more in record"); - } - const DataType& dtype = dataset()->out_type_[output_idx]; - Tensor component(ctx->allocator({}), dtype, {}); - if ((field.empty() || field == dataset()->na_value_) && - dataset()->record_defaults_[output_idx].NumElements() != 1) { - // If the field is empty or NA value, and default is not given, - // report error. - return errors::InvalidArgument("Field ", output_idx, - " is required but missing in record!"); - } - - switch (dtype) { - // For each case, if the field is empty, we use the default. - // Otherwise, we convert it to the right type. - case DT_INT32: { - if (field.empty() || field == dataset()->na_value_) { - component.scalar<int32>()() = - dataset()->record_defaults_[output_idx].flat<int32>()(0); - } else { - int32 value; - if (!strings::safe_strto32(field, &value)) { - return errors::InvalidArgument( - "Field ", output_idx, - " in record is not a valid int32: ", field); - } - component.scalar<int32>()() = value; - } - break; - } - case DT_INT64: { - if (field.empty() || field == dataset()->na_value_) { - component.scalar<int64>()() = - dataset()->record_defaults_[output_idx].flat<int64>()(0); - } else { - int64 value; - if (!strings::safe_strto64(field, &value)) { - return errors::InvalidArgument( - "Field ", output_idx, - " in record is not a valid int64: ", field); - } - component.scalar<int64>()() = value; - } - break; - } - case DT_FLOAT: { - if (field.empty() || field == dataset()->na_value_) { - component.scalar<float>()() = - dataset()->record_defaults_[output_idx].flat<float>()(0); - } else { - float value; - if (!strings::safe_strtof(field, &value)) { - return errors::InvalidArgument( - "Field ", output_idx, - " in record is not a valid float: ", field); - } - component.scalar<float>()() = value; - } - break; - } - case DT_DOUBLE: { - if (field.empty() || field == dataset()->na_value_) { - component.scalar<double>()() = - dataset()->record_defaults_[output_idx].flat<double>()(0); - } else { - double value; - if (!strings::safe_strtod(field, &value)) { - return errors::InvalidArgument( - "Field ", output_idx, - " in record is not a valid double: ", field); - } - component.scalar<double>()() = value; - } - break; - } - case DT_STRING: { - if (field.empty() || field == dataset()->na_value_) { - component.scalar<string>()() = - dataset()->record_defaults_[output_idx].flat<string>()(0); - } else { - component.scalar<string>()() = string(field); - } - break; - } - default: - return errors::InvalidArgument("csv: data type ", dtype, - " not supported in field ", - output_idx); - } - out_tensors->push_back(std::move(component)); - return Status::OK(); - } - - // Records can be delimited by "\r\n" line breaks. When we encounter a - // '\r', we have to check the next character to see if it is part of the - // linebreak, and ignore it if so. - void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (pos_ >= buffer_.size()) { - Status s = FillBuffer(&buffer_); - pos_ = 0; - // If we failed to fill buffer, it doesn't matter because we're done - // with the record - if (!s.ok()) return; - } - if (buffer_[pos_] == '\n') { - pos_++; - } - } - - // Given a string field, and its index in the output, - // converts it to a Tensor of the right type and adds it to the - // out_tensors vector. - Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, - std::vector<Tensor>* out_tensors, - const std::vector<Piece>& earlier_pieces, - bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (!include) return Status::OK(); - - if (earlier_pieces.empty()) { - return FieldToOutput(ctx, field, out_tensors); - } - - size_t str_len = field.size(); - for (const Piece& p : earlier_pieces) { - str_len += p.len; - } - string field_complete; - field_complete.reserve(str_len); - - for (const Piece& p : earlier_pieces) { - field_complete.append(p.buffer, p.start, p.len); - } - - field_complete.append(field.data(), field.size()); - return FieldToOutput(ctx, field_complete, out_tensors); - } - - // Sets up reader streams to read from the file at `current_file_index_`. - Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - - // Actually move on to next file. - TF_RETURN_IF_ERROR(env->NewRandomAccessFile( - dataset()->filenames_[current_file_index_], &file_)); - random_access_input_stream_ = - std::make_shared<io::RandomAccessInputStream>(file_.get(), false); - - if (dataset()->use_compression_) { - input_stream_ = std::make_shared<io::ZlibInputStream>( - random_access_input_stream_.get(), - dataset()->options_.input_buffer_size, - dataset()->options_.input_buffer_size, dataset()->options_); - } else { - input_stream_ = random_access_input_stream_; - } - buffer_.clear(); - pos_ = 0; - num_buffer_reads_ = 0; - if (dataset()->header_) { - // Read one line, but don't include it. Pass nullptrs as dummy - // pointers to objects that shouldn't be invoked anyway - // We need to process this as a record here instead of just finding - // the first newline because it might contain quoted fields with - // newlines in the header as well - std::vector<int64> empty; - Status s = ReadRecord(nullptr, nullptr, false, empty); - if (!s.ok()) { - return errors::InvalidArgument("Can't read header of file"); - } - } - return Status::OK(); - } - - // Resets all reader streams. - void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - input_stream_.reset(); - file_.reset(); - } - - mutex mu_; - string buffer_ GUARDED_BY(mu_); // Maintain our own buffer - size_t pos_ GUARDED_BY( - mu_); // Index into the buffer must be maintained between iters - size_t num_buffer_reads_ GUARDED_BY(mu_); - std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_ - GUARDED_BY(mu_); - std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_); - size_t current_file_index_ GUARDED_BY(mu_) = 0; - std::unique_ptr<RandomAccessFile> file_ - GUARDED_BY(mu_); // must outlive input_stream_ - }; // class Iterator - - const std::vector<string> filenames_; - const bool header_; - const DataTypeVector out_type_; - const std::vector<PartialTensorShape> output_shapes_; - const std::vector<Tensor> record_defaults_; - const std::vector<int64> select_cols_; - const bool use_quote_delim_; - const char delim_; - const string na_value_; - const bool use_compression_; - const string compression_type_; - const io::ZlibCompressionOptions options_; - }; // class Dataset - - DataTypeVector output_types_; - std::vector<PartialTensorShape> output_shapes_; -}; // class CSVDatasetOp - -// Register the kernel implementation for CSVDataset. -REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc deleted file mode 100644 index a5321620bf..0000000000 --- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc +++ /dev/null @@ -1,280 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/hash/hash.h" - -namespace tensorflow { -namespace data { -namespace { - -// See documentation in ../ops/dataset_ops.cc for a high-level -// description of the following op. - -class DirectedInterleaveDatasetOp : public DatasetOpKernel { - public: - explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) - : DatasetOpKernel(ctx) {} - - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - DatasetBase* selector_input; - OP_REQUIRES_OK(ctx, - GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); - - OP_REQUIRES( - ctx, - selector_input->output_dtypes().size() == 1 && - selector_input->output_dtypes()[0] == DT_INT64 && - selector_input->output_shapes().size() == 1 && - selector_input->output_shapes()[0].IsCompatibleWith( - PartialTensorShape({})), - errors::InvalidArgument( - "The selector input must be a dataset of scalar int64 elements.")); - - std::vector<DatasetBase*> data_inputs; - for (size_t i = 1; i < ctx->num_inputs(); ++i) { - DatasetBase* input; - OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); - data_inputs.push_back(input); - - OP_REQUIRES( - ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), - errors::InvalidArgument( - "All inputs must have the same output_dtypes. First input " - "has types ", - DataTypeVectorString(data_inputs[0]->output_dtypes()), - ", and input ", i - 1, " has types ", - DataTypeVectorString(input->output_dtypes()))); - } - *output = new Dataset(ctx, selector_input, std::move(data_inputs)); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, - std::vector<DatasetBase*> data_inputs) - : DatasetBase(DatasetContext(ctx)), - selector_input_(selector_input), - data_inputs_(std::move(data_inputs)) { - selector_input_->Ref(); - - output_shapes_ = data_inputs_[0]->output_shapes(); - data_inputs_[0]->Ref(); - for (size_t i = 1; i < data_inputs_.size(); ++i) { - const DatasetBase* data_input = data_inputs_[i]; - data_input->Ref(); - for (size_t j = 0; j < output_shapes_.size(); ++j) { - output_shapes_[j] = MostSpecificCompatibleShape( - output_shapes_[j], data_input->output_shapes()[j]); - } - } - } - - ~Dataset() override { - selector_input_->Unref(); - for (DatasetBase* data_input : data_inputs_) { - data_input->Unref(); - } - } - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>(new Iterator( - {this, strings::StrCat(prefix, "::DirectedInterleave")})); - } - - const DataTypeVector& output_dtypes() const override { - return data_inputs_[0]->output_dtypes(); - } - - const std::vector<PartialTensorShape>& output_shapes() const override { - return output_shapes_; - } - - string DebugString() const override { - return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* selector_input_node; - TF_RETURN_IF_ERROR( - b->AddInputDataset(ctx, selector_input_, &selector_input_node)); - std::vector<Node*> data_input_nodes(data_inputs_.size()); - for (size_t i = 0; i < data_inputs_.size(); ++i) { - TF_RETURN_IF_ERROR( - b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); - } - TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, - {{1, data_input_nodes}}, {}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params), - num_active_inputs_(params.dataset->data_inputs_.size()) {} - - Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( - ctx, strings::StrCat(prefix(), ".selector"), - &selector_input_impl_)); - data_input_impls_.resize(dataset()->data_inputs_.size()); - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - const DatasetBase* data_input = dataset()->data_inputs_[i]; - TF_RETURN_IF_ERROR(data_input->MakeIterator( - ctx, strings::StrCat(prefix(), "[", i, "]"), - &data_input_impls_[i])); - } - return Status::OK(); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - if (!selector_input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } - - while (true) { - std::vector<Tensor> selector_result; - *end_of_sequence = false; - TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( - ctx, &selector_result, end_of_sequence)); - if (*end_of_sequence) { - selector_input_impl_.reset(); - for (auto& data_input_impl : data_input_impls_) { - data_input_impl.reset(); - } - return Status::OK(); - } - - int64 selected_input = selector_result[0].scalar<int64>()(); - if (selected_input < 0 || selected_input > data_input_impls_.size()) { - return errors::InvalidArgument( - "Selector index out of range: ", selected_input, - " >= ", data_input_impls_.size()); - } - - if (data_input_impls_[selected_input]) { - bool end_of_selected_input = false; - TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( - ctx, out_tensors, &end_of_selected_input)); - - if (!end_of_selected_input) { - return Status::OK(); - } - - data_input_impls_[selected_input].reset(); - --num_active_inputs_; - - if (num_active_inputs_ == 0) { - selector_input_impl_.reset(); - *end_of_sequence = true; - return Status::OK(); - } - } - - LOG(WARNING) << "DirectedInterleave selected an exhausted input: " - << selected_input; - } - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (selector_input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("selector_input_impl_empty"), "")); - } - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - const auto& data_input_impl = data_input_impls_[i]; - if (data_input_impl) { - TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl)); - } else { - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name(strings::StrCat("data_input_impl_empty[", i, "]")), - "")); - } - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (!reader->Contains(full_name("selector_input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); - } else { - selector_input_impl_.reset(); - } - for (size_t i = 0; i < data_input_impls_.size(); ++i) { - if (!reader->Contains(full_name( - strings::StrCat("data_input_impl_empty[", i, "]")))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); - } else { - data_input_impls_[i].reset(); - } - } - return Status::OK(); - } - - private: - mutex mu_; - std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_); - std::vector<std::unique_ptr<IteratorBase>> data_input_impls_ - GUARDED_BY(mu_); - int64 num_active_inputs_ GUARDED_BY(mu_); - }; - - static PartialTensorShape MostSpecificCompatibleShape( - const PartialTensorShape& ts1, const PartialTensorShape& ts2) { - PartialTensorShape output_tensorshape; - if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) - return output_tensorshape; - auto dims1 = ts1.dim_sizes(); - auto dims2 = ts2.dim_sizes(); - for (int d = 0; d < ts1.dims(); d++) { - if (dims1[d] == dims2[d]) - output_tensorshape.Concatenate(dims1[d]); - else - output_tensorshape.Concatenate(-1); - } - return output_tensorshape; - } - - const DatasetBase* const selector_input_; - const std::vector<DatasetBase*> data_inputs_; - std::vector<PartialTensorShape> output_shapes_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU), - DirectedInterleaveDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc deleted file mode 100644 index c3cb45dbf7..0000000000 --- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc +++ /dev/null @@ -1,155 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/data/kernels/indexed_dataset.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { -namespace data { -namespace { - -class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { - public: - using IndexedDatasetOpKernel::IndexedDatasetOpKernel; - - void MakeIndexedDataset(OpKernelContext* ctx, - IndexedDataset** output) override { - uint64 size = -1; - OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size)); - OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0")); - *output = new Dataset(ctx, size); - } - - class Dataset : public IndexedDataset { - public: - Dataset(OpKernelContext* ctx, uint64 size) - : IndexedDataset(DatasetContext(ctx)), size_(size) {} - - Status MaterializeDataset( - std::shared_ptr<MaterializedIndexedDataset>* materialized) override { - materialized->reset(new Materialized(this)); - return Status::OK(); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64}); - return *dtypes; - } - - const std::vector<PartialTensorShape>& output_shapes() const override { - static std::vector<PartialTensorShape>* shapes = - new std::vector<PartialTensorShape>({{}}); - return *shapes; - } - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>(new Iterator( - {this, strings::StrCat(prefix, "::IdentityIndexedDataset")})); - } - - string DebugString() const override { - return "IdentityIndexedDataset::Dataset"; - } - - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** node) const override { - return errors::Unimplemented( - "identity_indexed_dataset.AsGraphDefInternal"); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - if (cur_ < dataset()->size_) { - Tensor result_tensor(ctx->allocator({}), DT_UINT64, {}); - result_tensor.scalar<uint64>()() = cur_++; - out_tensors->emplace_back(std::move(result_tensor)); - *end_of_sequence = false; - return Status::OK(); - } - *end_of_sequence = true; - return Status::OK(); - } - - private: - mutex mu_; - uint64 cur_ GUARDED_BY(mu_); - }; - - class Materialized : public MaterializedIndexedDataset { - public: - explicit Materialized(Dataset* dataset) : dataset_(dataset) { - dataset->Ref(); - } - - ~Materialized() override { - // TODO(saeta): Pull this into MaterializedIndexedDataset - dataset_->Unref(); - } - - const DataTypeVector& output_dtypes() const override { - return dataset_->output_dtypes(); - } - - const std::vector<PartialTensorShape>& output_shapes() const override { - return dataset_->output_shapes(); - } - - Status Get(IteratorContext&& ctx, uint64 index, - std::vector<Tensor>* out_tensors) const override { - LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index - << ")"; - if (index >= dataset_->size_) { - // Note: use InvalidArgument instead of OutOfRange error because many - // things consider OutOfRange to be a "clean termination" error. - return errors::InvalidArgument( - "Index ", index, - " is out of range for this dataset. (Size is: ", dataset_->size_, - ".)"); - } - Tensor result_tensor(ctx.allocator({}), DT_UINT64, {}); - result_tensor.scalar<uint64>()() = index; - out_tensors->emplace_back(std::move(result_tensor)); - return Status::OK(); - } - - Status Size(uint64* size) const override { - *size = dataset_->size_; - return Status::OK(); - } - - private: - const Dataset* const dataset_; // Not owned. - }; - - const uint64 size_; - std::shared_ptr<Materialized> materialized_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU), - IdentityIndexedDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc deleted file mode 100644 index beec344534..0000000000 --- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/random/random.h" - -namespace tensorflow { -namespace data { -namespace { - -// See documentation in ../ops/dataset_ops.cc for a high-level -// description of the following op. - -class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { - public: - explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} - - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - *output = new Dataset(ctx, input); - } - - private: - class Dataset : public DatasetBase { - public: - explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) - : DatasetBase(DatasetContext(ctx)), input_(input) { - input_->Ref(); - } - - ~Dataset() override { input_->Unref(); } - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); - } - - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - const std::vector<PartialTensorShape>& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { - return "IgnoreErrorsDatasetOp::Dataset"; - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - { - tf_shared_lock l(mu_); - if (!input_impl_) { - *end_of_sequence = true; - return Status::OK(); - } - Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); - while (!s.ok()) { - out_tensors->clear(); - s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); - } - } - if (*end_of_sequence) { - mutex_lock l(mu_); - input_impl_.reset(); - } - return Status::OK(); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (input_impl_) - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - else - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impls_empty"), "")); - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (reader->Contains(full_name("input_impls_empty"))) - input_impl_.reset(); - else - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); - } - - private: - mutex mu_; - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); - }; - - const DatasetBase* const input_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU), - IgnoreErrorsDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc deleted file mode 100644 index ced8ab0d60..0000000000 --- a/tensorflow/contrib/data/kernels/indexed_dataset.cc +++ /dev/null @@ -1,373 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/contrib/data/kernels/indexed_dataset.h" - -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/gtl/cleanup.h" - -namespace tensorflow { -namespace data { -namespace { - -Status VerifyTypesMatch(const DataTypeVector& expected, - const DataTypeVector& received) { - if (expected.size() != received.size()) { - return errors::InvalidArgument( - "Number of components does not match: expected ", expected.size(), - " types but got ", received.size(), "."); - } - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != received[i]) { - return errors::InvalidArgument("Data type mismatch at component ", i, - ": expected ", DataTypeString(expected[i]), - " but got ", DataTypeString(received[i]), - "."); - } - } - return Status::OK(); -} - -Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, - const std::vector<PartialTensorShape>& received) { - if (expected.size() != received.size()) { - return errors::InvalidArgument( - "Number of components does not match: expected ", expected.size(), - " shapes but got ", received.size(), "."); - } - for (size_t i = 0; i < expected.size(); ++i) { - if (!expected[i].IsCompatibleWith(received[i])) { - return errors::InvalidArgument("Incompatible shapes at component ", i, - ": expected ", expected[i].DebugString(), - " but got ", received[i].DebugString(), - "."); - } - } - - return Status::OK(); -} - -class MaterializedDatasetResource : public ResourceBase { - public: - MaterializedDatasetResource( - const DataTypeVector& output_dtypes, - const std::vector<PartialTensorShape>& output_shapes) - : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} - - string DebugString() override { - return "Materialized IndexedDataset resource"; - } - - Status Get(IteratorContext&& ctx, uint64 index, - std::vector<Tensor>* out_tensors) { - std::shared_ptr<MaterializedIndexedDataset> captured(materialized_); - if (captured) { - return captured->Get(std::move(ctx), index, out_tensors); - } else { - return errors::FailedPrecondition( - "Get() failed because the MaterializedIndexedDataset has not been " - "initialized. Ensure that you have run the materialization operation " - "for this MaterializedIndexedDataset before retrieving elements."); - } - } - - // TODO(saeta): Implement Save and Restore - - const DataTypeVector& output_dtypes() const { return output_dtypes_; } - const std::vector<PartialTensorShape>& output_shapes() const { - return output_shapes_; - } - - Status set_materialized_dataset( - const std::shared_ptr<MaterializedIndexedDataset>& dataset) { - if (dataset) { - TF_RETURN_IF_ERROR( - VerifyTypesMatch(output_dtypes_, dataset->output_dtypes())); - TF_RETURN_IF_ERROR( - VerifyShapesCompatible(output_shapes_, dataset->output_shapes())); - } - materialized_ = dataset; - return Status::OK(); - } - - private: - std::shared_ptr<MaterializedIndexedDataset> materialized_; - const DataTypeVector output_dtypes_; - const std::vector<PartialTensorShape> output_shapes_; -}; - -// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT -// tensor. Objects of the wrapper class own a reference on an instance of an -// `IndexedTensor` and the wrapper's copy constructor and desctructor take care -// of managing the reference count. -// -// NOTE: This is not a feature-complete implementation of the DT_VARIANT -// specification. In particular, we cannot currently serialize an arbitrary -// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not -// implemented. -// -// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just -// use `tensorflow::DatasetVariantWrapper`. -class IndexedDatasetVariantWrapper { - public: - IndexedDatasetVariantWrapper() : dataset_(nullptr) {} - - // Transfers ownership of `dataset` to `*this`. - explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset) - : dataset_(dataset) {} - - IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other) - : dataset_(other.dataset_) { - if (dataset_) dataset_->Ref(); - } - - ~IndexedDatasetVariantWrapper() { - if (dataset_) dataset_->Unref(); - } - - IndexedDataset* get() const { return dataset_; } - - string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; } - string DebugString() const { - if (dataset_) { - return dataset_->DebugString(); - } else { - return "<Uninitialized IndexedDatasetVariantWrapper>"; - } - } - - void Encode(VariantTensorData* data) const { - LOG(ERROR) << "The Encode() method is not implemented for " - "IndexedDatasetVariantWrapper objects."; - } - - bool Decode(const VariantTensorData& data) { - LOG(ERROR) << "The Decode() method is not implemented for " - "IndexedDatasetVariantWrapper objects."; - return false; - } - - private: - IndexedDataset* const dataset_; // Owns one reference. -}; - -} // namespace - -Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, - IndexedDataset** out_dataset) { - if (!(tensor.dtype() == DT_VARIANT || - TensorShapeUtils::IsScalar(tensor.shape()))) { - return errors::InvalidArgument( - "IndexedDataset tensor must be a scalar of dtype DT_VARIANT."); - } - const Variant& variant = tensor.scalar<Variant>()(); - const IndexedDatasetVariantWrapper* wrapper = - variant.get<IndexedDatasetVariantWrapper>(); - if (wrapper == nullptr) { - return errors::InvalidArgument("Tensor must be an IndexedDataset object."); - } - *out_dataset = wrapper->get(); - if (*out_dataset == nullptr) { - return errors::Internal("Read uninitialized IndexedDataset variant."); - } - return Status::OK(); -} - -Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, - Tensor* tensor) { - if (!(tensor->dtype() == DT_VARIANT || - TensorShapeUtils::IsScalar(tensor->shape()))) { - return errors::InvalidArgument( - "Dataset tensor must be a scalar of dtype DT_VARIANT."); - } - tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset); - return Status::OK(); -} - -void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) { - IndexedDataset* dataset = nullptr; - MakeIndexedDataset(ctx, &dataset); - - if (ctx->status().ok()) { - OP_REQUIRES(ctx, dataset != nullptr, - errors::Internal("MakeIndexedDataset did not correctly " - "construct the IndexedDataset")); - Tensor* output = nullptr; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); - OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output)); - } -} - -namespace { - -class MaterializedHandleOp : public OpKernel { - public: - explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - } - - ~MaterializedHandleOp() override { - if (resource_ != nullptr) { - resource_->Unref(); - if (cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->template Delete<MaterializedDatasetResource>( - cinfo_.container(), cinfo_.name()) - .ok()) { - // Do nothing; the resource can have been deleted by session resets. - // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h - } - } - } - } - - void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { - { - mutex_lock l(mu_); - if (resource_ == nullptr) { - ResourceMgr* mgr = context->resource_manager(); - OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); - - MaterializedDatasetResource* resource; - OP_REQUIRES_OK(context, - mgr->LookupOrCreate<MaterializedDatasetResource>( - cinfo_.container(), cinfo_.name(), &resource, - [this](MaterializedDatasetResource** ret) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new MaterializedDatasetResource( - output_dtypes_, output_shapes_); - return Status::OK(); - })); - Status s = VerifyResource(resource); - if (TF_PREDICT_FALSE(!s.ok())) { - resource->Unref(); - context->SetStatus(s); - return; - } - - resource_ = resource; - } - } - OP_REQUIRES_OK(context, MakeResourceHandleToOutput( - context, 0, cinfo_.container(), cinfo_.name(), - MakeTypeIndex<MaterializedDatasetResource>())); - } - - private: - // During the first Compute(), resource is either created or looked up using - // shared_name. In the latter case, the resource found should be verified if - // it is compatible with this op's configuration. The verification may fail in - // cases such as two graphs asking queues of the same shared name to have - // inconsistent capacities. - Status VerifyResource(MaterializedDatasetResource* resource) { - TF_RETURN_IF_ERROR( - VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); - TF_RETURN_IF_ERROR( - VerifyShapesCompatible(output_shapes_, resource->output_shapes())); - return Status::OK(); - } - - mutex mu_; - ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. - MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr; - DataTypeVector output_dtypes_; - std::vector<PartialTensorShape> output_shapes_; -}; - -// TODO(saeta): Make async. -class MaterializeDatasetOp : public OpKernel { - public: - explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - IndexedDataset* dataset; - OP_REQUIRES_OK(ctx, - GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset)); - - MaterializedDatasetResource* materialized_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), - &materialized_resource)); - core::ScopedUnref unref(materialized_resource); - std::shared_ptr<MaterializedIndexedDataset> materialized; - OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized)); - OP_REQUIRES_OK( - ctx, materialized_resource->set_materialized_dataset(materialized)); - } -}; - -// TODO(saeta): Make async -class IndexedDatasetGet : public OpKernel { - public: - explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {} - - void Compute(OpKernelContext* ctx) override { - MaterializedDatasetResource* materialized_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), - &materialized_resource)); - auto cleanup = gtl::MakeCleanup([materialized_resource] { - materialized_resource->Unref(); // Note: can't use core::ScopedUnref. - }); - - const Tensor* index_t; - OP_REQUIRES_OK(ctx, ctx->input("index", &index_t)); - // TODO(saeta): Support batch reads (indexes should be non-scalar!) - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()), - errors::InvalidArgument("index must be a scalar")); - const uint64 index = index_t->scalar<uint64>()(); - - std::vector<Tensor> out_tensors; - Status s = - materialized_resource->Get(IteratorContext(ctx), index, &out_tensors); - - // Note: Unref materialized_resource to avoid destruction races. (Important - // in a [future] async op implementation.) - cleanup.release()(); - - if (!s.ok()) { - ctx->SetStatus(s); - } else { - auto expected_shapes = materialized_resource->output_shapes(); - auto expected_types = materialized_resource->output_dtypes(); - for (size_t i = 0; i < out_tensors.size(); ++i) { - OP_REQUIRES( - ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()), - errors::Internal( - "Materialized dataset output at index ", i, - " is incompatible with the expected shape. (Expected: ", - expected_shapes[i], ", got: ", out_tensors[i].shape(), ")")); - OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i], - errors::Internal("Materialized dataset output at index ", i, - " was not the expected dtype. (Expected: ", - expected_types[i], - ", got: ", out_tensors[i].dtype(), ")")); - ctx->set_output(i, out_tensors[i]); - } - } - } -}; - -REGISTER_KERNEL_BUILDER( - Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU), - MaterializedHandleOp); -REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU), - MaterializeDatasetOp); -REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU), - IndexedDatasetGet); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h deleted file mode 100644 index 7aa2d3fdbc..0000000000 --- a/tensorflow/contrib/data/kernels/indexed_dataset.h +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ -#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/op_kernel.h" - -namespace tensorflow { -namespace data { - -// TODO(saeta): Urgh, this is ugly. -class MaterializedIndexedDataset { - public: - virtual ~MaterializedIndexedDataset() = default; - - // Retrieve the element at a given index. The output tensors are stored in - // out_tensors. - // - // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is - // returned. - // - // Get is thread-safe. - virtual Status Get(IteratorContext&& ctx, uint64 index, - std::vector<Tensor>* out_tensors) const = 0; - - // Size determines the number of elements in this IndexedDataset. - // - // Size is thread-safe. - virtual Status Size(uint64* size) const = 0; - - // Returns a vector of DataType values, representing the respective - // element types of each tuple component in the outputs of this dataset. - virtual const DataTypeVector& output_dtypes() const = 0; - - // Returns a vector of tensor shapes, representing the respective - // (and possibly partially defined) shapes of each tuple component - // in the outputs of this dataset. - virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; -}; - -// IndexedDataset represents a dataset that supports random access in addition -// to iterator-based sequential access. -// -// Note: IndexedDatasets are HIGHLY experimental at this time. Expect -// significant (backwards incompatible) changes! -class IndexedDataset : public DatasetBase { - public: - IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {} - - // Materialize (if necessary) the dataset, and return a pointer. - // TODO(saeta): Add in `IteratorContext* ctx` when materializing. - virtual Status MaterializeDataset( - std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0; -}; - -// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the -// rest of the TensorFlow runtime. -// -// Most IndexedDataset's will be private members of classes inheriting from this -// class. -class IndexedDatasetOpKernel : public OpKernel { - public: - IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} - void Compute(OpKernelContext* ctx) final; - - protected: - // Subclasses should implement this method. It will be called during Compute - // execution. - virtual void MakeIndexedDataset(OpKernelContext* ctx, - IndexedDataset** output) = 0; - - template <typename T> - Status ParseScalarArgument(OpKernelContext* ctx, - const StringPiece& argument_name, T* output) { - const Tensor* argument_t; - TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); - if (!TensorShapeUtils::IsScalar(argument_t->shape())) { - return errors::InvalidArgument(argument_name, " must be a scalar"); - } - *output = argument_t->scalar<T>()(); - return Status::OK(); - } -}; - -// Validates and extracts an `IndexedDataset` object from `tensor`. -// -// `tensor` must have been written by a call to -// `StoreIndexedDatasetInVariantTensor` -// -// The retrieved pointer isa borrowed reference to the dataset, which is owned -// by the tensor. The consumer must either acquire its own reference to the -// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not -// destroyed or mutated while the retrieved pointer is in use. -Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, - IndexedDataset** out_dataset); - -// Stores an `IndexedDataset` object in `tensor.` -// -// The ownership of `dataset` is transferred to `tensor`. -Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, - Tensor* tensor); - -} // namespace data -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_ diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc deleted file mode 100644 index d233c1f8ec..0000000000 --- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc +++ /dev/null @@ -1,217 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include <sys/stat.h> - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/lib/io/buffered_inputstream.h" -#include "tensorflow/core/platform/file_system.h" - -#include "lmdb.h" // NOLINT(build/include) - -namespace tensorflow { -namespace data { -namespace { - -class LMDBDatasetOp : public DatasetOpKernel { - public: - using DatasetOpKernel::DatasetOpKernel; - void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - const Tensor* filenames_tensor; - OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); - OP_REQUIRES( - ctx, filenames_tensor->dims() <= 1, - errors::InvalidArgument("`filenames` must be a scalar or a vector.")); - - std::vector<string> filenames; - filenames.reserve(filenames_tensor->NumElements()); - for (int i = 0; i < filenames_tensor->NumElements(); ++i) { - filenames.push_back(filenames_tensor->flat<string>()(i)); - } - - *output = new Dataset(ctx, filenames); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const std::vector<string>& filenames) - : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::LMDB")})); - } - - const DataTypeVector& output_dtypes() const override { - static DataTypeVector* dtypes = - new DataTypeVector({DT_STRING, DT_STRING}); - return *dtypes; - } - - const std::vector<PartialTensorShape>& output_shapes() const override { - static std::vector<PartialTensorShape>* shapes = - new std::vector<PartialTensorShape>({{}, {}}); - return *shapes; - } - - string DebugString() const override { return "LMDBDatasetOp::Dataset"; } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* filenames = nullptr; - TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - do { - if (mdb_cursor_) { - Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); - key_tensor.scalar<string>()() = string( - static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size); - out_tensors->emplace_back(std::move(key_tensor)); - - Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); - value_tensor.scalar<string>()() = - string(static_cast<const char*>(mdb_value_.mv_data), - mdb_value_.mv_size); - out_tensors->emplace_back(std::move(value_tensor)); - - int val; - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - ++current_file_index_; - } - *end_of_sequence = false; - return Status::OK(); - } - if (current_file_index_ == dataset()->filenames_.size()) { - *end_of_sequence = true; - return Status::OK(); - } - - TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); - } while (true); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - return errors::Unimplemented( - "Checkpointing is currently not supported for LMDBDataset."); - } - - private: - Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (current_file_index_ >= dataset()->filenames_.size()) { - return errors::InvalidArgument( - "current_file_index_:", current_file_index_, - " >= filenames_.size():", dataset()->filenames_.size()); - } - const string& filename = dataset()->filenames_[current_file_index_]; - - int val = mdb_env_create(&mdb_env_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; - - struct stat source_stat; - if (stat(filename.c_str(), &source_stat) == 0 && - (source_stat.st_mode & S_IFREG)) { - flags |= MDB_NOSUBDIR; - } - val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); - if (val != MDB_SUCCESS) { - return errors::InvalidArgument(mdb_strerror(val)); - } - val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); - if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { - return errors::InvalidArgument(mdb_strerror(val)); - } - if (val == MDB_NOTFOUND) { - ResetStreamsLocked(); - } - return Status::OK(); - } - void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - if (mdb_env_ != nullptr) { - if (mdb_cursor_) { - mdb_cursor_close(mdb_cursor_); - mdb_cursor_ = nullptr; - } - mdb_dbi_close(mdb_env_, mdb_dbi_); - mdb_txn_abort(mdb_txn_); - mdb_env_close(mdb_env_); - mdb_txn_ = nullptr; - mdb_dbi_ = 0; - mdb_env_ = nullptr; - } - } - mutex mu_; - size_t current_file_index_ GUARDED_BY(mu_) = 0; - MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; - MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; - MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; - MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; - - MDB_val mdb_key_ GUARDED_BY(mu_); - MDB_val mdb_value_ GUARDED_BY(mu_); - }; - - const std::vector<string> filenames_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc deleted file mode 100644 index 96f1dd0059..0000000000 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ /dev/null @@ -1,481 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include <deque> - -#include "tensorflow/core/common_runtime/process_function_library_runtime.h" -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_op_kernel.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/lib/random/random.h" -#include "tensorflow/core/util/device_name_utils.h" - -namespace tensorflow { -namespace data { -namespace { - -struct BufferElement { - // The producer sets `status` if getting the input element fails. - Status status; - // The buffered data element. - std::vector<Tensor> value; -}; - -using FunctionBufferCallback = std::function<void(const BufferElement&)>; - -class FunctionBufferingResource : public ResourceBase { - public: - FunctionBufferingResource(FunctionLibraryRuntime* lib, - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, - const NameAttrList& func, int64 buffer_size, - const string& source_device, - const string& target_device, - const std::vector<Tensor>& func_args, - const DataTypeVector& output_types) - : lib_(lib), - pflr_(std::move(pflr)), - func_(func), - buffer_size_(buffer_size), - source_device_(source_device), - target_device_(target_device), - func_args_(func_args), - output_types_(output_types), - handle_(kInvalidHandle), - is_buffering_(false), - end_of_sequence_(false), - cancelled_(false) {} - - ~FunctionBufferingResource() override { - Cancel(); - } - - string DebugString() override { - return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_, - "; target_device: ", target_device_); - } - - // Instantiates the function the first time it's called. After that it caches - // the handle. - Status Instantiate() LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - // Re-use existing handle if it's been set, effectively caching it. - if (handle_ != kInvalidHandle) { - return Status::OK(); - } - AttrValueMap attr_values = func_.attr(); - FunctionLibraryRuntime::InstantiateOptions opts; - opts.target = target_device_; - return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts, - &handle_); - } - - // Returns true if we've got to the end of the sequence and exhausted the - // buffer. - bool Finished() LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - return end_of_sequence_ && buffer_.empty(); - } - - // Cancels any buffering / prefetching going on. - void Cancel() LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - cancelled_ = true; - while (is_buffering_) { - cond_var_.wait(l); - } - } - - // Cancels all pending operations and then clears out the state. - void Reset() LOCKS_EXCLUDED(mu_) { - Cancel(); - mutex_lock l(mu_); - buffer_.clear(); - requests_.clear(); - is_buffering_ = false; - end_of_sequence_ = false; - cancelled_ = false; - } - - // If the buffer has anything, runs `callback` on the first element in the - // buffer, else schedules the `callback` to be called. Requires `args` and - // `lib` in case more function calls need to be scheduled. - void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) { - bool start_buffering = false; - bool produced_output = false; - BufferElement buffer_element; - { - mutex_lock l(mu_); - if (!is_buffering_ && !end_of_sequence_) { - start_buffering = true; - } - if (!buffer_.empty()) { - produced_output = true; - std::swap(buffer_element, buffer_.front()); - buffer_.pop_front(); - } else { - produced_output = false; - requests_.push_back(std::move(callback)); - } - } - if (produced_output) { - callback(buffer_element); - } - if (start_buffering) { - FillBuffer(); - } - } - - private: - void FillBuffer() LOCKS_EXCLUDED(mu_) { - FunctionLibraryRuntime::Handle handle; - std::vector<FunctionBufferCallback> cancellation_callbacks; - std::vector<BufferElement> cancellation_buffer_elements; - bool cancelled = false; - { - mutex_lock l(mu_); - handle = handle_; - if (cancelled_) { - cancelled = true; - // Run through and fulfill all pending requests, if possible. - while (!requests_.empty()) { - if (!buffer_.empty()) { - cancellation_buffer_elements.push_back(std::move(buffer_.front())); - buffer_.pop_front(); - cancellation_callbacks.push_back(std::move(requests_.front())); - requests_.pop_front(); - } else { - LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: " - << requests_.size() << " requests"; - break; - } - } - is_buffering_ = false; - } else { - is_buffering_ = true; - } - } - if (cancelled) { - for (int i = 0; i < cancellation_callbacks.size(); ++i) { - cancellation_callbacks[i](cancellation_buffer_elements[i]); - } - cond_var_.notify_all(); - return; - } - FunctionLibraryRuntime::Options opts; - // Copied from CapturedFunction::generate_step_id(); - opts.step_id = -std::abs(static_cast<int64>(random::New64())); - opts.source_device = source_device_; - AllocatorAttributes arg_alloc_attr; - arg_alloc_attr.set_on_host(true); - opts.args_alloc_attrs.push_back(arg_alloc_attr); - for (const auto& dtype : output_types_) { - AllocatorAttributes ret_alloc_attrs; - if (DataTypeAlwaysOnHost(dtype)) { - ret_alloc_attrs.set_on_host(true); - } - opts.rets_alloc_attrs.push_back(ret_alloc_attrs); - } - if (opts.source_device != target_device_) { - opts.remote_execution = true; - } - opts.create_rendezvous = true; - auto* rets = new std::vector<Tensor>; - lib_->Run(opts, handle, func_args_, rets, - [this, rets](const Status& status) { - FunctionBufferCallback callback = nullptr; - BufferElement buffer_front; - bool restart_buffering = false; - { - mutex_lock l(mu_); - BufferElement buffer_element; - buffer_element.status = status; - if (status.ok()) { - buffer_element.value.swap(*rets); - } else { - end_of_sequence_ = true; - is_buffering_ = false; - } - buffer_.push_back(std::move(buffer_element)); - if (!requests_.empty()) { - buffer_front = std::move(buffer_.front()); - buffer_.pop_front(); - callback = std::move(requests_.front()); - requests_.pop_front(); - } - if (buffer_.size() < buffer_size_ && !end_of_sequence_) { - restart_buffering = true; - } else { - // When the buffer is full, we don't want to call - // FillBuffer() unless we're in cancellation phase in which - // case FillBuffer() will do the final cleanup post - // cancellation. - if (cancelled_) { - restart_buffering = true; - } - is_buffering_ = false; - } - } - if (callback != nullptr) { - callback(buffer_front); - } - if (restart_buffering) { - FillBuffer(); - } - }); - } - - mutex mu_; - FunctionLibraryRuntime* lib_; - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; - NameAttrList func_; - const int64 buffer_size_; - const string source_device_; - const string target_device_; - const std::vector<Tensor> func_args_; - const DataTypeVector output_types_; - FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); - std::deque<BufferElement> buffer_ GUARDED_BY(mu_); - std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_); - bool is_buffering_ GUARDED_BY(mu_); - bool end_of_sequence_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_); - condition_variable cond_var_; -}; - -class FunctionBufferResourceHandleOp : public OpKernel { - public: - explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) - : OpKernel(ctx), flib_def_(nullptr) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); - } - - ~FunctionBufferResourceHandleOp() override { - if (cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->Delete<FunctionBufferingResource>(cinfo_.container(), - cinfo_.name()) - .ok()) { - // Do nothing; the resource can have been deleted by session resets. - } - } - } - - void Compute(OpKernelContext* ctx) override { - const Tensor* string_arg; - OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg)); - std::vector<Tensor> func_args; - func_args.push_back(*string_arg); - - const string& source_device = ctx->device()->name(); - - // Obtain and canonicalize target_device. - const Tensor* target_arg; - OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); - string target_device; - OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( - target_arg->scalar<string>()(), source_device, - &target_device)); - - FunctionLibraryRuntime* lib = ctx->function_library(); - OP_REQUIRES(ctx, lib != nullptr, - errors::Internal("No function library is provided.")); - - mutex_lock l(mu_); - if (!initialized_) { - OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); - FunctionLibraryRuntime* clone_lib; - std::unique_ptr<ProcessFunctionLibraryRuntime> pflr; - OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib)); - // Create the resource. - FunctionBufferingResource* buffer; - OP_REQUIRES_OK( - ctx, - ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>( - cinfo_.container(), cinfo_.name(), &buffer, - [clone_lib, &pflr, &source_device, &target_device, func_args, - this](FunctionBufferingResource** ptr) { - *ptr = new FunctionBufferingResource( - clone_lib, std::move(pflr), func_, buffer_size_, - source_device, target_device, func_args, output_types_); - return Status::OK(); - })); - core::ScopedUnref s(buffer); - OP_REQUIRES_OK(ctx, buffer->Instantiate()); - initialized_ = true; - } - - OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( - ctx, 0, cinfo_.container(), cinfo_.name(), - MakeTypeIndex<FunctionBufferingResource>())); - } - - private: - mutex mu_; - ContainerInfo cinfo_ GUARDED_BY(mu_); - bool initialized_ GUARDED_BY(mu_) = false; - std::unique_ptr<FunctionLibraryDefinition> flib_def_; - NameAttrList func_; - int64 buffer_size_; - string container_; - string name_; - DataTypeVector output_types_; -}; - -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") - .Device(DEVICE_CPU) - .HostMemory("resource") - .HostMemory("string_arg") - .HostMemory("target_device"), - FunctionBufferResourceHandleOp); -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") - .Device(DEVICE_GPU) - .HostMemory("resource") - .HostMemory("string_arg") - .HostMemory("target_device"), - FunctionBufferResourceHandleOp); -#if TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") - .Device(DEVICE_SYCL) - .HostMemory("resource") - .HostMemory("string_arg") - .HostMemory("target_device"), - FunctionBufferResourceHandleOp); -#endif // TENSORFLOW_USE_SYCL - -// Prefetches and fills up a buffer by calling a function that provides the -// elements to buffer. -class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { - public: - explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx) - : AsyncOpKernel(ctx) {} - - ~FunctionBufferingResourceGetNextOp() override {} - - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - ResourceHandle handle; - OP_REQUIRES_OK_ASYNC( - ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done); - FunctionBufferingResource* buffer = nullptr; - OP_REQUIRES_OK_ASYNC( - ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer), - done); - - if (buffer->Finished()) { - buffer->Unref(); - ctx->SetStatus(errors::OutOfRange("end_of_sequence")); - done(); - return; - } - - FunctionBufferCallback callback = - [ctx, buffer, done](const BufferElement& buffer_element) { - Status s = buffer_element.status; - if (!s.ok()) { - ctx->SetStatus(s); - buffer->Unref(); - done(); - return; - } - for (size_t i = 0; i < buffer_element.value.size(); ++i) { - ctx->set_output(i, buffer_element.value[i]); - } - buffer->Unref(); - done(); - }; - buffer->MaybeGet(std::move(callback)); - } -}; - -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") - .Device(DEVICE_CPU) - .HostMemory("function_buffer_resource"), - FunctionBufferingResourceGetNextOp); -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") - .Device(DEVICE_GPU) - .HostMemory("function_buffer_resource"), - FunctionBufferingResourceGetNextOp); -#if TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") - .Device(DEVICE_SYCL) - .HostMemory("function_buffer_resource"), - FunctionBufferingResourceGetNextOp); -#endif // TENSORFLOW_USE_SYCL - -// Resets the FunctionBufferingResource, cancelling all pending requests and -// clearing out the buffer. -class FunctionBufferingResourceResetOp : public OpKernel { - public: - explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) - : OpKernel(ctx) {} - - ~FunctionBufferingResourceResetOp() override {} - - void Compute(OpKernelContext* ctx) override { - ResourceHandle handle; - OP_REQUIRES_OK(ctx, - HandleFromInput(ctx, "function_buffer_resource", &handle)); - FunctionBufferingResource* buffer = nullptr; - OP_REQUIRES_OK( - ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer)); - core::ScopedUnref s(buffer); - - buffer->Reset(); - } -}; - -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") - .Device(DEVICE_CPU) - .HostMemory("function_buffer_resource"), - FunctionBufferingResourceResetOp); -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") - .Device(DEVICE_GPU) - .HostMemory("function_buffer_resource"), - FunctionBufferingResourceResetOp); -#if TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset") - .Device(DEVICE_SYCL) - .HostMemory("function_buffer_resource"), - FunctionBufferingResourceResetOp); -#endif // TENSORFLOW_USE_SYCL - -class IteratorGetDeviceOp : public OpKernel { - public: - using OpKernel::OpKernel; - - void Compute(OpKernelContext* ctx) override { - // NOTE(mrry): We do not currently Validate that the handle - // corresponds to a real IteratorResource, because that symbol is - // not exposed from the framework library. - Tensor* device_name_t; - OP_REQUIRES_OK(ctx, - ctx->allocate_output(0, TensorShape({}), &device_name_t)); - // NOTE(mrry): Since the operation's input is a resource, we must be - // colocated with it, and so we can simply return the current device's - // name without looking at the input. - device_name_t->scalar<string>()() = ctx->device()->name(); - } -}; - -REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU), - IteratorGetDeviceOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc deleted file mode 100644 index 30fa97a636..0000000000 --- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc +++ /dev/null @@ -1,219 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/util/work_sharder.h" - -namespace tensorflow { -namespace data { -namespace { - -class ThreadPoolResource : public ResourceBase { - public: - ThreadPoolResource(Env* env, const ThreadOptions& thread_options, - const string& name, int num_threads, bool low_latency_hint, - int max_intra_op_parallelism) - : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), - max_intra_op_parallelism_(max_intra_op_parallelism) {} - - // Schedules fn() for execution in the pool of threads. - void Schedule(std::function<void()> fn) { - if (max_intra_op_parallelism_ < 0) { - thread_pool_.Schedule(std::move(fn)); - } else { - thread_pool_.Schedule(std::bind( - [this](std::function<void()> bound_fn) { - // TODO(mrry): Consider moving this thread-local configuration to - // the threads themselves. - ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); - bound_fn(); - }, - std::move(fn))); - } - } - - string DebugString() override { return "ThreadPoolResource"; } - - private: - thread::ThreadPool thread_pool_; - const int max_intra_op_parallelism_; -}; - -// Creates a handle to a ThreadPool resource. Note that we don't use -// ResourceOpKernel here because the ThreadPoolResource constructor requires -// access to `OpKernelContext::env()`, which isn't provided by -// `ResourceOpKernel<T>::CreateResource()`. -class ThreadPoolHandleOp : public OpKernel { - public: - explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", - &max_intra_op_parallelism_)); - OP_REQUIRES( - ctx, num_threads_ > 0, - errors::InvalidArgument("`num_threads` must be greater than zero.")); - } - - // The resource is deleted from the resource manager only when it is private - // to kernel. Ideally the resource should be deleted when it is no longer held - // by anyone, but it would break backward compatibility. - ~ThreadPoolHandleOp() override { - if (cinfo_.resource_is_private_to_kernel()) { - if (!cinfo_.resource_manager() - ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name()) - .ok()) { - // Do nothing; the resource can have been deleted by session resets. - } - } - } - - void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); - if (!initialized_) { - ResourceMgr* mgr = ctx->resource_manager(); - OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); - ThreadPoolResource* resource; - OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>( - cinfo_.container(), cinfo_.name(), &resource, - [this, ctx](ThreadPoolResource** ret) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new ThreadPoolResource( - ctx->env(), {}, display_name_, - num_threads_, max_intra_op_parallelism_, - false /* low_latency_hint */); - return Status::OK(); - })); - initialized_ = true; - } - OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( - ctx, 0, cinfo_.container(), cinfo_.name(), - MakeTypeIndex<ThreadPoolResource>())); - } - - private: - mutex mu_; - ContainerInfo cinfo_ GUARDED_BY(mu_); - bool initialized_ GUARDED_BY(mu_) = false; - string display_name_; - int num_threads_; - int max_intra_op_parallelism_; -}; - -class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { - public: - explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} - - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - ThreadPoolResource* threadpool_resource; - OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), - &threadpool_resource)); - core::ScopedUnref unref_iterator(threadpool_resource); - - *output = new Dataset(ctx, input, threadpool_resource); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, - ThreadPoolResource* threadpool) - : DatasetBase(DatasetContext(ctx)), - input_(input), - threadpool_(threadpool) { - input_->Ref(); - threadpool_->Ref(); - } - - ~Dataset() override { - input_->Unref(); - threadpool_->Unref(); - } - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); - } - - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - const std::vector<PartialTensorShape>& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { - return "ThreadPoolDatasetOp::Dataset"; - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - ThreadPoolResource* pool = dataset()->threadpool_; - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = [pool](std::function<void()> c) { - pool->Schedule(std::move(c)); - }; - params.stats_aggregator_getter = ctx->stats_aggregator_getter(); - params.lib = ctx->lib(); - params.function_library = ctx->function_library(); - params.allocator_getter = ctx->allocator_getter(); - IteratorContext threadpool_ctx(params); - return input_impl_->GetNext(&threadpool_ctx, out_tensors, - end_of_sequence); - } - - private: - std::unique_ptr<IteratorBase> input_impl_; - }; - - const DatasetBase* const input_; - ThreadPoolResource* const threadpool_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU), - ThreadPoolHandleOp); -REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU), - ThreadPoolDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc deleted file mode 100644 index 57fc5697a4..0000000000 --- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc +++ /dev/null @@ -1,223 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/dataset.h" -#include "tensorflow/core/framework/partial_tensor_shape.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/lib/hash/hash.h" - -namespace tensorflow { -namespace data { -namespace { - -// See documentation in ../ops/dataset_ops.cc for a high-level -// description of the following op. - -class UniqueDatasetOp : public UnaryDatasetOpKernel { - public: - explicit UniqueDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx) {} - - void MakeDataset(OpKernelContext* ctx, DatasetBase* input, - DatasetBase** output) override { - OP_REQUIRES(ctx, input->output_dtypes().size() == 1, - errors::InvalidArgument("UniqueDataset only supports " - "inputs with a single component.")); - - DataType input_dtype = input->output_dtypes()[0]; - OP_REQUIRES(ctx, - input_dtype == DT_INT32 || input_dtype == DT_INT64 || - input_dtype == DT_STRING, - errors::InvalidArgument( - "UniqueDataset only supports inputs with a single " - "`tf.int32`, `tf.int64`, or `tf.string` component.")); - - *output = new Dataset(ctx, input); - } - - private: - class Dataset : public DatasetBase { - public: - Dataset(OpKernelContext* ctx, const DatasetBase* input) - : DatasetBase(DatasetContext(ctx)), input_(input) { - input_->Ref(); - } - - ~Dataset() override { input_->Unref(); } - - std::unique_ptr<IteratorBase> MakeIteratorInternal( - const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Unique")})); - } - - const DataTypeVector& output_dtypes() const override { - return input_->output_dtypes(); - } - - const std::vector<PartialTensorShape>& output_shapes() const override { - return input_->output_shapes(); - } - - string DebugString() const override { - return strings::StrCat("UniqueDatasetOp::Dataset"); - } - - protected: - Status AsGraphDefInternal(SerializationContext* ctx, - DatasetGraphDefBuilder* b, - Node** output) const override { - Node* input_graph_node = nullptr; - TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); - return Status::OK(); - } - - private: - class Iterator : public DatasetIterator<Dataset> { - public: - explicit Iterator(const typename Iterator::Params& params) - : DatasetIterator<Dataset>(params) {} - - Status Initialize(IteratorContext* ctx) override { - return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); - } - - Status GetNextInternal(IteratorContext* ctx, - std::vector<Tensor>* out_tensors, - bool* end_of_sequence) override { - mutex_lock l(mu_); - bool saw_new_value; - do { - saw_new_value = false; - out_tensors->clear(); - TF_RETURN_IF_ERROR( - input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); - if (*end_of_sequence) { - break; - } - DCHECK_EQ(1, out_tensors->size()); - saw_new_value = unique_elements_.insert((*out_tensors)[0]).second; - } while (!saw_new_value); - return Status::OK(); - } - - protected: - Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - if (input_impl_) { - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - } else { - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("input_impl_empty"), "")); - } - TF_RETURN_IF_ERROR(writer->WriteScalar( - full_name("unique_elements_size"), unique_elements_.size())); - size_t i = 0; - for (const Tensor& t : unique_elements_) { - TF_RETURN_IF_ERROR(writer->WriteTensor( - full_name(strings::StrCat("unique_elements[", i++, "]")), t)); - } - return Status::OK(); - } - - Status RestoreInternal(IteratorContext* ctx, - IteratorStateReader* reader) override { - mutex_lock l(mu_); - if (!reader->Contains(full_name("input_impl_empty"))) { - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - } else { - input_impl_.reset(); - } - int64 num_unique_elements; - unique_elements_.clear(); - TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"), - &num_unique_elements)); - for (int64 i = 0; i < num_unique_elements; ++i) { - Tensor unique_element; - TF_RETURN_IF_ERROR(reader->ReadTensor( - full_name(strings::StrCat("unique_elements[", i, "]")), - &unique_element)); - auto insert_result = unique_elements_.insert(unique_element); - if (!insert_result.second) { - return errors::InvalidArgument( - "Checkpoint contained two unique elements with the same " - "value."); - } - } - return Status::OK(); - } - - private: - struct TensorHash { - size_t operator()(const Tensor& t) const { - if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) { - return Hash64(t.tensor_data().data(), t.tensor_data().size()); - } else { - DCHECK_EQ(DT_STRING, t.dtype()); - auto flat_t = t.flat<string>(); - uint64 hash = 0; - for (int64 i = 0; i < t.NumElements(); ++i) { - hash = Hash64Combine(hash, Hash64(flat_t(i))); - } - return static_cast<size_t>(hash); - } - } - }; - - struct TensorKeyEqual { - bool operator()(const Tensor& lhs, const Tensor& rhs) const { - if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) { - return false; - } - switch (lhs.dtype()) { -#define HANDLE_TYPE(T) \ - case T: \ - do { \ - auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \ - auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \ - for (int64 i = 0; i < lhs.NumElements(); ++i) { \ - if (lhs_flat(i) != rhs_flat(i)) { \ - return false; \ - } \ - } \ - return true; \ - } while (0) - - HANDLE_TYPE(DT_INT32); - HANDLE_TYPE(DT_INT64); - HANDLE_TYPE(DT_STRING); - default: - LOG(FATAL) << "UniqueDataset unhandled data type: " - << DataTypeString(lhs.dtype()); - } - } - }; - - mutex mu_; - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); - std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_ - GUARDED_BY(mu_); - }; - - const DatasetBase* const input_; - }; -}; - -REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU), - UniqueDatasetOp); - -} // namespace -} // namespace data -} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc deleted file mode 100644 index d1a771f005..0000000000 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("DirectedInterleaveDataset") - .Input("selector_input_dataset: variant") - .Input("data_input_datasets: N * variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .Attr("N: int >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -A substitute for `InterleaveDataset` on a fixed list of `N` datasets. - -selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines - which of the `N` data inputs should produce the next output element. -data_input_datasets: `N` datasets with the same type that will be interleaved - according to the values of `selector_input_dataset`. -)doc"); - -REGISTER_OP("CSVDataset") - .Input("filenames: string") - .Input("compression_type: string") - .Input("buffer_size: int64") - .Input("header: bool") - .Input("field_delim: string") - .Input("use_quote_delim: bool") - .Input("na_value: string") - .Input("select_cols: int64") - .Input("record_defaults: output_types") - .Output("handle: variant") - .Attr("output_types: list({float,double,int32,int64,string}) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - // `filenames` must be a scalar or a vector. - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); - // `compression_type`, `buffer_size`, `header`, `field_delim`, - // `use_quote_delim`, `na_value` must be scalars - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); - // `select_cols` must be a vector - TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); - // `record_defaults` must be lists of scalars - for (size_t i = 8; i < c->num_inputs(); ++i) { - shape_inference::ShapeHandle v; - TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); - if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { - return errors::InvalidArgument( - "Shape of a default must be a length-0 or length-1 vector, or a " - "scalar."); - } - } - return shape_inference::ScalarShape(c); - }); - -REGISTER_OP("IgnoreErrorsDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the elements of `input_dataset` ignoring errors. -)doc"); - -REGISTER_OP("UniqueDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that contains the unique elements of `input_dataset`. -)doc"); - -REGISTER_OP("IteratorGetDevice") - .Input("resource: resource") - .Output("device: string") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Returns the name of the device on which `resource` has been placed. -)doc"); - -REGISTER_OP("FunctionBufferingResource") - .Input("string_arg: string") - .Input("target_device: string") - .Output("resource: resource") - .Attr("shared_name: string") - .Attr("container: string") - .Attr("f: func") - .Attr("buffer_size: int") - .Attr("output_types: list(type)") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Creates a resource that fills up a buffer by making function calls. - -string_arg: String argument to the function call. -target_device: Target device to execute the function on. -resource: Handle to the resource created. -f: Function to be executed. -buffer_size: Size of the buffer. -container: If non-empty, this resource is placed in the given container. - Otherwise, a default container is used. -shared_name: If non-empty, this resource will be shared under the given name - across multiple sessions. -output_types: The type list for the return values. -)doc"); - -REGISTER_OP("FunctionBufferingResourceGetNext") - .Input("function_buffer_resource: resource") - .Attr("output_types: list(type)") - .Output("output: output_types") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Gets the next element from a FunctionBufferingResource. - -function_buffer_resource: The FunctionBufferingResource handle. -output: A list of return values. -output_types: The type list for the return values. -)doc"); - -REGISTER_OP("FunctionBufferingResourceReset") - .Input("function_buffer_resource: resource") - .SetShapeFn(shape_inference::UnknownShape) - .Doc(R"doc( -Resets the FunctionBufferingResource. - -function_buffer_resource: The FunctionBufferingResource handle. -)doc"); - -REGISTER_OP("ThreadPoolDataset") - .Input("input_dataset: variant") - .Input("thread_pool: resource") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape) - .Doc(R"doc( -Creates a dataset that uses a custom thread pool to compute `input_dataset`. - -handle: A resource produced by the ThreadPoolHandle op. -)doc"); - -REGISTER_OP("ThreadPoolHandle") - .Output("handle: resource") - .SetShapeFn(shape_inference::ScalarShape) - .Attr("num_threads: int") - .Attr("max_intra_op_parallelism: int = 1") - .Attr("display_name: string") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Doc(R"doc( -Creates a custom thread pool with the given number of threads. - -handle: A resource that can be consumed by one or more ThreadPoolDataset ops. -num_threads: The number of threads in the thread pool. -max_intra_op_parallelism: The maximum degree of parallelism to use within - operations that execute on this threadpool. -display_name: A human-readable name for the threads that may be visible in - some visualizations. -)doc"); - -REGISTER_OP("AssertNextDataset") - .Input("input_dataset: variant") - .Input("transformations: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - // transformations should be a vector. - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); - return shape_inference::ScalarShape(c); - }); - -REGISTER_OP("LMDBDataset") - .Input("filenames: string") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked - // stateful to inhibit constant folding. - .SetShapeFn(shape_inference::ScalarShape); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc deleted file mode 100644 index cd9b7c68a0..0000000000 --- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" - -namespace tensorflow { - -REGISTER_OP("IdentityIndexedDataset") - .Input("size: uint64") - .Output("handle: variant") - .SetIsStateful() - .SetShapeFn( - shape_inference::ScalarShape); // TODO(saeta): check input shapes. - -/////////////////////////////////////////////////////////////////////////////// -// IndexedDataset Internals -/////////////////////////////////////////////////////////////////////////////// - -// Creates the handle. -REGISTER_OP("MaterializedIndexDatasetHandle") - .Output("handle: resource") - .Attr("container: string") - .Attr("shared_name: string") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); - -// Actually materialize the materialize handle. -REGISTER_OP("IndexedDatasetMaterialize") - .Input("dataset: variant") - .Input("materialized: resource") - .SetShapeFn(shape_inference::NoOutputs); - -namespace { - -Status GetShapeFn(shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle unused; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); - std::vector<PartialTensorShape> output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - if (output_shapes.size() != c->num_outputs()) { - return errors::InvalidArgument( - "`output_shapes` must be the same length as `output_types` (", - output_shapes.size(), " vs. ", c->num_outputs()); - } - for (size_t i = 0; i < output_shapes.size(); ++i) { - shape_inference::ShapeHandle output_shape_handle; - TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( - output_shapes[i], &output_shape_handle)); - c->set_output(static_cast<int>(i), output_shape_handle); - } - return Status::OK(); -} - -} // namespace - -REGISTER_OP("IndexedDatasetGet") - .Input("materialized: resource") - .Input("index: uint64") - .Output("components: output_types") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(GetShapeFn) - .Doc(R"doc( -Gets the element at `index` from `materialized` IndexedDataset. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ce52c990ce..33784afa3f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -31,6 +31,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -54,6 +55,7 @@ py_test( "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -77,6 +79,7 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:readers", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -97,6 +100,7 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", ], @@ -112,6 +116,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:random_seed", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -130,6 +135,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -139,12 +145,12 @@ py_test( name = "indexed_dataset_ops_test", srcs = ["indexed_dataset_ops_test.py"], deps = [ - "//tensorflow/contrib/data/python/ops:contrib_op_loader", - "//tensorflow/contrib/data/python/ops:gen_dataset_ops", "//tensorflow/contrib/data/python/ops:indexed_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -170,6 +176,7 @@ py_test( "//tensorflow/python:script_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@six_archive//:six", ], @@ -189,6 +196,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/estimator:estimator_py", ], @@ -215,6 +223,7 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:platform_test", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//third_party/py/numpy", ], ) @@ -240,6 +249,7 @@ py_test( "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -259,6 +269,7 @@ py_test( "//tensorflow/python:io_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -283,6 +294,7 @@ py_test( "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", ], ) @@ -301,6 +313,7 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", @@ -316,6 +329,7 @@ cuda_py_test( "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -341,6 +355,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -366,6 +381,7 @@ py_library( "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", ], @@ -412,6 +428,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:string_ops", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -434,6 +451,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", @@ -454,6 +472,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -471,6 +490,7 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -490,6 +510,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python/data/kernel_tests:test_base", "@org_sqlite//:python", ], ) @@ -534,6 +555,7 @@ py_library( deps = [ "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", + "//tensorflow/python/data/kernel_tests:test_base", ], ) @@ -550,6 +572,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:script_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -568,6 +591,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -588,6 +612,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -605,17 +630,8 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:lib", "//tensorflow/python:util", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:readers", ], ) - -py_library( - name = "test_utils", - srcs = ["test_utils.py"], - deps = [ - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python/data/util:nest", - ], -) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index e2508de9e9..fed7de5f2b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -40,12 +41,8 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class BatchDatasetTest(test.TestCase, parameterized.TestCase): +class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) @@ -723,7 +720,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase): self.assertAllEqual([element for _ in range(10)], sess.run(get_next)) -class RestructuredDatasetTest(test.TestCase): +class RestructuredDatasetTest(test_base.DatasetTestBase): def test_assert_element_shape(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 48971f2ccc..ae401f786c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -22,6 +22,7 @@ import random import numpy as np from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -class GroupByReducerTest(test.TestCase): +class GroupByReducerTest(test_base.DatasetTestBase): def checkResults(self, dataset, shapes, values): self.assertEqual(shapes, dataset.output_shapes) @@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase): self.assertEqual(y, 45) -class GroupByWindowTest(test.TestCase): +class GroupByWindowTest(test_base.DatasetTestBase): def testSimple(self): components = np.random.randint(100, size=(200,)).astype(np.int64) @@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase): # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. -class BucketTest(test.TestCase): +class BucketTest(test_base.DatasetTestBase): def _dynamicPad(self, bucket, window, window_size): # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a @@ -570,7 +571,7 @@ def _get_record_shape(sparse): return tensor_shape.TensorShape([None]) -class BucketBySequenceLength(test.TestCase): +class BucketBySequenceLength(test_base.DatasetTestBase): def testBucket(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py index f8e74e4583..5b3c512b64 100644 --- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py @@ -30,6 +30,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import readers from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -43,37 +44,7 @@ from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes -class CsvDatasetOpTest(test.TestCase): - - def _get_next(self, dataset): - # Returns a no argument function whose result is fed to self.evaluate to - # yield the next element - it = dataset.make_one_shot_iterator() - if context.executing_eagerly(): - return it.get_next - else: - get_next = it.get_next() - return lambda: get_next - - def _assert_datasets_equal(self, ds1, ds2): - assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, ' - '%s') % (ds1.output_shapes, - ds2.output_shapes) - assert ds1.output_types == ds2.output_types - assert ds1.output_classes == ds2.output_classes - next1 = self._get_next(ds1) - next2 = self._get_next(ds2) - # Run through datasets and check that outputs match, or errors match. - while True: - try: - op1 = self.evaluate(next1()) - except (errors.OutOfRangeError, ValueError) as e: - # If op1 throws an exception, check that op2 throws same exception. - with self.assertRaises(type(e)): - self.evaluate(next2()) - break - op2 = self.evaluate(next2()) - self.assertAllEqual(op1, op2) +class CsvDatasetOpTest(test_base.DatasetTestBase): def _setup_files(self, inputs, linebreak='\n', compression_type=None): filenames = [] @@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase): """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv).""" dataset_actual, dataset_expected = self._make_test_datasets( inputs, **kwargs) - self._assert_datasets_equal(dataset_actual, dataset_expected) + self.assertDatasetsEqual(dataset_actual, dataset_expected) def _verify_output_or_err(self, dataset, @@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase): expected_err_re=None): if expected_err_re is None: # Verify that output is expected, without errors - nxt = self._get_next(dataset) + nxt = self.getNext(dataset) expected_output = [[ v.encode('utf-8') if isinstance(v, str) else v for v in op ] for op in expected_output] @@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase): else: # Verify that OpError is produced as expected with self.assertRaisesOpError(expected_err_re): - nxt = self._get_next(dataset) + nxt = self.getNext(dataset) while True: try: self.evaluate(nxt()) @@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase): inputs = [['1,,3,4', '5,6,,8']] ds_actual, ds_expected = self._make_test_datasets( inputs, record_defaults=record_defaults) - self._assert_datasets_equal( + self.assertDatasetsEqual( ds_actual.repeat(5).prefetch(1), ds_expected.repeat(5).prefetch(1)) @@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase): ds = readers.make_csv_dataset( file_path, batch_size=1, shuffle=False, num_epochs=1) - nxt = self._get_next(ds) + nxt = self.getNext(ds) result = list(self.evaluate(nxt()).values()) diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index a2ab3de52e..722e87e555 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import batching +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class DatasetConstructorTest(test.TestCase): +class DatasetConstructorTest(test_base.DatasetTestBase): def testRestructureDataset(self): components = (array_ops.placeholder(dtypes.int32), diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py index eb110324d1..bc10c21472 100644 --- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py @@ -20,13 +20,14 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import random_seed from tensorflow.python.platform import test -class DirectedInterleaveDatasetTest(test.TestCase): +class DirectedInterleaveDatasetTest(test_base.DatasetTestBase): def testBasic(self): selector_dataset = dataset_ops.Dataset.range(10).repeat(100) diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py index f3968cdc15..cc22ea1df7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import get_single_element from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class GetSingleElementTest(test.TestCase, parameterized.TestCase): +class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("Zero", 0, 1), diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py index 9c508d686d..d4d3d4adb2 100644 --- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py @@ -19,29 +19,30 @@ from __future__ import print_function import unittest -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import indexed_dataset_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.platform import test -class IndexedDatasetOpsTest(test.TestCase): +class IndexedDatasetOpsTest(test_base.DatasetTestBase): def testLowLevelIndexedDatasetOps(self): - identity = gen_dataset_ops.identity_indexed_dataset( + identity = ged_ops.experimental_identity_indexed_dataset( ops.convert_to_tensor(16, dtype=dtypes.uint64)) - handle = gen_dataset_ops.materialized_index_dataset_handle( + handle = ged_ops.experimental_materialized_index_dataset_handle( container="", shared_name="", output_types=[dtypes.uint64], output_shapes=[[]]) - materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle) + materialize = ged_ops.experimental_indexed_dataset_materialize( + identity, handle) index = array_ops.placeholder(dtypes.uint64) - get_op = gen_dataset_ops.indexed_dataset_get( + get_op = ged_ops.experimental_indexed_dataset_get( handle, index, output_types=[dtypes.uint64], output_shapes=[[]]) with self.cached_session() as sess: diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index b9e74dfddb..28bd670ab5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -25,6 +25,7 @@ import time from six.moves import zip_longest from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class ParallelInterleaveDatasetTest(test.TestCase): +class ParallelInterleaveDatasetTest(test_base.DatasetTestBase): def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 7e2326bd17..58a1d7c93b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import iterator_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn @@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import training_util -class CheckpointInputPipelineHookTest(test.TestCase): +class CheckpointInputPipelineHookTest(test_base.DatasetTestBase): @staticmethod def _model_fn(features, labels, mode, config): diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py index 1cc5ddc9a2..d2a72272db 100644 --- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py @@ -22,6 +22,7 @@ import os import shutil from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,7 @@ from tensorflow.python.util import compat prefix_path = "tensorflow/core/lib" -class LMDBDatasetTest(test.TestCase): +class LMDBDatasetTest(test_base.DatasetTestBase): def setUp(self): super(LMDBDatasetTest, self).setUp() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index e8519381d6..385c4ef6ea 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import optimization from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops @@ -41,7 +42,7 @@ from tensorflow.python.util import compat _NUMPY_RANDOM_SEED = 42 -class MapDatasetTest(test.TestCase): +class MapDatasetTest(test_base.DatasetTestBase): def testMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py index 25aea0393f..751e6d5b30 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py @@ -21,6 +21,7 @@ import time from tensorflow.contrib.data.python.ops import map_defun from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapDefunTest(test.TestCase): + +class MapDefunTest(test_base.DatasetTestBase): def testMapDefunSimple(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD index 1ae92bdeff..d7b5edcd9a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD @@ -15,6 +15,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", ], ) @@ -31,6 +32,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -57,7 +59,6 @@ py_test( srcs = ["map_vectorization_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/kernel_tests:test_utils", "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:check_ops", "//tensorflow/python:client_testlib", @@ -67,6 +68,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:session", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -85,6 +87,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -102,6 +105,7 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "@absl_py//absl/testing:parameterized", ], @@ -121,6 +125,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -137,6 +142,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], @@ -151,6 +157,7 @@ py_test( "//tensorflow/contrib/data/python/ops:optimization", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", + "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py index d10da80442..fe1b5280ba 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py @@ -18,12 +18,13 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.platform import test -class AssertNextDatasetTest(test.TestCase): +class AssertNextDatasetTest(test_base.DatasetTestBase): def testAssertNext(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py index 9518c2e1ad..b43efb5c7c 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -31,7 +32,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class HoistRandomUniformTest(test.TestCase, parameterized.TestCase): +class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def map_functions(): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py index e75edf6086..e9e3fc81e5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase): +class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def map_functions(): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py index dd547db086..f7907eb890 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from absl.testing import parameterized from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class MapParallelizationTest(test.TestCase, parameterized.TestCase): +class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase): @staticmethod def map_functions(): diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py index 5b493f44c9..a5ea85f454 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py @@ -22,9 +22,9 @@ import time from absl.testing import parameterized import numpy as np -from tensorflow.contrib.data.python.kernel_tests import test_utils from tensorflow.contrib.data.python.ops import optimization from tensorflow.python.client import session +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): +class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase): def _get_test_datasets(self, base_dataset, @@ -85,7 +85,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): [3, 4]]).repeat(5) unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn, num_parallel_calls) - self._assert_datasets_equal(unoptimized, optimized) + self.assertDatasetsEqual(unoptimized, optimized) def testOptimizationBadMapFn(self): # Test map functions that give an error @@ -112,7 +112,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): # TODO(rachelim): when this optimization works, turn on expect_optimized unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_equal(optimized, unoptimized) + self.assertDatasetsEqual(optimized, unoptimized) def testOptimizationIgnoreStateful(self): @@ -124,7 +124,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): [3, 4]]).repeat(5) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error( + self.assertDatasetsRaiseSameError( unoptimized, optimized, errors.InvalidArgumentError, [("OneShotIterator", "OneShotIterator_1", 1), ("IteratorGetNext", "IteratorGetNext_1", 1)]) @@ -138,7 +138,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_equal(unoptimized, optimized) + self.assertDatasetsEqual(unoptimized, optimized) def testOptimizationIgnoreRaggedMap(self): # Don't optimize when the output of the map fn shapes are unknown. @@ -148,7 +148,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase): base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True) unoptimized, optimized = self._get_test_datasets( base_dataset, map_fn, expect_optimized=False) - self._assert_datasets_raise_same_error( + self.assertDatasetsRaiseSameError( unoptimized, optimized, errors.InvalidArgumentError, [("OneShotIterator", "OneShotIterator_1", 1), ("IteratorGetNext", "IteratorGetNext_1", 1)]) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py index 3b62a7e468..33c250ab2a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py @@ -23,12 +23,13 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class ModelDatasetTest(test.TestCase): +class ModelDatasetTest(test_base.DatasetTestBase): def testModelMap(self): k = 1024 * 1024 diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py index 507feda3ad..b9e60cfa4e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class NoopEliminationTest(test.TestCase): +class NoopEliminationTest(test_base.DatasetTestBase): def testNoopElimination(self): a = constant_op.constant(1, dtype=dtypes.int64) diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py index a3fb824ce9..04f499f8c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import optimization +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -28,7 +29,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.platform import test -class OptimizeDatasetTest(test.TestCase): +class OptimizeDatasetTest(test_base.DatasetTestBase): def testOptimizationDefault(self): dataset = dataset_ops.Dataset.range(10).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py index c4623bca73..66ccaceea5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes @@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors, i += 1 -class ParseExampleTest(test.TestCase): +class ParseExampleTest(test_base.DatasetTestBase): def _test(self, input_tensor, diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py index 33a64ea767..7a6a7a709a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -22,6 +22,7 @@ import threading from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.compat import compat +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op @@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test -class PrefetchingKernelsOpsTest(test.TestCase): +class PrefetchingKernelsOpsTest(test_base.DatasetTestBase): def setUp(self): self._event = threading.Event() @@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase): sess.run(destroy_op) -class PrefetchToDeviceTest(test.TestCase): +class PrefetchToDeviceTest(test_base.DatasetTestBase): def testPrefetchToDevice(self): host_dataset = dataset_ops.Dataset.range(10) @@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase): sess.run(next_element) -class CopyToDeviceTest(test.TestCase): +class CopyToDeviceTest(test_base.DatasetTestBase): def testCopyToDevice(self): host_dataset = dataset_ops.Dataset.range(10) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index db8fe6aa1b..2e901587f4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import counter from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test -class RangeDatasetTest(test.TestCase): +class RangeDatasetTest(test_base.DatasetTestBase): def testEnumerateDataset(self): components = (["a", "b"], [1, 2], [37.0, 38]) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index ed75b27a44..66ed547b6d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.data.util import nest from tensorflow.python.framework import constant_op @@ -242,7 +243,7 @@ class ReadBatchFeaturesTest( self.assertEqual(32, shape[0]) -class MakeCsvDatasetTest(test.TestCase): +class MakeCsvDatasetTest(test_base.DatasetTestBase): def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs): return readers.make_csv_dataset( diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py index 08b9f03816..f443b5501b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py @@ -25,6 +25,7 @@ import zlib from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op @@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops -from tensorflow.python.platform import test from tensorflow.python.util import compat -class FixedLengthRecordDatasetTestBase(test.TestCase): +class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing FixedLengthRecordDataset.""" def setUp(self): @@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase): return filenames -class ReadBatchFeaturesTestBase(test.TestCase): +class ReadBatchFeaturesTestBase(test_base.DatasetTestBase): """Base class for setting up and testing `make_batched_feature_dataset`.""" def setUp(self): @@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase): self.assertAllEqual(expected_batch[i], actual_batch[i]) -class TextLineDatasetTestBase(test.TestCase): +class TextLineDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing TextLineDataset.""" def _lineText(self, f, l): @@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase): return filenames -class TFRecordDatasetTestBase(test.TestCase): +class TFRecordDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing TFRecordDataset.""" def setUp(self): diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 16b1441baa..32474bd411 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -24,6 +24,7 @@ import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.data.python.ops import resampling +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -57,7 +58,7 @@ def _time_resampling( return end_time - start_time -class ResampleTest(test.TestCase, parameterized.TestCase): +class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("InitialDistributionKnown", True), diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index dde678bd54..bdf80eae4e 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -22,6 +22,7 @@ import itertools import numpy as np from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op @@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class ScanDatasetTest(test.TestCase): +class ScanDatasetTest(test_base.DatasetTestBase): def _counting_dataset(self, start, scan_fn): return dataset_ops.Dataset.from_tensors(0).repeat().apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py index 14cd3e9c4a..a10f85263a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -90,6 +91,16 @@ class StatsDatasetSerializationTest( lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), None, num_outputs) + def _build_dataset_stats_aggregator(self): + stats_aggregator = stats_ops.StatsAggregator() + return dataset_ops.Dataset.range(10).apply( + stats_ops.set_stats_aggregator(stats_aggregator)) + + def test_set_stats_aggregator_not_support_checkpointing(self): + with self.assertRaisesRegexp(errors.UnimplementedError, + "does not support checkpointing"): + self.run_core_tests(self._build_dataset_stats_aggregator, None, 10) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 440e48db30..c97002a255 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -20,13 +20,14 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import shuffle_ops +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.platform import test -class ShuffleAndRepeatTest(test.TestCase): +class ShuffleAndRepeatTest(test_base.DatasetTestBase): def _build_ds(self, seed, count=5, num_elements=20): return dataset_ops.Dataset.range(num_elements).apply( diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py index 90d18dca2a..c5a7862322 100644 --- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py @@ -21,6 +21,7 @@ from absl.testing import parameterized import numpy as np from tensorflow.contrib.data.python.ops import sliding +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class SlideDatasetTest(test.TestCase, parameterized.TestCase): +class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.named_parameters( ("1", 20, 14, 7, 1), @@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase): sliding.sliding_window_batch( window_size=1, stride=1, window_shift=1, window_stride=1)) - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - def testSlideSparse(self): def _sparse(i): diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py index 1f5c725a92..319a2ea263 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py @@ -24,12 +24,13 @@ import os import sqlite3 from tensorflow.contrib.data.python.ops import readers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class SqlDatasetTestBase(test.TestCase): +class SqlDatasetTestBase(test_base.DatasetTestBase): """Base class for setting up and testing SqlDataset.""" def _createSqlDataset(self, output_types, num_repeats=1): @@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase): 9007199254740992.0)]) conn.commit() conn.close() - - diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py index b1b4c23510..80f2625927 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py @@ -19,10 +19,10 @@ from __future__ import print_function from tensorflow.core.framework import summary_pb2 -from tensorflow.python.platform import test +from tensorflow.python.data.kernel_tests import test_base -class StatsDatasetTestBase(test.TestCase): +class StatsDatasetTestBase(test_base.DatasetTestBase): """Base class for testing statistics gathered in `StatsAggregator`.""" def _assertSummaryContains(self, summary_str, tag): diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py deleted file mode 100644 index 4c3353fe40..0000000000 --- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test utilities for tf.data functionality.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re - -from tensorflow.python.data.util import nest -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class DatasetTestBase(test.TestCase): - """Base class for dataset tests.""" - - def _assert_datasets_equal(self, dataset1, dataset2): - # TODO(rachelim): support sparse tensor outputs - next1 = dataset1.make_one_shot_iterator().get_next() - next2 = dataset2.make_one_shot_iterator().get_next() - with self.cached_session() as sess: - while True: - try: - op1 = sess.run(next1) - except errors.OutOfRangeError: - with self.assertRaises(errors.OutOfRangeError): - sess.run(next2) - break - op2 = sess.run(next2) - - op1 = nest.flatten(op1) - op2 = nest.flatten(op2) - assert len(op1) == len(op2) - for i in range(len(op1)): - self.assertAllEqual(op1[i], op2[i]) - - def _assert_datasets_raise_same_error(self, - dataset1, - dataset2, - exception_class, - replacements=None): - # We are defining next1 and next2 in the same line so that we get identical - # file:line_number in the error messages - # pylint: disable=line-too-long - next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next() - # pylint: enable=line-too-long - with self.cached_session() as sess: - try: - sess.run(next1) - raise ValueError( - "Expected dataset to raise an error of type %s, but it did not." % - repr(exception_class)) - except exception_class as e: - expected_message = e.message - for old, new, count in replacements: - expected_message = expected_message.replace(old, new, count) - # Check that the first segment of the error messages are the same. - with self.assertRaisesRegexp(exception_class, - re.escape(expected_message)): - sess.run(next2) diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py index 8d335e87d5..08de3a9143 100644 --- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import threadpool from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase): +class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase, + parameterized.TestCase): @parameterized.named_parameters( ("1", 1, None), diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py index f994c8563f..8856ce5afb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -25,7 +26,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class UniqueDatasetTest(test.TestCase): +class UniqueDatasetTest(test_base.DatasetTestBase): def _testSimpleHelper(self, dtype, test_cases): """Test the `unique()` transformation on a list of test cases. diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py index 8b7b3ac0f7..79134c7bc6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -31,7 +32,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class WindowDatasetTest(test.TestCase, parameterized.TestCase): +class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): def _structuredDataset(self, structure, shape, dtype): if structure is None: diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py index 867ee2ba37..fca546a570 100644 --- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os from tensorflow.contrib.data.python.ops import writers +from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers from tensorflow.python.framework import dtypes @@ -30,7 +31,7 @@ from tensorflow.python.platform import test from tensorflow.python.util import compat -class TFRecordWriterTest(test.TestCase): +class TFRecordWriterTest(test_base.DatasetTestBase): def setUp(self): super(TFRecordWriterTest, self).setUp() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index a14781cd93..5cd1ed542b 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -78,7 +78,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":batching", - ":gen_dataset_ops", ":interleave_ops", ":optimization", ":parsing_ops", @@ -86,6 +85,7 @@ py_library( "//tensorflow/python:constant_op", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:platform", @@ -148,8 +148,7 @@ py_library( srcs = ["error_ops.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", @@ -179,12 +178,11 @@ py_library( srcs = ["interleave_ops.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", ":random_ops", "//tensorflow/contrib/stateless", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:util", @@ -199,9 +197,8 @@ py_library( srcs = ["optimization.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", @@ -304,8 +301,7 @@ py_library( srcs = ["threadpool.py"], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", @@ -321,9 +317,8 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", @@ -342,47 +337,11 @@ py_library( ], ) -tf_gen_op_wrapper_py( - name = "gen_dataset_ops", - out = "gen_dataset_ops.py", - deps = [ - "//tensorflow/contrib/data:dataset_ops_op_lib", - "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", - ], -) - -tf_kernel_library( - name = "dataset_ops_kernels", - deps = [ - "//tensorflow/contrib/data/kernels:dataset_kernels", - "//tensorflow/core:framework", - ], - alwayslink = 1, -) - -tf_custom_op_py_library( - name = "contrib_op_loader", - srcs = ["contrib_op_loader.py"], - dso = ["//tensorflow/contrib/data:_dataset_ops.so"], - kernels = [ - ":dataset_ops_kernels", - "//tensorflow/contrib/data:indexed_dataset_ops_op_lib", - "//tensorflow/contrib/data:dataset_ops_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":gen_dataset_ops", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:platform", - ], -) - py_library( name = "indexed_dataset_ops", srcs = ["indexed_dataset_ops.py"], deps = [ - ":contrib_op_loader", - ":gen_dataset_ops", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", @@ -394,7 +353,7 @@ py_library( name = "prefetching_ops", srcs = ["prefetching_ops.py"], deps = [ - ":contrib_op_loader", + "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/data/python/ops/contrib_op_loader.py deleted file mode 100644 index 8f495a9dc9..0000000000 --- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python helper for loading contrib ops and kernels.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.util import loader -from tensorflow.python.platform import resource_loader - -_dataset_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index 615dbcabd4..f962e623ee 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,9 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops def ignore_errors(): @@ -60,7 +59,7 @@ class _IgnoreErrorsDataset(dataset_ops.UnaryDataset): self._input_dataset = input_dataset def _as_variant_tensor(self): - return gen_dataset_ops.ignore_errors_dataset( + return gen_experimental_dataset_ops.experimental_ignore_errors_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access **dataset_ops.flat_structure(self)) diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py index cc76ab0850..9c06474a2f 100644 --- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py @@ -19,14 +19,13 @@ from __future__ import print_function import abc -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops class MaterializedIndexedDataset(object): @@ -57,7 +56,7 @@ class MaterializedIndexedDataset(object): A tensor containing the values corresponding to `index`. """ # TODO(saeta): nest.pack_sequence_as(...) - return gen_dataset_ops.indexed_dataset_get( + return ged_ops.experimental_indexed_dataset_get( self._materialized_resource, index, output_types=nest.flatten( @@ -90,16 +89,18 @@ class IndexedDataset(dataset_ops.Dataset): container = "" if shared_name is None: shared_name = "" - materialized_resource = gen_dataset_ops.materialized_index_dataset_handle( - container=container, - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_types(self.output_shapes, self.output_classes))) + materialized_resource = ( + ged_ops.experimental_materialized_index_dataset_handle( + container=container, + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_types(self.output_shapes, + self.output_classes)))) with ops.colocate_with(materialized_resource): - materializer = gen_dataset_ops.indexed_dataset_materialize( + materializer = ged_ops.experimental_indexed_dataset_materialize( self._as_variant_tensor(), materialized_resource) return MaterializedIndexedDataset(materialized_resource, materializer, self.output_classes, self.output_types, @@ -170,7 +171,7 @@ class IdentityIndexedDataset(IndexedDataset): return tensor_shape.scalar() def _as_variant_tensor(self): - return gen_dataset_ops.identity_indexed_dataset(self._size) + return ged_ops.experimental_identity_indexed_dataset(self._size) def _inputs(self): return [] diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index bfa3fdf543..1ee9db1aa8 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -18,8 +18,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import stateless -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.contrib.data.python.ops import random_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers @@ -28,6 +26,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import math_ops from tensorflow.python.util import deprecation @@ -167,10 +166,12 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset): def _as_variant_tensor(self): # pylint: disable=protected-access - return gen_dataset_ops.directed_interleave_dataset( - self._selector_input._as_variant_tensor(), - [data_input._as_variant_tensor() for data_input in self._data_inputs], - **dataset_ops.flat_structure(self)) + return ( + gen_experimental_dataset_ops.experimental_directed_interleave_dataset( + self._selector_input._as_variant_tensor(), [ + data_input._as_variant_tensor() + for data_input in self._data_inputs + ], **dataset_ops.flat_structure(self))) # pylint: enable=protected-access def _inputs(self): diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py index 3eb172acd5..30348ede36 100644 --- a/tensorflow/contrib/data/python/ops/optimization.py +++ b/tensorflow/contrib/data/python/ops/optimization.py @@ -17,12 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops # A constant that can be used to enable auto-tuning. AUTOTUNE = -1 @@ -54,7 +53,7 @@ def model(): Returns: A `Dataset` transformation function, which can be passed to - @{tf.data.Dataset.apply}. + `tf.data.Dataset.apply`. """ def _apply_fn(dataset): @@ -97,7 +96,7 @@ class _AssertNextDataset(dataset_ops.UnaryDataset): transformations, dtype=dtypes.string, name="transformations") def _as_variant_tensor(self): - return contrib_gen_dataset_ops.assert_next_dataset( + return gen_experimental_dataset_ops.experimental_assert_next_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._transformations, **dataset_ops.flat_structure(self)) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index 58395879e6..46f82e453a 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -19,8 +19,6 @@ from __future__ import print_function import warnings -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest @@ -32,7 +30,8 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import resource_variable_ops @@ -64,7 +63,7 @@ def function_buffering_resource(string_arg, """ if shared_name is None: shared_name = "" - return gen_dataset_ops.function_buffering_resource( + return ged_ops.experimental_function_buffering_resource( string_arg=string_arg, target_device=target_device, shared_name=shared_name, @@ -78,14 +77,14 @@ def function_buffering_resource(string_arg, def function_buffering_resource_get_next(function_buffer_resource, output_types, name=None): - return gen_dataset_ops.function_buffering_resource_get_next( + return ged_ops.experimental_function_buffering_resource_get_next( function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) def function_buffering_resource_reset(function_buffer_resource, name=None): - return gen_dataset_ops.function_buffering_resource_reset( + return ged_ops.experimental_function_buffering_resource_reset( function_buffer_resource=function_buffer_resource, name=name) @@ -136,7 +135,7 @@ class _PrefetchToDeviceIterator(object): ret = remote_iterator.get_next() return nest.flatten(sparse.serialize_sparse_tensors(ret)) - iterator_device = gen_dataset_ops.iterator_get_device( + iterator_device = ged_ops.experimental_iterator_get_device( self._input_iterator._iterator_resource) with ops.device(device): @@ -162,10 +161,11 @@ class _PrefetchToDeviceIterator(object): if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) - flat_ret = gen_dataset_ops.function_buffering_resource_get_next( + flat_ret = ged_ops.experimental_function_buffering_resource_get_next( self._buffering_resource, - output_types=nest.flatten(sparse.as_dense_types( - self.output_types, self.output_classes)), name=name) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + name=name) ret = sparse.deserialize_sparse_tensors( nest.pack_sequence_as(self.output_types, flat_ret), @@ -219,7 +219,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): buffer_size): with ops.device("/device:CPU:0"): super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset) - input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle( + input_iterator_handle = gen_dataset_ops.iterator_to_string_handle( self._resource) self._device = device @@ -238,7 +238,8 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): self._buffering_resource = function_buffering_resource( f=_prefetch_fn, output_types=self._flat_output_types, - target_device=gen_dataset_ops.iterator_get_device(self._resource), + target_device=ged_ops.experimental_iterator_get_device( + self._resource), string_arg=input_iterator_handle, buffer_size=buffer_size, shared_name=iterator_ops._generate_shared_name( @@ -252,7 +253,7 @@ class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator): # TODO(b/77291417): Fix with context.execution_mode(context.SYNC): with ops.device(self._device): - ret = gen_dataset_ops.function_buffering_resource_get_next( + ret = ged_ops.experimental_function_buffering_resource_get_next( function_buffer_resource=self._buffering_resource, output_types=self._flat_output_types) return sparse.deserialize_sparse_tensors( @@ -409,12 +410,12 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset): """ # pylint: disable=protected-access ds_variant = self._input_dataset._as_variant_tensor() - resource = core_gen_dataset_ops.anonymous_iterator( + resource = gen_dataset_ops.anonymous_iterator( output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) with ops.control_dependencies( - [core_gen_dataset_ops.make_iterator(ds_variant, resource)]): - return core_gen_dataset_ops.iterator_to_string_handle(resource) + [gen_dataset_ops.make_iterator(ds_variant, resource)]): + return gen_dataset_ops.iterator_to_string_handle(resource) @function.Defun() def _remote_init_func(): @@ -463,7 +464,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset): Returns: Tensor constant 0 """ - iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2( + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( string_handle, output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) @@ -504,7 +505,7 @@ class _CopyToDeviceDataset(dataset_ops.UnaryDataset): def _as_variant_tensor(self): with ops.device(self._target_device): - return core_gen_dataset_ops.generator_dataset( + return gen_dataset_ops.generator_dataset( self._init_captured_args, self._next_captured_args, self._finalize_captured_args, diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index d9d06e2703..360971e200 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -23,7 +23,6 @@ import csv import numpy as np from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.contrib.data.python.ops import optimization from tensorflow.contrib.data.python.ops import parsing_ops @@ -38,6 +37,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import file_io from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.platform import gfile from tensorflow.python.util import deprecation @@ -629,7 +629,7 @@ class CsvDataset(dataset_ops.DatasetSource): def _as_variant_tensor(self): # Constructs graph node for the dataset op. - return contrib_gen_dataset_ops.csv_dataset( + return gen_experimental_dataset_ops.experimental_csv_dataset( filenames=self._filenames, record_defaults=self._record_defaults, buffer_size=self._buffer_size, @@ -1013,7 +1013,7 @@ class LMDBDataset(dataset_ops.DatasetSource): filenames, dtype=dtypes.string, name="filenames") def _as_variant_tensor(self): - return contrib_gen_dataset_ops.lmdb_dataset( + return gen_experimental_dataset_ops.experimental_lmdb_dataset( self._filenames, output_types=nest.flatten(self.output_types), output_shapes=nest.flatten(self.output_shapes)) diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py index 9d165ad52a..f73c3fd9cb 100644 --- a/tensorflow/contrib/data/python/ops/threadpool.py +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -19,10 +19,9 @@ from __future__ import print_function import threading -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import context +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops from tensorflow.python.ops import resource_variable_ops _uid_counter = 0 @@ -47,7 +46,7 @@ class PrivateThreadPool(object): """Creates a `PrivateThreadPool` with the given number of threads.""" if context.executing_eagerly(): shared_name = _generate_shared_name("privatethreadpool") - self._resource = gen_dataset_ops.thread_pool_handle( + self._resource = ged_ops.experimental_thread_pool_handle( num_threads=num_threads, max_intra_op_parallelism=max_intra_op_parallelism, display_name=display_name, @@ -55,7 +54,7 @@ class PrivateThreadPool(object): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device=context.context().device_name) else: - self._resource = gen_dataset_ops.thread_pool_handle( + self._resource = ged_ops.experimental_thread_pool_handle( num_threads=num_threads, max_intra_op_parallelism=max_intra_op_parallelism, display_name=display_name) @@ -70,7 +69,7 @@ class _ThreadPoolDataset(dataset_ops.UnaryDataset): self._thread_pool = thread_pool def _as_variant_tensor(self): - return gen_dataset_ops.thread_pool_dataset( + return ged_ops.experimental_thread_pool_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._thread_pool._resource, # pylint: disable=protected-access **dataset_ops.flat_structure(self)) diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index bad67a580d..ed363a7090 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -17,10 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_experimental_dataset_ops def unique(): @@ -61,7 +60,7 @@ class _UniqueDataset(dataset_ops.UnaryDataset): "`tf.int32`, `tf.int64`, or `tf.string` component.") def _as_variant_tensor(self): - return gen_dataset_ops.unique_dataset( + return gen_experimental_dataset_ops.experimental_unique_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access **dataset_ops.flat_structure(self)) diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 3b50a48336..06940a90d5 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -17,7 +17,6 @@ tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], cc_api_version = 2, - java_api_version = 2, visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 91a27f97b7..2e025765e4 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use important to shuffle your dataset in your `input_fn`. `MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you -`input_fn`. As a result, each worker gets a fraction of your input data. +`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker +gets a fraction of your input data. ### Performance Tips diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index e329b964c4..422983dbef 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -22,6 +22,7 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":input_ops", + ":prefetching_ops_v2", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", @@ -29,7 +30,6 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:training", "//tensorflow/python:util", - "//tensorflow/python/data/ops:multi_device_iterator_ops", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", @@ -648,6 +648,32 @@ cuda_py_test( ) py_library( + name = "prefetching_ops_v2", + srcs = ["prefetching_ops_v2.py"], + deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_ops", + "//tensorflow/python:experimental_dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + +cuda_py_test( + name = "prefetching_ops_v2_test", + srcs = ["prefetching_ops_v2_test.py"], + additional_deps = [ + ":prefetching_ops_v2", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_library( name = "input_ops", srcs = ["input_ops.py"], visibility = ["//tensorflow:internal"], diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py index c900b41e14..9809204f8f 100644 --- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py +++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py @@ -216,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy): """Configures the object. Args: - session_config: a @{tf.ConfigProto} + session_config: a `tf.ConfigProto` cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the cluster configurations. task_type: the current task type, such as "worker". diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py index 244d1fcec8..82ca041cc2 100644 --- a/tensorflow/contrib/distribute/python/combinations.py +++ b/tensorflow/contrib/distribute/python/combinations.py @@ -59,6 +59,7 @@ from tensorflow.python.training import adagrad from tensorflow.python.training import adam from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import gradient_descent +from tensorflow.python.training import rmsprop from tensorflow.python.util import tf_inspect @@ -354,6 +355,8 @@ gradient_descent_optimizer_v1_fn = NamedObject( "GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2)) adagrad_optimizer_v1_fn = NamedObject( "AdagradV1", lambda: adagrad.AdagradOptimizer(0.001)) +rmsprop_optimizer_v1_fn = NamedObject( + "RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001)) optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn] diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index a0b8bde132..3aab2c521f 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -173,13 +173,42 @@ def batch_wrapper(dataset, batch_size, distribution): return dataset.batch(batch_size) -def all_combinations(): +def get_model(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + return model + + +def get_dataset(distribution): + inputs = np.zeros((10, 3), dtype=np.float32) + targets = np.zeros((10, 4), dtype=np.float32) + dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = dataset.repeat(100) + dataset = batch_wrapper(dataset, 10, distribution) + return dataset + + +strategies = [combinations.default_strategy, + combinations.one_device_strategy, + combinations.mirrored_strategy_with_gpu_and_cpu, + combinations.mirrored_strategy_with_two_gpus, + combinations.tpu_strategy_one_step] + + +def strategy_combinations(): return combinations.combine( - distribution=[combinations.default_strategy, - combinations.one_device_strategy, - combinations.mirrored_strategy_with_gpu_and_cpu, - combinations.mirrored_strategy_with_two_gpus, - combinations.tpu_strategy_one_step], + distribution=strategies, + mode=['graph']) + + +def strategy_and_optimizer_combinations(): + return combinations.combine( + distribution=strategies, + optimizer=[combinations.adagrad_optimizer_v1_fn, + combinations.adam_optimizer_v1_fn, + combinations.gradient_descent_optimizer_v1_fn, + combinations.rmsprop_optimizer_v1_fn], mode=['graph']) @@ -360,9 +389,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_calling_model_with_numpy_arrays(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' @@ -392,23 +419,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): # with batch_size model.predict(inputs, batch_size=8) - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 10, distribution) + dataset = get_dataset(distribution) # Call fit with validation data model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, @@ -461,23 +482,17 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1) - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_fit_eval_and_predict_methods_on_dataset(self, distribution): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' metrics = ['mae', keras.metrics.CategoricalAccuracy()] model.compile(optimizer, loss, metrics=metrics, distribute=distribution) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = batch_wrapper(dataset, 10, distribution) + dataset = get_dataset(distribution) model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) model.evaluate(dataset, steps=2, verbose=1) @@ -486,11 +501,23 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0, validation_data=dataset, validation_steps=2) + @combinations.generate(strategy_and_optimizer_combinations()) + def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer): + with self.cached_session(): + model = get_model() + + loss = 'mse' + model.compile(optimizer(), loss, distribute=distribution) + + dataset = get_dataset(distribution) + + model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) + model.evaluate(dataset, steps=2, verbose=1) + model.predict(dataset, steps=2) + def test_unsupported_features(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' @@ -500,11 +527,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = get_dataset(strategy) # Test with validation split with self.assertRaisesRegexp( @@ -541,9 +564,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_calling_with_unsupported_predefined_callbacks(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = gradient_descent.GradientDescentOptimizer(0.001) loss = 'mse' @@ -552,11 +573,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): '/device:GPU:0']) model.compile(optimizer, loss, metrics=metrics, distribute=strategy) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) - dataset = dataset.repeat(100) - dataset = dataset.batch(10) + dataset = get_dataset(strategy) def schedule(_): return 0.001 @@ -580,9 +597,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): def test_dataset_input_shape_validation(self): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' @@ -616,17 +631,13 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): mode=['graph'])) def test_dataset_input_shape_fully_defined(self, distribution): with self.cached_session(): - x = keras.layers.Input(shape=(3,), name='input') - y = keras.layers.Dense(4, name='dense')(x) - model = keras.Model(x, y) + model = get_model() optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001) loss = 'mse' model.compile(optimizer, loss, distribute=distribution) - inputs = np.zeros((10, 3), dtype=np.float32) - targets = np.zeros((10, 4), dtype=np.float32) - dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) + dataset = get_dataset(distribution) # Input shapes are not fully known. Batch dimension is unknown as we are # not using the drop_remainder argument. dataset = dataset.repeat(100).batch(10) @@ -698,7 +709,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase): class NormalizationLayerWithDistributionStrategyTest( test.TestCase, parameterized.TestCase): - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_batchnorm_correctness(self, distribution): with self.cached_session(): model = keras.models.Sequential() @@ -726,7 +737,7 @@ class NormalizationLayerWithDistributionStrategyTest( class CorrectnessWithDistributionStrategyTest(test.TestCase, parameterized.TestCase): - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_metric_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') @@ -756,7 +767,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase, history = model.fit(x=train_dataset, epochs=1, steps_per_epoch=10) self.assertEqual(history.history['binary_accuracy'], [1.0]) - @combinations.generate(all_combinations()) + @combinations.generate(strategy_combinations()) def test_correctness(self, distribution): with self.cached_session(): keras.backend.set_image_data_format('channels_last') diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index f7773aff4f..8163494c8e 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -86,11 +86,10 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( - dataset_fn).make_initializable_iterator() + dataset_fn).make_one_shot_iterator() value, update = distribution.call_for_each_tower( metric_fn, iterator.get_next()) update = distribution.group(update) - self.evaluate(iterator.initializer) self.evaluate(variables.local_variables_initializer()) # TODO(josh11b): Once we switch to using a global batch size for input, # replace "distribution.num_towers" with "1". diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index d082d5c419..ba147e7824 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,14 +41,6 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): - def _get_iterator(self, ds): - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate(iterator.initializer) - return iterator - @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -70,7 +62,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -106,7 +99,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.group( @@ -165,7 +159,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -249,7 +244,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -342,7 +338,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, x, y, run_concurrently=False)) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return distribution.run_steps_on_dataset( @@ -435,7 +432,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): initial_loss = lambda: constant_op.constant(1e7) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 945f450387..4d7516063c 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input data to devices. + auto_shard_dataset: whether to auto-shard the dataset when there are + multiple workers. """ def __init__(self, @@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_tower_ops=None, - prefetch_on_device=None): + prefetch_on_device=None, + auto_shard_dataset=False): super(MirroredStrategy, self).__init__() self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + self._auto_shard_dataset = auto_shard_dataset # Rememeber num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( @@ -477,13 +481,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) + self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), - self._devices, - self._prefetch_on_device, - source_device=device_util.resolve("/device:CPU:0")) + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _run_steps_on_dataset(self, fn, iterator, iterations, diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 04c712ce1d..f51e543624 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -300,15 +300,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) - ds = dist.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() - self.evaluate([iterator.initializer]) - - features = iterator.get_next() + features = dist.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) + ).make_one_shot_iterator().get_next() with dist.scope(): result = dist.call_for_each_tower( diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 17b7ab74f6..7644acedc9 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,7 +51,6 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") - session.run(step_callable._iterator.initializer) # pylint: disable=protected-access self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 3064433129..6e9ba37a19 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -42,11 +42,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - ds = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - iterator = ds.make_one_shot_iterator() - else: - iterator = ds.make_initializable_iterator() + iterator = distribution.distribute_dataset( + dataset_fn).make_one_shot_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( @@ -55,7 +52,6 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.cached_session() as sess: - sess.run(iterator.initializer) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py new file mode 100644 index 0000000000..8d949943b7 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -0,0 +1,232 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Extension of prefetching_ops to support more than one device.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.util import nest as data_nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops +from tensorflow.python.util import nest + + +# pylint: disable=protected-access +class _PrefetchToDeviceIterator(object): + """A replacement for `tf.data.Iterator` that prefetches to another device. + + Args: + input_dataset: The input dataset. + one_shot: If true, we make a one shot iterator that's already initialized. + devices: Devices on which to prefetch. + buffer_size: Size of the prefetching buffer. + shared_name: (Optional.) If non-empty, the returned iterator will be shared + under the given name across multiple sessions that share the same devices + (e.g. when using a remote server). Only used if one_shot is False. + + Returns: + An Iterator type object. + """ + + def __init__(self, + input_dataset, + one_shot, + devices, + buffer_size, + shared_name=None): + self._input_dataset = input_dataset + self._get_next_call_count = 0 + self._one_shot = one_shot + if shared_name is None: + shared_name = "" + self._devices = devices + + if self._one_shot: + self._input_iterator = input_dataset.make_one_shot_iterator() + else: + self._input_iterator = iterator_ops.Iterator.from_structure( + self._input_dataset.output_types, self._input_dataset.output_shapes, + shared_name, self._input_dataset.output_classes) + input_iterator_handle = self._input_iterator.string_handle() + + @function.Defun(dtypes.string) + def _prefetch_fn(handle): + """Prefetches one element from `input_iterator`.""" + remote_iterator = iterator_ops.Iterator.from_string_handle( + handle, self._input_iterator.output_types, + self._input_iterator.output_shapes, + self._input_iterator.output_classes) + ret = remote_iterator.get_next() + return nest.flatten(sparse.serialize_sparse_tensors(ret)) + + target_device = ged_ops.experimental_iterator_get_device( + self._input_iterator._iterator_resource) + self._buffering_resources = [] + for device in nest.flatten(self._devices): + with ops.device(device): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_prefetch_fn, + output_types=data_nest.flatten( + sparse.as_dense_types(self._input_dataset.output_types, + self._input_dataset.output_classes)), + target_device=target_device, + string_arg=input_iterator_handle, + buffer_size=buffer_size, + shared_name=shared_name) + self._buffering_resources.append(buffer_resource_handle) + + if not self._one_shot: + reset_ops = [] + for buffer_resource in self._buffering_resources: + reset_ops.append( + ged_ops.experimental_function_buffering_resource_reset( + buffer_resource)) + with ops.control_dependencies(reset_ops): + self._initializer = self._input_iterator.make_initializer( + self._input_dataset) + + def get_next(self, name=None): + """See `tf.data.Iterator.get_next`.""" + self._get_next_call_count += 1 + if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: + warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) + + flat_result = [] + # TODO(priyag): This will fail if the input size (typically number of + # batches) is not divisible by number of devices. + # How do we handle that more gracefully / let the user know? + for buffer_resource in self._buffering_resources: + flat_ret = ged_ops.experimental_function_buffering_resource_get_next( + buffer_resource, + output_types=data_nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + name=name) + + ret = sparse.deserialize_sparse_tensors( + data_nest.pack_sequence_as(self.output_types, flat_ret), + self.output_types, self.output_shapes, self.output_classes) + + for tensor, shape in zip( + data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): + if isinstance(tensor, ops.Tensor): + tensor.set_shape(shape) + flat_result.append(ret) + + return nest.pack_sequence_as(self._devices, flat_result) + + @property + def initializer(self): + if self._one_shot: + raise NotImplementedError("Can't initialize a one_shot_iterator") + return self._initializer + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +# pylint: enable=protected-access + + +class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): + """A `Dataset` whose iterator prefetches elements to other device(s).""" + + def __init__(self, input_dataset, devices, buffer_size): + super(_PrefetchToDeviceDataset, self).__init__(input_dataset) + self._input_dataset = input_dataset + self._devices = devices + self._buffer_size = buffer_size if buffer_size is not None else 1 + + def make_one_shot_iterator(self): + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=True, + devices=self._devices, + buffer_size=self._buffer_size) + + def make_initializable_iterator(self, shared_name=None): + if context.executing_eagerly(): + raise RuntimeError( + "make_initializable_iterator is not supported when eager " + "execution is enabled.") + + return _PrefetchToDeviceIterator( + self._input_dataset, + one_shot=False, + devices=self._devices, + buffer_size=self._buffer_size, + shared_name=shared_name) + + def _as_variant_tensor(self): + # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset + # transformation methods is called. + # TODO(mrry): Investigate support for chaining further transformations after + # the prefetch, including GPU support. + raise NotImplementedError("`prefetch_to_devices()` must be the last " + "transformation in a dataset pipeline.") + + # TODO(priyag): Fix the output types, shapes and classes to match the result + # of get_next (which has the additional nesting layer of devices now). + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def prefetch_to_devices(devices, buffer_size=None): + """A transformation that prefetches dataset values to the given `devices`. + + NOTE: Although the transformation creates a `tf.data.Dataset`, the + transformation must be the final `Dataset` in the input pipeline. + + Args: + devices: A nested structure of devices on which to prefetch the data. It can + be a single device name, or a tuple or list of device names. + buffer_size: (Optional.) The number of elements to buffer on each device. + Defaults to an automatically chosen value. + + Returns: + A `Dataset` transformation function, which can be passed to + `tf.data.Dataset.apply`. + """ + + def _apply_fn(dataset): + return _PrefetchToDeviceDataset(dataset, devices, buffer_size) + + return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py new file mode 100644 index 0000000000..16799104e8 --- /dev/null +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py @@ -0,0 +1,90 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops_v2.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.platform import test + + +class PrefetchingOpsV2Test(test.TestCase): + + def testPrefetchToOneDevice(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices("/gpu:0")) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.cached_session() as sess: + for i in range(10): + self.assertEqual(i, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testPrefetchToTwoDevicesInAList(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + output = [] + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. + with self.cached_session() as sess: + for _ in range(4): + result = sess.run(next_element) + self.assertEqual(2, len(result)) + output.extend(result) + self.assertEquals(set(range(8)), set(output)) + + def testPrefetchToTwoDevicesWithReinit(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) + + iterator = device_dataset.make_initializable_iterator() + next_element = iterator.get_next() + + # TODO(rohanj): Modify test to go till the end of the dataset when we + # switch to MultiDeviceIterator. + with self.cached_session() as sess: + sess.run(iterator.initializer) + for _ in range(4): + sess.run(next_element) + sess.run(iterator.initializer) + for _ in range(4): + sess.run(next_element) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 23bf36184f..1b5a4f64e5 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop -from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -51,11 +50,7 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) self._distributed_input = distribution.distribute_dataset(dataset_fn) - if context.executing_eagerly(): - self._iterator = self._distributed_input.make_one_shot_iterator() - else: - # TODO(priyag): Expose initializer via some initializer property. - self._iterator = self._distributed_input.make_initializable_iterator() + self._iterator = self._distributed_input.make_one_shot_iterator() class StandardSingleLossStep(StandardInputStep): diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index 1ff9b9ceec..f1ada49fa3 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,7 +50,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): run_step = single_loss_step else: with self.cached_session() as sess: - sess.run(single_loss_step._iterator.initializer) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index a0cd029f51..4955ded4d5 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -26,7 +26,7 @@ import weakref import six from tensorflow.contrib.distribute.python import input_ops -from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops @@ -683,7 +683,7 @@ class PerDeviceDataIterator(object): def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: - data_list = self._iterator.get_next() + data_list = self._iterator.get_next(name=name) index = dict(zip(self._devices, data_list)) else: batch = self._iterator.get_next(name=name) @@ -703,26 +703,21 @@ class PerDeviceDataIterator(object): class PerDeviceDataset(object): """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" - def __init__( - self, - dataset, - devices, - prefetch_on_device=None, - source_device="/cpu:0", - ): + def __init__(self, dataset, devices, prefetch_on_device=None): self._devices = devices - self._source_device = source_device if source_device is not None else "/cpu:0" # Default to using prefetching in graph mode, unless specified. - # TODO(rohanj): Enable prefetching in eager mode. + # TODO(priyag): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") - self._dataset = dataset - if not self._prefetch_on_device: + if self._prefetch_on_device: + self._dataset = dataset.apply( + prefetching_ops_v2.prefetch_to_devices(self._devices)) + else: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. @@ -730,33 +725,15 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" - # Graph mode prefetching with one shot iterator is disabled. - if not context.executing_eagerly(): - raise ValueError("Cannot create a one shot iterator. Please use " - "`make_initializable_iterator()` instead.") - # Eager mode prefetching would error out in constructor. Only remaining - # cases are non-prefetching eager / graph mode. We delegate to - # PerDeviceDataIterator to handle them. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator( - dataset_iterator, self._devices, prefetch_on_device=False) + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" - # Eager mode generates already initialized iterators. Hence we cannot create - # an initializable iterator. - if context.executing_eagerly(): - raise ValueError("Cannot create initializable iterator in Eager mode. " - "Please use `make_one_shot_iterator` instead.") - if self._prefetch_on_device: - dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices, source_device=self._source_device) - else: - dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator( - dataset_iterator, - self._devices, - prefetch_on_device=self._prefetch_on_device) + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator(dataset_iterator, self._devices, + self._prefetch_on_device) class MultiWorkerDataIterator(object): @@ -816,7 +793,8 @@ class MultiWorkerDataset(object): eager mode. """ - def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None, + auto_shard=False): """Initialize the MultiWorkerDataset object. Args: @@ -824,6 +802,7 @@ class MultiWorkerDataset(object): worker_device_map: a dict mapping from each worker to a list of devices that belong to this worker. prefetch_on_device: whether to prefetch to devices. + auto_shard: whether to auto-shard the dataset. """ self._worker_device_map = worker_device_map self._datasets = {} @@ -833,13 +812,11 @@ class MultiWorkerDataset(object): six.iteritems(worker_device_map)): with ops.device(worker): worker_input = dataset_fn() - worker_input = input_ops.auto_shard_dataset( - worker_input, len(worker_device_map), i) + if auto_shard: + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( - worker_input, - worker_devices, - source_device=worker, - prefetch_on_device=prefetch_on_device) + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) def make_one_shot_iterator(self): iterators = {} diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 002d61f46e..ae3e134333 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -349,11 +349,7 @@ class PerDeviceDatasetTest(test.TestCase): def _test_iterator_no_prefetch(self, devices, dataset, expected_values): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=False) - if context.executing_eagerly(): - iterator = per_device_dataset.make_one_shot_iterator() - else: - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) + iterator = per_device_dataset.make_one_shot_iterator() for expected_value in expected_values: next_element = iterator.get_next() @@ -370,14 +366,21 @@ class PerDeviceDatasetTest(test.TestCase): if not context.executing_eagerly(): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_initializable_iterator() - self.evaluate([iterator.initializer]) + iterator = per_device_dataset.make_one_shot_iterator() + # With prefetching, we cannot guarantee which input ends up on which + # device, so we verify that the complete set seen on all devices is + # correct, and equal numbers are distributed to each device. + combined_actual = [] + combined_expected = [] for expected_value in expected_values: next_element = iterator.get_next() - computed_value = self.evaluate( - [values.select_device(d, next_element) for d in devices]) - self.assertEqual(expected_value, computed_value) + combined_actual.extend( + self.evaluate( + [values.select_device(d, next_element) for d in devices])) + combined_expected.extend(expected_value) + + self.assertEqual(set(combined_expected), set(combined_actual)) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index e344d7a23b..510f292508 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -154,6 +154,8 @@ tf_py_test( ], tags = [ "no_pip", # b/38283730 + "noasan", # b/116875897 + "nomsan", "notsan", # Flaky: b/30756419 ], ) @@ -177,7 +179,11 @@ tf_py_test( "//tensorflow/python:random_seed", "//tensorflow/python:variables", ], - tags = ["notsan"], # b/62863147 + tags = [ + "noasan", # b/116875897 + "nomsan", + "notsan", # b/62863147 + ], ) py_library( @@ -276,6 +282,7 @@ tf_py_test( "manual", "noasan", # times out b/63678675 "nomsan", + "notsan", # b/116875897 ], ) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index f320b53d94..f3ebe3b245 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -26,6 +26,14 @@ config_setting( }, ) +# Enables inclusion of TensorFlow kernels via the TF Lite Flex delegate. +# WARNING: This build flag is experimental and subject to change. +config_setting( + name = "with_tflite_flex", + define_values = {"with_tflite_flex": "true"}, + visibility = ["//visibility:public"], +) + cc_library( name = "schema_fbs_version", hdrs = ["version.h"], @@ -157,6 +165,10 @@ cc_library( "stderr_reporter.h", ], copts = tflite_copts(), + defines = select({ + ":with_tflite_flex": ["TFLITE_FLEX"], + "//conditions:default": [], + }), linkopts = [ ] + select({ "//tensorflow:android": [ @@ -180,7 +192,12 @@ cc_library( "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/profiling:profiler", "//tensorflow/contrib/lite/schema:schema_fbs", - ], + ] + select({ + ":with_tflite_flex": [ + "//tensorflow/contrib/lite/delegates/flex:delegate", + ], + "//conditions:default": [], + }), ) cc_library( diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD index 4d2437e7d3..d180cb4785 100644 --- a/tensorflow/contrib/lite/examples/android/BUILD +++ b/tensorflow/contrib/lite/examples/android/BUILD @@ -28,6 +28,7 @@ android_binary( srcs = glob([ "app/src/main/java/**/*.java", ]), + aapt_version = "aapt", # Package assets from assets dir as well as all model targets. # Remove undesired models (and corresponding Activities in source) # to reduce APK size. diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl index db837cf29e..9d2aead266 100644 --- a/tensorflow/contrib/lite/java/aar_with_jni.bzl +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -3,12 +3,12 @@ load("@build_bazel_rules_android//android:rules.bzl", "android_binary") def aar_with_jni(name, android_library): - # Generate dummy AndroidManifest.xml for dummy apk usage - # (dummy apk is generated by <name>_dummy_app_for_so target below) - native.genrule( - name = name + "_binary_manifest_generator", - outs = [name + "_generated_AndroidManifest.xml"], - cmd = """ + # Generate dummy AndroidManifest.xml for dummy apk usage + # (dummy apk is generated by <name>_dummy_app_for_so target below) + native.genrule( + name = name + "_binary_manifest_generator", + outs = [name + "_generated_AndroidManifest.xml"], + cmd = """ cat > $(OUTS) <<EOF <manifest xmlns:android="http://schemas.android.com/apk/res/android" @@ -17,27 +17,28 @@ cat > $(OUTS) <<EOF </manifest> EOF """, - ) + ) - # Generate dummy apk including .so files and later we extract out - # .so files and throw away the apk. - android_binary( - name = name + "_dummy_app_for_so", - manifest = name + "_generated_AndroidManifest.xml", - custom_package = "dummy.package.for.so", - deps = [android_library], - # In some platforms we don't have an Android SDK/NDK and this target - # can't be built. We need to prevent the build system from trying to - # use the target in that case. - tags = ["manual"], - ) + # Generate dummy apk including .so files and later we extract out + # .so files and throw away the apk. + android_binary( + name = name + "_dummy_app_for_so", + aapt_version = "aapt", + manifest = name + "_generated_AndroidManifest.xml", + custom_package = "dummy.package.for.so", + deps = [android_library], + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = ["manual"], + ) - native.genrule( - name = name, - srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"], - outs = [name + ".aar"], - tags = ["manual"], - cmd = """ + native.genrule( + name = name, + srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"], + outs = [name + ".aar"], + tags = ["manual"], + cmd = """ cp $(location {}.aar) $(location :{}.aar) chmod +w $(location :{}.aar) origdir=$$PWD @@ -46,4 +47,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*" cp -r lib jni zip -r $$origdir/$(location :{}.aar) jni/*/*.so """.format(android_library, name, name, name, name), - ) + ) diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD index 220d6c2159..5ad738389e 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD @@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0 android_binary( name = "TfLiteCameraDemo", srcs = glob(["java/**/*.java"]), + aapt_version = "aapt", assets = [ "//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt", "@tflite_mobilenet//:mobilenet_quant_v1_224.tflite", diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD index b2e3a9bd7d..058240aada 100644 --- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD +++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD @@ -8,6 +8,7 @@ android_binary( srcs = [ "OvicBenchmarkerActivity.java", ], + aapt_version = "aapt", assets = [ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata", "//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt", diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index 765c3a03ef..689cea03e7 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -37,10 +37,6 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) { : nullptr; } -inline Dims<4> GetTensorDims(std::vector<int32_t> data) { - return GetTensorDims(data.data(), data.size()); -} - inline RuntimeShape GetTensorShape(std::vector<int32_t> data) { return RuntimeShape(data.size(), data.data()); } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h index 5e688ce452..9f5b33d217 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h @@ -86,35 +86,6 @@ inline const bool* GetTensorData(const TfLiteTensor* tensor) { return tensor != nullptr ? tensor->data.b : nullptr; } -// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object -// even if the original tensors were not 4D. We should consider rewriting them -// to take a more generic 'shape' object. -inline Dims<4> GetTensorDims(const int data[], const int size) { - Dims<4> d; - for (int i = 0; i < 4; ++i) { - int src = size - i - 1; - if (src >= 0) { - d.sizes[i] = data[src]; - } else { - d.sizes[i] = 1; - } - } - d.strides[0] = 1; - for (int i = 1; i < 4; i++) { - d.strides[i] = d.strides[i - 1] * d.sizes[i - 1]; - } - return d; -} - -inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { - if (tensor == nullptr) { - return Dims<4>(); - } - - auto* dims = tensor->dims; - return GetTensorDims(dims->data, dims->size); -} - inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) { if (tensor == nullptr) { return RuntimeShape(); diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc index bf2068d320..2ed73ba82d 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_test.cc @@ -21,28 +21,32 @@ namespace { using ::testing::ElementsAre; -TEST(TensorTest, GetTensorDims4D) { - Dims<4> d = GetTensorDims({2, 3, 4, 5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 2)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +TEST(TensorTest, GetTensorShape4D) { + RuntimeShape d = GetTensorShape({2, 3, 4, 5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(2, 3, 4, 5)); } -TEST(TensorTest, GetTensorDims3D) { - Dims<4> d = GetTensorDims({3, 4, 5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 4, 3, 1)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 60)); +TEST(TensorTest, GetTensorShape3D) { + RuntimeShape d = GetTensorShape({3, 4, 5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(3, 4, 5)); } -TEST(TensorTest, GetTensorDims2D) { - Dims<4> d = GetTensorDims({4, 5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 4, 1, 1)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 20, 20)); +TEST(TensorTest, GetTensorShape2D) { + RuntimeShape d = GetTensorShape({4, 5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(4, 5)); } -TEST(TensorTest, GetTensorDims1D) { - Dims<4> d = GetTensorDims({5}); - EXPECT_THAT(d.sizes, ElementsAre(5, 1, 1, 1)); - EXPECT_THAT(d.strides, ElementsAre(1, 5, 5, 5)); +TEST(TensorTest, GetTensorShape1D) { + RuntimeShape d = GetTensorShape({5}); + EXPECT_THAT( + std::vector<int32>(d.DimsData(), d.DimsData() + d.DimensionsCount()), + ElementsAre(5)); } } // namespace diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD index f18a2ca07a..2e5033dab1 100644 --- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -20,6 +20,7 @@ filegroup( android_binary( name = "SmartReplyDemo", srcs = glob(["java/**/*.java"]), + aapt_version = "aapt", assets = [":assets"], assets_dir = "", custom_package = "com.example.android.smartreply", diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index d962a5e12d..36125c198e 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -133,7 +133,8 @@ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*benchmark*.cc) \ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \ -tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc +tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc \ +tensorflow/contrib/makefile/downloads/absl/absl/hash/internal/print_hash_of.cc ABSL_CC_SRCS := $(filter-out $(ABSL_CC_EXCLUDE_SRCS), $(ABSL_CC_ALL_SRCS)) diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index f4ac70eb1a..6a67c6295d 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -377,6 +377,11 @@ py_test( size = "large", srcs = ["python/training/shampoo_test.py"], srcs_version = "PY2AND3", + tags = [ + "noasan", # b/116875897 + "nomsan", + "notsan", + ], deps = [ ":opt_py", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py index 05bcf2cfa3..a2fd8fbd87 100644 --- a/tensorflow/contrib/opt/python/training/shampoo_test.py +++ b/tensorflow/contrib/opt/python/training/shampoo_test.py @@ -54,9 +54,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np_2 = np.random.rand(size) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = constant_op.constant(grad_np, dtype=dtypes.float32) grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) @@ -105,9 +105,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np_2 = np.random.rand(size[0], size[1]) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = constant_op.constant(grad_np, dtype=dtypes.float32) grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) @@ -164,9 +164,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np_2 = np.random.rand(size[0], size[1], size[2]) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = constant_op.constant(grad_np, dtype=dtypes.float32) grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) @@ -254,9 +254,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np_2 = np.random.rand(size) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = constant_op.constant(grad_np, dtype=dtypes.float32) grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) @@ -310,9 +310,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np_2 = np.random.rand(size[0], size[1]) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = constant_op.constant(grad_np, dtype=dtypes.float32) grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) @@ -383,9 +383,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np_2 = np.random.rand(sample_size_2, size[1]) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = ops.IndexedSlices( constant_op.constant(grad_np, dtype=dtypes.float32), @@ -463,9 +463,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): grad_np = np.random.rand(sample_size, size[1], size[2]) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = ops.IndexedSlices( constant_op.constant(grad_np, dtype=dtypes.float32), @@ -533,9 +533,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): gbar_weight = 0.1 with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = constant_op.constant(grad_np, dtype=dtypes.float32) grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32) @@ -628,9 +628,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): mat_g3 = np.zeros_like(mat_g3_a) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = array_ops.placeholder(dtypes.float32, shape=size) @@ -705,9 +705,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase): mat_g3 = np.zeros_like(mat_g3_a) with self.cached_session() as sess: - global_step = variables.Variable( + global_step = variables.VariableV1( 0, dtype=dtypes.int64, use_resource=use_resource_var) - var = variables.Variable( + var = variables.VariableV1( init_var_np, dtype=dtypes.float32, use_resource=use_resource_var) grad = array_ops.placeholder(dtypes.float32, shape=size) diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index c230919168..cb1f707028 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -159,7 +159,12 @@ py_test( ], shard_count = 4, srcs_version = "PY2AND3", - tags = ["no_pip_gpu"], # b/63391119 + tags = [ + "no_pip_gpu", # b/63391119 + "noasan", # b/116875897 + "nomsan", + "notsan", + ], deps = [ ":estimators", ":feature_keys", diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py index 647455ae42..04d17bc123 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -104,7 +104,7 @@ class EvaluationMetricsTests(test.TestCase): "ticker": array_ops.reshape( math_ops.cast( - variables.Variable( + variables.VariableV1( name="ticker", initial_value=0, dtype=dtypes.int64, diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py index 766466968a..6ce6b779a2 100644 --- a/tensorflow/contrib/tpu/__init__.py +++ b/tensorflow/contrib/tpu/__init__.py @@ -55,7 +55,9 @@ @@TPUDistributionStrategy @@keras_to_tpu_model + @@AsyncCheckpointSaverHook +@@TPUInMemoryEvalHook """ from __future__ import absolute_import @@ -65,6 +67,7 @@ from __future__ import print_function # pylint: disable=wildcard-import,unused-import from tensorflow.contrib.tpu.python import profiler from tensorflow.contrib.tpu.python.ops.tpu_ops import * +from tensorflow.contrib.tpu.python.tpu.async_checkpoint import * from tensorflow.contrib.tpu.python.tpu.bfloat16 import * from tensorflow.contrib.tpu.python.tpu.device_assignment import * from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc index 1bd1a31e11..bc1a0c5284 100644 --- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc +++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc @@ -104,18 +104,9 @@ Status RegisterPerTableLoadOpsForAlgorithmBody( } } { - auto* table_id_attr = op_def->add_attr(); - table_id_attr->set_name("table_id"); - table_id_attr->set_type("int"); - table_id_attr->set_has_minimum(true); - table_id_attr->set_minimum(-1); - table_id_attr->mutable_default_value()->set_i(-1); - } - { auto* table_name_attr = op_def->add_attr(); table_name_attr->set_name("table_name"); table_name_attr->set_type("string"); - table_name_attr->mutable_default_value()->set_s(""); } { auto* num_shards_attr = op_def->add_attr(); @@ -147,11 +138,9 @@ parameters that are loaded from a checkpoint before a training loop is executed. %s table_name: Name of this table; must match a name in the - EmbeddingLayerConfiguration proto (overrides table_id). + EmbeddingLayerConfiguration proto. num_shards: Number of shards into which the embedding tables are divided. shard_id: Identifier of shard for this operation. -table_id: Index of this table in the EmbeddingLayerConfiguration proto - (deprecated). )doc", parameter_descriptions.c_str())); op_def->set_is_commutative(false); @@ -160,14 +149,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto auto shape_inference_function = [state_variable_specs, is_debug_op](shape_inference::InferenceContext* c) -> Status { - int table_id; - TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - // Exactly one must be non-default. - if ((table_id >= 0) == (!table_name.empty())) { - return errors::InvalidArgument( - "exactly one of table_id or table_name must be non-default"); + if (table_name.empty()) { + return errors::InvalidArgument("table_name attribute must be set"); } int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); @@ -241,18 +226,9 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody( } } { - auto* table_id_attr = op_def->add_attr(); - table_id_attr->set_name("table_id"); - table_id_attr->set_type("int"); - table_id_attr->set_has_minimum(true); - table_id_attr->set_minimum(-1); - table_id_attr->mutable_default_value()->set_i(-1); - } - { auto* table_name_attr = op_def->add_attr(); table_name_attr->set_name("table_name"); table_name_attr->set_type("string"); - table_name_attr->mutable_default_value()->set_s(""); } { auto* num_shards_attr = op_def->add_attr(); @@ -283,11 +259,9 @@ the correct embedding table configuration. For example, this op is used to retrieve updated parameters before saving a checkpoint. %s table_name: Name of this table; must match a name in the - EmbeddingLayerConfiguration proto (overrides table_id). + EmbeddingLayerConfiguration proto. num_shards: Number of shards into which the embedding tables are divided. shard_id: Identifier of shard for this operation. -table_id: Index of this table in the EmbeddingLayerConfiguration proto - (deprecated). )doc", parameter_descriptions.c_str())); op_def->set_is_commutative(false); @@ -296,14 +270,10 @@ table_id: Index of this table in the EmbeddingLayerConfiguration proto auto shape_inference_function = [state_variable_specs, is_debug_op](shape_inference::InferenceContext* c) -> Status { - int table_id; - TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id)); string table_name; TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name)); - // Exactly one must be non-default. - if ((table_id >= 0) == (!table_name.empty())) { - return errors::InvalidArgument( - "exactly one of table_id or table_name must be non-default"); + if (table_name.empty()) { + return errors::InvalidArgument("table_name must be non-empty"); } int num_shards; TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards)); diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index b498599962..8e6e9aa0cd 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -156,8 +156,7 @@ bool NewSession(const string& service_addr, channel_args)); NewProfileSessionResponse new_session_response; TF_QCHECK_OK(FromGrpcStatus( - stub->NewSession(&context, new_session_request, &new_session_response))) - << new_session_response.error_message(); + stub->NewSession(&context, new_session_request, &new_session_response))); std::cout << "Profile session succeed for host(s):" << str_util::Join(hostnames, ",") << std::endl; diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto index b25d06dda8..292108f949 100644 --- a/tensorflow/contrib/tpu/profiler/op_profile.proto +++ b/tensorflow/contrib/tpu/profiler/op_profile.proto @@ -66,8 +66,8 @@ message Metrics { // - it does not reveal the peak core FLOPS of the hardware double flops = 2; - // The VMEM bandwidth used to load operands from HBM, as a fraction of - // thereotical VMEM bandwidth on the specific hardware. + // The memory bandwidth used to load operands, as a fraction of + // thereotical memory bandwidth on the specific hardware. double memory_bandwidth = 3; double raw_time = 11; // Elapsed core-time in picoseconds. diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto index fc1320501b..a43f45554f 100644 --- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto +++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto @@ -22,13 +22,22 @@ message LearningRate { } } +// Each optimizer's parameter proto has a link to its documentation and CPU +// implementation (if available) for user reference. + +// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151 message AdagradParameters { float initial_accumulator = 1; } +// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423 message StochasticGradientDescentParameters { } +// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192 message FtrlParameters { float l1 = 1; float l2 = 2; @@ -41,21 +50,38 @@ message FtrlParameters { // learning rate feature instead, setting the learning rate to: // user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) // Here, t is the current timestep. +// +// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer // https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54 +// +// Note that the code by default implements the lazy version of Adam +// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer) +// unless the use_non_lazy_adam parameter is set, in which case it implements +// the normal version of Adam that updates all parameters in the embedding +// table, even for entries that are not used in the current minibatch +// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If +// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in +// order to get correct results; a warning will be printed otherwise (which may +// change to an error in the future). message AdamParameters { float beta1 = 3; float beta2 = 4; float epsilon = 5; float initial_m = 6; float initial_v = 7; + bool use_non_lazy_adam = 8; } +// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271 message MomentumParameters { float momentum = 1; bool use_nesterov = 2; float initial_accum = 3; } +// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356 message RmsPropParameters { float rho = 1; float momentum = 2; @@ -64,6 +90,8 @@ message RmsPropParameters { float initial_mom = 5; } +// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372 message CenteredRmsPropParameters { float rho = 1; float momentum = 2; @@ -73,6 +101,7 @@ message CenteredRmsPropParameters { float initial_mg = 6; } +// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf message MdlAdagradLightParameters { float l2 = 1; float lr_power = 2; @@ -91,6 +120,8 @@ message MdlAdagradLightParameters { float initial_benefit = 15; } +// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68 message AdadeltaParameters { float rho = 1; float epsilon = 2; @@ -98,6 +129,8 @@ message AdadeltaParameters { float initial_update = 4; } +// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer +// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164 message ProximalAdagradParameters { float l1 = 1; float l2 = 2; diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py index e06a720e82..20b7ba0997 100644 --- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py +++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ====================================== - """Hook for asynchronous checkpointing. This hook dispatches checkpoint writing operations in a separate thread to @@ -28,18 +27,16 @@ import threading import time from tensorflow.core.util.event_pb2 import SessionLog - from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks -from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util from tensorflow.python.training.session_run_hook import SessionRunArgs from tensorflow.python.training.summary_io import SummaryWriterCache -class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): +class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): """Saves checkpoints every N steps or seconds.""" def __init__(self, @@ -67,7 +64,7 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): ValueError: One of `save_steps` or `save_secs` should be set. ValueError: At most one of `saver` or `scaffold` should be set. """ - logging.info("Create CheckpointSaverHook.") + logging.info("Create AsyncCheckpointSaverHook.") if saver is not None and scaffold is not None: raise ValueError("You cannot provide both saver and scaffold.") self._saver = saver @@ -144,6 +141,10 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): def _save(self, session, step, asynchronous=True): """Saves the latest checkpoint, returns should_stop.""" + # Skip saving on step 0 + if step == 0: + return + def _save_fn(): """Run the saver process.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path) @@ -162,7 +163,6 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook): end_time - start_time) logging.info("Checkpoint finished for %d into %s.", step, self._save_path) - logging.info("Saving checkpoints for %d into %s.", step, self._save_path) for l in self._listeners: l.before_save(session, step) diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py index 956d0142a3..696656e840 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_support.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py @@ -959,7 +959,16 @@ class TPUFunction(object): # Compute our outfeed depending on the execution mode if is_training: - self._cloned_model._make_train_function() + if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer): + # For Keras optimizer, we try to place the variable weights on the TPU + # device. Keras creates optimizer variables (e.g. momentum values for + # the Momentum optimizer) when _make_train_function is invoked. + with keras_tpu_variables.replicated_variable_for_optimizer( + self._tpu_assignment.num_towers): + self._cloned_model._make_train_function() + else: + self._cloned_model._make_train_function() + self._outfeed_spec = [ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) for tensor in self._cloned_model.train_function.outputs diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py index 170977d8ab..598da7418e 100644 --- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py +++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py @@ -25,10 +25,15 @@ from __future__ import print_function import contextlib +import numpy as np + from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops +from tensorflow.python.keras import backend from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -285,3 +290,51 @@ def replicated_scope(num_replicas): return variable_scope.variable_scope( "", custom_getter=_replicated_variable_getter) + + +@contextlib.contextmanager +def replicated_variable_for_optimizer(num_replicas): + """Context manager for optimizer weights. Overrides K.variable.""" + if num_replicas == 1: + yield + return + + try: + old_v = backend.variable + + def opt_variable(value, dtype=None, name=None, constraint=None): + """Instantiates a variable and returns it.""" + if dtype is None: + dtype = backend.floatx() + + variables = [] + for i in range(num_replicas): + # Keras holds the variables in optimizer class instance , so the name + # does not matter here. ResourceVariable constructor will find a unique + # name (including name=None) for each replica. + with ops.device("device:TPU:{}".format(i)): + v = resource_variable_ops.ResourceVariable( + value, + dtype=dtypes_module.as_dtype(dtype), + name=name, + constraint=constraint) + variables.append(v) + name = "replicate_{}_{}".format("variable" if name is None else name, + ops.uid()) + v = ReplicatedVariable(name, variables) + + # pylint: disable=protected-access + + if isinstance(value, np.ndarray): + v._keras_shape = value.shape + elif hasattr(value, "shape"): + v._keras_shape = backend.int_shape(value) + v._uses_learning_phase = False + backend.track_variable(v) + return v + + backend.variable = opt_variable + yield + + finally: + backend.variable = old_v diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index 23c54511ca..545cee637f 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -231,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote `metric_fn` runs on CPU to generate metrics and `tensors` represents the `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. To be precise, TPU evaluation expects a slightly different signature from the - @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a + `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The `tensors` usually specify the model logits, which are transferred back from @@ -254,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote sending tensors from TPU to CPU. To reduce the overhead, try reducing the size of the tensors. The `tensors` are concatenated along their major (batch) dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with @{tf.contrib.summary.create_file_writer}. + summaries with `tf.contrib.summary.create_file_writer`. """ def __new__(cls, @@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): self._feed_error = None self._finished = False + self._should_initialize_tpu = True def begin(self): logging.info('TPU job name %s', self._master_job) self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [tpu.initialize_system(job=self._master_job)] - self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] + if self._should_initialize_tpu: + self._init_ops = [tpu.initialize_system(job=self._master_job)] + self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] + else: + self._init_ops = [] + self._finalize_ops = [] summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() self._init_ops.extend(summary_writer_init_ops) diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index ddf8365d61..b565ebd073 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -313,6 +313,5 @@ tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), cc_api_version = 2, - java_api_version = 2, visibility = ["//visibility:public"], ) |