aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
committerGravatar Cao Zongyan <zongyan.cao@alibaba-inc.com>2018-09-26 11:54:30 +0800
commit35174f46b973c66a2e6894a12b3018d60e8414ec (patch)
tree5bdae0172159bc02ec3a470722bf959b14dd47ba /tensorflow/core/kernels
parentf0886f7269de900d226455d4831722f6fc94a71b (diff)
parent6666516f390f125ed70ddbd4e6f89b83d953c408 (diff)
Merge remote-tracking branch 'origin'
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD104
-rw-r--r--tensorflow/core/kernels/bias_op.cc13
-rw-r--r--tensorflow/core/kernels/bincount_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD16
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto13
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc38
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantile_ops.cc453
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD4
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h96
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc26
-rw-r--r--tensorflow/core/kernels/conv_2d.h45
-rw-r--r--tensorflow/core/kernels/conv_3d.h43
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc6
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc11
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc1330
-rw-r--r--tensorflow/core/kernels/conv_ops.cc19
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc20
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h6
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc81
-rw-r--r--tensorflow/core/kernels/data/BUILD29
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc4
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc156
-rw-r--r--tensorflow/core/kernels/data/captured_function.h25
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc37
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h10
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc61
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc44
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc51
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc12
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc41
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc82
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc234
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc183
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc633
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc22
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc1
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc207
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc55
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc61
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc7
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc13
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc38
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc19
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc215
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc7
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc3
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h311
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h41
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc31
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h1438
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h342
-rw-r--r--tensorflow/core/kernels/eigen_volume_patch.h1
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.cc197
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.h58
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD2
-rw-r--r--tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc45
-rw-r--r--tensorflow/core/kernels/gather_functor.h1
-rw-r--r--tensorflow/core/kernels/histogram_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/list_kernels.cc12
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc3
-rw-r--r--tensorflow/core/kernels/logging_ops.cc54
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc22
-rw-r--r--tensorflow/core/kernels/mirror_pad_op.h1
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops_test.cc407
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc4
-rw-r--r--tensorflow/core/kernels/queue_base.h4
-rw-r--r--tensorflow/core/kernels/queue_ops.cc2
-rw-r--r--tensorflow/core/kernels/random_op.cc10
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h10
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc2
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc2
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc5
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc1
-rw-r--r--tensorflow/core/kernels/searchsorted_op.cc249
-rw-r--r--tensorflow/core/kernels/searchsorted_op.h52
-rw-r--r--tensorflow/core/kernels/searchsorted_op_gpu.cu.cc126
-rw-r--r--tensorflow/core/kernels/shape_op_test.cc10
-rw-r--r--tensorflow/core/kernels/split_lib_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/split_op.cc7
-rw-r--r--tensorflow/core/kernels/stack_ops.cc26
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
-rw-r--r--tensorflow/core/kernels/string_format_op_test.cc66
-rw-r--r--tensorflow/core/kernels/string_length_op.cc23
-rw-r--r--tensorflow/core/kernels/string_util.cc63
-rw-r--r--tensorflow/core/kernels/string_util.h45
-rw-r--r--tensorflow/core/kernels/substr_op.cc50
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc105
-rw-r--r--tensorflow/core/kernels/tensor_array.cc3
-rw-r--r--tensorflow/core/kernels/tensor_array.h3
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc5
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc6
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc1
-rw-r--r--tensorflow/core/kernels/unravel_index_op.cc10
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h8
105 files changed, 7212 insertions, 1394 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index c3c6013d83..1a3db2c7cd 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,6 +30,7 @@ load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
+ "tf_cc_test_mkl",
"tf_cc_tests",
"tf_cc_binary",
"tf_copts",
@@ -50,6 +51,10 @@ load(
"tf_kernel_tests_linkstatic",
)
load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
"if_mkl_ml",
@@ -212,6 +217,19 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "extract_volume_patches_op",
+ prefix = "extract_volume_patches_op",
+ deps = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":ops_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
@@ -617,6 +635,7 @@ cc_library(
":diag_op",
":edit_distance_op",
":extract_image_patches_op",
+ ":extract_volume_patches_op",
":gather_nd_op",
":gather_op",
":guarantee_const_op",
@@ -636,6 +655,7 @@ cc_library(
":reshape_op",
":reverse_op",
":reverse_sequence_op",
+ ":searchsorted_op",
":shape_ops",
":slice_op",
":snapshot_op",
@@ -869,6 +889,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "searchsorted_op",
+ prefix = "searchsorted_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "inplace_ops",
prefix = "inplace_ops",
deps = ARRAY_DEPS,
@@ -1105,7 +1131,7 @@ tf_cuda_cc_test(
name = "depthwise_conv_ops_test",
size = "small",
srcs = ["depthwise_conv_ops_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":conv_ops",
":image",
@@ -2702,6 +2728,7 @@ cc_library(
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2759,6 +2786,7 @@ tf_cc_tests(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -4396,6 +4424,7 @@ cc_library(
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
":string_length_op",
":string_split_op",
@@ -4405,8 +4434,16 @@ cc_library(
],
)
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
STRING_DEPS = [
":bounds_check",
+ ":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -4427,6 +4464,30 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
@@ -4504,6 +4565,25 @@ tf_kernel_library(
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "substr_op_test",
+ size = "small",
+ srcs = ["substr_op_test.cc"],
+ deps = [
+ ":substr_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "as_string_op",
prefix = "as_string_op",
@@ -5094,6 +5174,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
+ "string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@@ -5262,6 +5343,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
+ "string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
@@ -6209,6 +6291,26 @@ tf_mkl_kernel_library(
] + mkl_deps(),
)
+tf_cc_test_mkl(
+ name = "mkl_conv_ops_test",
+ size = "small",
+ srcs = ["mkl_conv_ops_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_mkl_kernel_library(
name = "mkl_tfconv_op",
prefix = "mkl_tfconv",
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 7b28c8e91f..e15ea82e7d 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -134,8 +134,8 @@ class BiasOp : public BinaryOp<T> {
if (data_format_ == FORMAT_NCHW) {
int32 batch, height, width, channel;
GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
- Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
- Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1);
+ Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width);
const Device& d = context->eigen_device<Device>();
output->tensor<T, 4>().device(d) =
input.tensor<T, 4>() +
@@ -247,14 +247,14 @@ class BiasGradOp : public OpKernel {
OP_REQUIRES(context, output_backprop.dims() == 4,
errors::InvalidArgument(
"NCHW format supports only 4D input/output tensor."));
- Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
#ifdef EIGEN_HAS_INDEX_LIST
using idx0 = Eigen::type2index<0>;
using idx2 = Eigen::type2index<2>;
using idx3 = Eigen::type2index<3>;
Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
#else
- Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+ Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
@@ -263,11 +263,12 @@ class BiasGradOp : public OpKernel {
.sum(reduction_axes)
.template cast<T>(); // End of code by intel_tf.
} else {
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
+ channel);
#ifdef EIGEN_HAS_INDEX_LIST
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<Eigen::Index, 1> reduction_axis = {0};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e1f6..7d09e9b820 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021c63..4e8bfa02fc 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@ load(
tf_proto_library(
name = "boosted_trees_proto",
- srcs = ["boosted_trees.proto"],
+ srcs = [
+ "boosted_trees.proto",
+ ],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
@@ -87,9 +89,21 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "quantile_ops",
+ srcs = ["quantile_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+ ],
+)
+
+tf_kernel_library(
name = "boosted_trees_ops",
deps = [
":prediction_ops",
+ ":quantile_ops",
":resource_ops",
":stats_ops",
":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index c9664f0c1c..1ab72af059 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -11,6 +11,7 @@ message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
+ CategoricalSplit categorical_split = 3;
}
NodeMetadata metadata = 777;
}
@@ -57,6 +58,18 @@ message BucketizedSplit {
int32 right_id = 4;
}
+message CategoricalSplit {
+ // Categorical feature column and split describing the rule feature value ==
+ // value.
+ int32 feature_id = 1;
+ int32 value = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06941..4ae26fb95b 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
- example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
- resource->node_value(0, 0));
+ float tree_logit =
+ resource->GetTreeWeight(0) * resource->node_value(0, 0);
+ example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
- float tree_logit;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
- // Populate proto.
+ // Go through each tree and populate proto.
while (tree_id <= last_tree) {
- // Feature id used to split.
- feature_id = resource->feature_id(tree_id, node_id);
- example_debug_info.add_feature_ids(feature_id);
- // Get logit after split.
- node_id = resource->next_node(tree_id, node_id, i,
- batch_bucketized_features);
- tree_logit = resource->GetTreeWeight(tree_id) *
- resource->node_value(tree_id, node_id);
- // Output logit incorporates sum of leaf logits from prior trees.
- example_debug_info.add_logits_path(tree_logit + past_trees_logit);
- if (resource->is_leaf(tree_id, node_id)) {
- // Move onto other trees.
- past_trees_logit += tree_logit;
+ if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
+ // Accumulate tree_logits only if the leaf is non-root, but do so
+ // for bias tree.
+ if (tree_id == 0 || node_id > 0) {
+ past_trees_logit += tree_logit;
+ }
++tree_id;
node_id = 0;
+ } else { // Add to proto.
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000000..d1840941c1
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+ const int64 num_boundaries) {
+ std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+ // Uniquify elements as we may get dupes.
+ auto end_it = std::unique(boundaries.begin(), boundaries.end());
+ boundaries.resize(std::distance(boundaries.begin(), end_it));
+ return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ const int64 num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+ CHECK_EQ(boundaries.size(), num_quantiles);
+ return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+ const OpInputList& buckets_list) {
+ const auto& buckets = buckets_list[feature].flat<float>();
+ std::vector<float> buckets_vector(buckets.data(),
+ buckets.data() + buckets.size());
+ return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateQuantileStreamResourceOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions. If one already exists, it unrefs the new one.
+ // An epsilon value of zero could cause perfoamance issues and is therefore,
+ // disallowed.
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+ OP_REQUIRES(
+ context, epsilon > 0,
+ errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+ const Tensor* num_streams_t;
+ OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+ int64 num_streams = num_streams_t->scalar<int64>()();
+
+ auto result =
+ new QuantileStreamResource(epsilon, max_elements_, num_streams);
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ // An upper bound on the number of entries that the summaries might have
+ // for a feature.
+ int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+ BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeQuantileSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+ // Parse example weights and get batch size.
+ const Tensor* example_weights_t;
+ OP_REQUIRES_OK(context,
+ context->input(kExampleWeightsName, &example_weights_t));
+ auto example_weights = example_weights_t->flat<float>();
+ const int64 batch_size = example_weights.size();
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+
+ OpOutputList summaries_output_list;
+ OP_REQUIRES_OK(
+ context, context->output_list(kSummariesName, &summaries_output_list));
+
+ auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+ // Iterating features.
+ for (int64 index = begin; index < end; index++) {
+ const auto feature_values = float_features_list[index].flat<float>();
+ QuantileStream stream(epsilon, batch_size + 1);
+ // Run quantile summary generation.
+ for (int64 j = 0; j < batch_size; j++) {
+ stream.PushEntry(feature_values(j), example_weights(j));
+ }
+ stream.Finalize();
+ const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+ Tensor* output_t;
+ OP_REQUIRES_OK(
+ context,
+ summaries_output_list.allocate(
+ index,
+ TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+ &output_t));
+ auto output = output_t->matrix<float>();
+ for (auto row = 0; row < summary_entry_list.size(); row++) {
+ const auto& entry = summary_entry_list[row];
+ output(row, 0) = entry.value;
+ output(row, 1) = entry.weight;
+ output(row, 2) = entry.min_rank;
+ output(row, 3) = entry.max_rank;
+ }
+ }
+ };
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_summary_gen);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+ BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ OpInputList summaries_list;
+ OP_REQUIRES_OK(context,
+ context->input_list(kSummariesName, &summaries_list));
+ int32 num_streams = stream_resource->num_streams();
+ CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+ auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+ // Iterating all features.
+ for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+ const Tensor& summaries = summaries_list[feature_idx];
+ const auto summary_values = summaries.matrix<float>();
+ const auto& tensor_shape = summaries.shape();
+ const int64 entries_size = tensor_shape.dim_size(0);
+ CHECK_EQ(tensor_shape.dim_size(1), 4);
+ std::vector<QuantileSummaryEntry> summary_entries;
+ summary_entries.reserve(entries_size);
+ for (int64 i = 0; i < entries_size; i++) {
+ float value = summary_values(i, 0);
+ float weight = summary_values(i, 1);
+ float min_rank = summary_values(i, 2);
+ float max_rank = summary_values(i, 3);
+ QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+ summary_entries.push_back(entry);
+ }
+ stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_add_summary);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceFlushOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const Tensor* num_buckets_t;
+ OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+ const int64 num_buckets = num_buckets_t->scalar<int64>()();
+ const int64 num_streams = stream_resource->num_streams();
+
+ auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+ QuantileStream* stream = stream_resource->stream(stream_idx);
+ stream->Finalize();
+ stream_resource->set_boundaries(
+ generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+ : GenerateBoundaries(*stream, num_buckets),
+ stream_idx);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_flush);
+
+ stream_resource->set_buckets_ready(true);
+ }
+
+ private:
+ bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+ : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const int64 num_streams = stream_resource->num_streams();
+ CHECK_EQ(num_features_, num_streams);
+ OpOutputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+
+ auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+ const auto& boundaries = stream_resource->boundaries(stream_idx);
+ Tensor* bucket_boundaries_t = nullptr;
+ OP_REQUIRES_OK(context,
+ bucket_boundaries_list.allocate(
+ stream_idx, {static_cast<int64>(boundaries.size())},
+ &bucket_boundaries_t));
+ auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+ memcpy(quantiles_flat, boundaries.data(),
+ sizeof(float) * boundaries.size());
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_get_buckets);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+ explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+ OpInputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+ OP_REQUIRES(context,
+ tensorflow::TensorShapeUtils::IsVector(
+ bucket_boundaries_list[0].shape()),
+ errors::InvalidArgument(
+ strings::Printf("Buckets should be flat vectors.")));
+ OpOutputList buckets_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+ auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+ // Iterating over all resources
+ for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+ const Tensor& values_tensor = float_features_list[feature_idx];
+ const int64 num_values = values_tensor.dim_size(0);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, buckets_list.allocate(
+ feature_idx, TensorShape({num_values, 1}), &output_t));
+ auto output = output_t->matrix<int32>();
+
+ const std::vector<float>& bucket_boundaries_vector =
+ GetBuckets(feature_idx, bucket_boundaries_list);
+ CHECK(!bucket_boundaries_vector.empty())
+ << "Got empty buckets for feature " << feature_idx;
+ auto flat_values = values_tensor.flat<float>();
+ for (int64 instance = 0; instance < num_values; instance++) {
+ const float value = flat_values(instance);
+ auto bucket_iter =
+ std::lower_bound(bucket_boundaries_vector.begin(),
+ bucket_boundaries_vector.end(), value);
+ if (bucket_iter == bucket_boundaries_vector.end()) {
+ --bucket_iter;
+ }
+ const int32 bucket = static_cast<int32>(
+ bucket_iter - bucket_boundaries_vector.begin());
+ // Bucket id.
+ output(instance, 0) = bucket;
+ }
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_features_;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_get_quantiles);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+ BoostedTreesBucketizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
index 3163c63949..12d9473776 100644
--- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -1,5 +1,5 @@
# Description:
-# This directory contains common utilities used in boosted_trees.
+# This directory contains common quantile utilities used in boosted_trees.
package(
default_visibility = ["//tensorflow:internal"],
)
@@ -16,6 +16,7 @@ cc_library(
name = "weighted_quantiles",
srcs = [],
hdrs = [
+ "quantile_stream_resource.h",
"weighted_quantiles_buffer.h",
"weighted_quantiles_stream.h",
"weighted_quantiles_summary.h",
@@ -23,6 +24,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
],
)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000000..1c31724272
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+ BoostedTreesQuantileStreamResource(const float epsilon,
+ const int64 max_elements,
+ const int64 num_streams)
+ : are_buckets_ready_(false),
+ epsilon_(epsilon),
+ num_streams_(num_streams),
+ max_elements_(max_elements) {
+ streams_.reserve(num_streams_);
+ boundaries_.reserve(num_streams_);
+ for (int64 idx = 0; idx < num_streams; ++idx) {
+ streams_.push_back(QuantileStream(epsilon, max_elements));
+ boundaries_.push_back(std::vector<float>());
+ }
+ }
+
+ string DebugString() override { return "QuantileStreamResource"; }
+
+ tensorflow::mutex* mutex() { return &mu_; }
+
+ QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+ const std::vector<float>& boundaries(const int64 index) {
+ return boundaries_[index];
+ }
+
+ void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+ boundaries_[index] = boundaries;
+ }
+
+ float epsilon() const { return epsilon_; }
+ int64 num_streams() const { return num_streams_; }
+
+ bool are_buckets_ready() const { return are_buckets_ready_; }
+ void set_buckets_ready(const bool are_buckets_ready) {
+ are_buckets_ready_ = are_buckets_ready;
+ }
+
+ private:
+ ~BoostedTreesQuantileStreamResource() override {}
+
+ // Mutex for the whole resource.
+ tensorflow::mutex mu_;
+
+ // Quantile streams.
+ std::vector<QuantileStream> streams_;
+
+ // Stores the boundaries. Same size as streams_.
+ std::vector<std::vector<float>> boundaries_;
+
+ // Whether boundaries are created. Initially boundaries are empty until
+ // set_boundaries are called.
+ bool are_buckets_ready_;
+
+ const float epsilon_;
+ const int64 num_streams_;
+ // An upper-bound for the number of elements.
+ int64 max_elements_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index cc90bb2f45..2798722536 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node(
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- const auto& split = node.bucketized_split();
- if (bucketized_features[split.feature_id()](index_in_batch) <=
- split.threshold()) {
- return split.left_id();
- } else {
- return split.right_id();
+
+ switch (node.node_case()) {
+ case boosted_trees::Node::kBucketizedSplit: {
+ const auto& split = node.bucketized_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold())
+ ? split.left_id()
+ : split.right_id();
+ }
+ case boosted_trees::Node::kCategoricalSplit: {
+ const auto& split = node.categorical_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) ==
+ split.value())
+ ? split.left_id()
+ : split.right_id();
+ }
+ default:
+ DCHECK(false) << "Node type " << node.node_case() << " not supported.";
}
+ return -1;
}
float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index de9b69828e..639c3062cc 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -137,17 +137,16 @@ struct MatMulConvFunctor {
}
};
-// Shuffles a filter tensor from:
-// [<spatial_dims>, in, out]
-// to:
-// [out, in, <spatial_dims>]
+// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
+//
+// Note: Currently OIHW is the only supported destination format. Support for
+// OHWI format will be added in a follow-up change.
template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
- // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
- // to speed up the shuffle operation.
+ // Merge the spatial dimensions together to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
merged_dims[0] = in.dimension(0); // spatial dimensions
for (int i = 1; i < NDIMS - 2; ++i) {
@@ -156,16 +155,30 @@ struct TransformFilter {
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported destination filter format: "
+ << ToString(dst_filter_format);
+ // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
+ // in the beginning.
+ Eigen::DSizes<IndexType, 3> shuffling_perm =
+ Eigen::DSizes<IndexType, 3>(2, 1, 0);
+
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
- expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
- expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
- for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions
- expanded_dims[i + 2] = in.dimension(i);
+ int out_index = 0;
+ for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
+ if (shuffling_perm[merged_dim] == 0) {
+ for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
+ expanded_dims[out_index++] = in.dimension(spatial_dim);
+ }
+ } else {
+ constexpr int kLastSpatialDim = NDIMS - 3;
+ expanded_dims[out_index++] =
+ in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
+ }
}
- out.device(d) = in.reshape(merged_dims)
- .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
- .reshape(expanded_dims);
+ out.device(d) =
+ in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims);
}
};
@@ -282,7 +295,9 @@ struct SwapDimension0And2InTensor3 {
const gtl::ArraySlice<int64>& input_dims, T* out);
};
-// Reverses the effect of TransformFilter above.
+// Transforms back filter from OIHW to HWOI format to reverse effect of
+// TransformFilter above.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename Device, typename T, int NDIMS>
struct ReverseTransformFilter {
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 02e3655ad1..b819c6f910 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -19,6 +19,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
namespace tensorflow {
@@ -28,6 +29,14 @@ namespace functor {
template <typename Device, typename T>
struct CuboidConvolution;
+// Backward input pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardInput;
+
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
@@ -42,6 +51,40 @@ struct CuboidConvolution<CPUDevice, T> {
}
};
+template <typename T>
+struct CuboidConvolutionBackwardInput<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor input_backward,
+ typename TTypes<T, 5>::ConstTensor filter,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward,
+ input_backward.dimension(3), // input_planes
+ input_backward.dimension(2), // input_rows
+ input_backward.dimension(1), // input_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor filter_backward,
+ typename TTypes<T, 5>::ConstTensor input,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward,
+ filter_backward.dimension(2), // kernel_planes
+ filter_backward.dimension(1), // kernel_rows
+ filter_backward.dimension(0), // kernel_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 63b1bcda43..9e86a16b66 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -1018,7 +1018,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d664a11e73..43bb5ea56c 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -901,7 +901,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_out_backprop;
@@ -1090,7 +1091,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index fc0a2f123f..507720c998 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -41,6 +41,17 @@ limitations under the License.
namespace tensorflow {
+// Compute padding for the given spatial dimension.
+int ConvBackpropDimensions::SpatialPadding(const Padding& padding,
+ int dim) const {
+ return (padding == VALID)
+ ? 0
+ : std::max<int>(
+ 0, static_cast<int>((output_size(dim) - 1) * stride(dim) +
+ (filter_size(dim) - 1) * dilation(dim) +
+ 1 - input_size(dim)));
+}
+
// The V2 version computes windowed output size with arbitrary dilation_rate,
// while the original version only handles the cases where dilation_rates equal
// to 1.
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 535586d53a..9551959463 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -234,6 +234,16 @@ struct ConvBackpropDimensions {
// Input and output feature depth.
int64 in_depth, out_depth;
+
+ // Convenience access methods for spatial dimensions properties.
+ int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
+ int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
+ int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
+ int64 stride(int dim) const { return spatial_dims[dim].stride; }
+ int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
+
+ // Compute padding for the given spatial dimension.
+ int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9aba..bab91f5e86 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* im_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data =
+ im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ // Jump over remaining number of (depth * width).
+ im_patch_data += (depth * width) * (height - filter_h);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* col_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ memcpy(col_data,
+ input_data +
+ (ip * height * width + ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+} // namespace
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
- const Tensor& out_backprop = context->input(2); \
- OP_REQUIRES( \
- context, input_shape.dims() == 5, \
- errors::InvalidArgument(label, ": input must be 5-dimensional")); \
- OP_REQUIRES( \
- context, filter_shape.dims() == 5, \
- errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
- OP_REQUIRES( \
- context, out_backprop.dims() == 5, \
- errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
- const int64 batch = input_shape.dim_size(0); \
- OP_REQUIRES( \
- context, batch == out_backprop.dim_size(0), \
- errors::InvalidArgument( \
- label, ": input and out_backprop must have the same batch size")); \
- const std::array<int64, 3> input_size = { \
- {GetTensorDim(input_shape, data_format_, '0'), \
- GetTensorDim(input_shape, data_format_, '1'), \
- GetTensorDim(input_shape, data_format_, '2')}}; \
- const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
- const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \
- filter_shape.dim_size(1), \
- filter_shape.dim_size(2)}}; \
- const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \
- const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \
- const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \
- OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \
- errors::InvalidArgument( \
- label, ": input and filter must have the same depth")); \
- const int64 out_depth = filter_shape.dim_size(4); \
- OP_REQUIRES( \
- context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \
- errors::InvalidArgument( \
- label, ": filter and out_backprop must have the same out_depth")); \
- const std::array<int64, 3> dilations = { \
- {GetTensorDim(dilation_, data_format_, '0'), \
- GetTensorDim(dilation_, data_format_, '1'), \
- GetTensorDim(dilation_, data_format_, '2')}}; \
- const std::array<int64, 3> strides = { \
- {GetTensorDim(stride_, data_format_, '0'), \
- GetTensorDim(stride_, data_format_, '1'), \
- GetTensorDim(stride_, data_format_, '2')}}; \
- std::array<int64, 3> out, padding; \
- OP_REQUIRES_OK( \
- context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \
- padding_, &out, &padding)); \
- OP_REQUIRES(context, output_planes == out[0], \
- errors::InvalidArgument( \
- label, \
- ": Number of planes of out_backprop doesn't match " \
- "computed: actual = ", \
- output_planes, ", computed = ", out[0])); \
- OP_REQUIRES( \
- context, output_rows == out[1], \
- errors::InvalidArgument( \
- label, ": Number of rows of out_backprop doesn't match computed: ", \
- "actual = ", output_rows, ", computed = ", out[1])); \
- OP_REQUIRES( \
- context, output_cols == out[2], \
- errors::InvalidArgument( \
- label, ": Number of cols of out_backprop doesn't match computed: ", \
- "actual = ", output_cols, ", computed = ", out[2])); \
- const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
- const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
- const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
- const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
- const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
- const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
- const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
- const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
- const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
- const auto bottom_pad_planes = \
- padded_out_planes - expanded_out_planes - top_pad_planes; \
- const auto bottom_pad_rows = \
- padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
- padded_out_cols - expanded_out_cols - left_pad_cols; \
- VLOG(2) << "Conv3d: " << label \
- << ": expanded_out_planes = " << expanded_out_planes \
- << ": expanded_out_rows = " << expanded_out_rows \
- << ", expanded_out_cols = " << expanded_out_cols \
- << ", padded_out_planes = " << padded_out_planes \
- << ", padded_out_rows = " << padded_out_rows \
- << ", padded_out_cols = " << padded_out_cols \
- << ", top_pad_planes = " << top_pad_planes \
- << ", top_pad_rows = " << top_pad_rows \
- << ", left_pad_cols = " << left_pad_cols \
- << ", bottom_pad_planes = " << bottom_pad_planes \
- << ", bottom_pad_rows = " << bottom_pad_rows \
- << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
template <typename Device, class T>
class Conv3DBackpropInputOp : public OpKernel {
public:
@@ -192,6 +212,116 @@ class Conv3DBackpropInputOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape input_shape;
+ if (takes_shape_) {
+ const Tensor& input_sizes = context->input(0);
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+ } else {
+ input_shape = context->input(0).shape();
+ }
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
+ Tensor* in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
+};
+
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& filter = context->input(1);
+ const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -200,51 +330,239 @@ class Conv3DBackpropInputOp : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- // Fill out a padded out_backprop.
- TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
- padded_out_cols, out_depth});
- Tensor padded_output;
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // Use L3 cache size as target working set size.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ // Calculate size of matrices involved in MatMul: C = A x B.
+ const int64 size_A = output_image_size * dims.out_depth;
+
+ const int64 size_B = filter_total_size * dims.out_depth;
+
+ const int64 size_C = output_image_size * filter_total_size;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ // Use parallel tensor contractions if there is no batching.
+ //
+ // Compared to Conv2D code, this version is missing work size estimation. In
+ // benchmarks I didn't find a case when it's beneficial to run parallel
+ // contraction compared to sharding and matmuls.
+ const bool use_parallel_contraction = dims.batch_size == 1;
+
+ const size_t shard_size =
+ use_parallel_contraction
+ ? 1
+ : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // Fill a new "reverted" filter. We need to transpose the in_depth and
- // out_depth for the filter and reverse the planes, rows and cols.
- TensorShape r_filter_shape(
- {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
- Tensor r_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
- r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 5>());
- const Tensor& r_filter_cref = r_filter;
-
- // Now we can call conv_3d directly.
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* filter_data = filter.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+ auto in_backprop_flat = in_backprop->template flat<T>();
+ T* input_backprop_data = in_backprop_flat.data();
+ in_backprop_flat.device(context->eigen_device<Device>()) =
+ in_backprop_flat.constant(T(0));
+
+ if (use_parallel_contraction) {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ // Initialize contraction dims (we need to transpose 'B' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 1;
+ contract_dims[0].second = 1;
+
+ for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+ // Compute gradient into col_buffer.
+ TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstTensorMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, dims.out_depth);
+ ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+ Col2im<T>(col_buffer_data, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ } else {
+ typedef Eigen::Map<
+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>
+ ConstMatrixMap;
+
+ for (int image_id = 0; image_id < dims.batch_size;
+ image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+ &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+ &output_image_size, &filter_total_size,
+ &input_backprop_data, &col_buffer_data,
+ &out_backprop_data, &filter_data, &input_offset,
+ &output_offset, &size_C](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ T* im2col_buf = col_buffer_data + shard_id * size_C;
+ T* input_data = input_backprop_data + shard_id * input_offset;
+ const T* out_data = out_backprop_data + shard_id * output_offset;
+
+ // Compute gradient into 'im2col_buf'.
+ MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(im2col_buf, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_data);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ work_unit_size, shard);
+
+ input_backprop_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
+ }
}
private:
@@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>); \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>);
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>);
+
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
template <typename Device, class T>
class Conv3DBackpropFilterOp : public OpKernel {
public:
@@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape;
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel {
return;
}
- // For the backprop of the filter, we need to also transpose the
- // out_backprop.
- // The shape of backprop is
- // [batch, out_z, out_y, out_x, out_depth]
- // And we need to change it to
- // [out_depth, out_x, out_y, out_z, batch]
- Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
- TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
- padded_out_cols, batch});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
+};
+
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
+ if (takes_shape_) {
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes.vec<int32>(), &filter_shape));
+ } else {
+ filter_shape = context->input(1).shape();
+ }
+
+ ConvBackpropDimensions dims;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // For the backprop of the filter, we need to transpose the input.
- // The shape of input is
- // [batch, in_z, in_y, in_x, in_depth]
- // And we need to change it to
- // [in_z, in_y, in_x, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
- TensorShape in_shuffle_shape(
- {input_size[0], input_size[1], input_size[2], batch, in_depth});
- Tensor in_shuffle;
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
+ Tensor* filter_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- in_shuffle_shape, &in_shuffle));
- // No need for reversing this time.
- Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
- no_reverse, in_shuffle.tensor<T, 5>());
- const Tensor& in_shuffle_cref = in_shuffle;
-
- // The output of the conv_3d would be
- // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
- // and we need to shuffle it back to
- // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
- // And we need to reverse the filter backprops.
- // So we need to allocate (sigh) yet another piece of memory to hold the
- // output.
- TensorShape filter_shuffle_shape(
- {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
- Tensor filter_shuffle;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::v(),
- filter_shuffle_shape, &filter_shuffle));
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
-
- // Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ if (input_shape.num_elements() == 0) {
+ filter_backprop->template flat<T>().setZero();
+ return;
+ }
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ // Shard 'batch' images (volumes) into 'shard_size' groups of images
+ // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+ // dividing the L3 cache size ('target_working_set_size') by the matmul size
+ // of an individual image ('work_unit_size').
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // TODO(andydavis)
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ const int64 size_A = output_image_size * filter_total_size;
+
+ const int64 size_B = output_image_size * dims.out_depth;
+
+ const int64 size_C = filter_total_size * dims.out_depth;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* input_data = input.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+ T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+ C.setZero();
+
+ // Initialize contraction dims (we need to transpose 'A' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 0;
+ contract_dims[0].second = 0;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+ &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+ &bottom_pad_rows, &right_pad_cols, &input_offset,
+ &size_A](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ const T* input_data_shard = input_data + shard_id * input_offset;
+ T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data_shard, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ col_data_shard);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ size_A, shard);
+
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ dims.out_depth);
+
+ // Gradient with respect to filter.
+ C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+ input_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
}
private:
@@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropFilterOp<CPUDevice, T>); \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
-TF_CALL_half(REGISTER_CPU_KERNEL);
+
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
// GPU definitions of both ops.
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
@@ -445,7 +1054,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
@@ -523,6 +1133,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -531,7 +1145,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1160,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
- dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
- stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+ if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+ dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+ dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
- const uint64 k = out_depth;
- const uint64 n = in_depth;
+ const uint64 m = dims.batch_size * dims.input_size(0) *
+ dims.input_size(1) * dims.input_size(2);
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -567,13 +1190,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = batch;
- const uint64 k = out_depth;
- const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -597,65 +1221,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
- const bool planes_odd = (padding_planes % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
- batch,
- in_depth,
- input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd,
+ dims.batch_size,
+ dims.in_depth,
+ dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd,
};
} else {
- compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
- input_size[2]};
+ compatible_input_shape = {dims.batch_size, dims.in_depth,
+ dims.input_size(0), dims.input_size(1),
+ dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,20 +1282,23 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &transformed_filter));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
- context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ context->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
- if (out_depth > 1) {
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
+ if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
@@ -713,14 +1334,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
@@ -799,10 +1420,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- {batch, in_depth, input_size[0],
- input_size[1], input_size[2]},
- &in_backprop_remove_padding));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {dims.batch_size, dims.in_depth, dims.input_size(0),
+ dims.input_size(1), dims.input_size(2)},
+ &in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1518,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1531,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
@@ -914,13 +1545,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
- dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
- strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+ if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+ dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+ dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = in_depth;
- const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
- const uint64 n = out_depth;
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.input_size(1) *
+ dims.input_size(2) * dims.input_size(0);
+ const uint64 n = dims.out_depth;
// The shape of output backprop is
// [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1584,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
- const uint64 k = batch;
- const uint64 n = out_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
@@ -979,30 +1613,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
- bool planes_odd = (padding_planes % 2 != 0);
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
Tensor compatible_input;
if (rows_odd || cols_odd || planes_odd) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, batch,
- {{input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd}},
- in_depth),
- &compatible_input));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format_, dims.batch_size,
+ {{dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd}},
+ dims.in_depth),
+ &compatible_input));
functor::PadInput<GPUDevice, T, int, 5>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1644,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
.set_spatial_dim(DimIndex::Y,
GetTensorDim(compatible_input, data_format_, '1'))
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1680,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &pre_transformed_filter_backprop));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
- if (out_depth > 1) {
+ if (dims.out_depth > 1) {
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1706,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
Tensor transformed_input;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
- compatible_input.dim_size(2),
- compatible_input.dim_size(3)};
- if (in_depth > 1) {
+ TensorShape nchw_shape = {
+ dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2), compatible_input.dim_size(3)};
+ if (dims.in_depth > 1) {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
nchw_shape, &transformed_input));
@@ -1110,14 +1740,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ef692418d6..717a9f40a9 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -680,9 +680,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
TensorShape({filter.dim_size(3), filter.dim_size(2),
filter.dim_size(0), filter.dim_size(1)}),
&transformed_filter));
-
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_output;
@@ -731,9 +731,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
- &algorithms));
+ OP_REQUIRES(
+ ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown("Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -823,7 +829,8 @@ namespace functor {
extern template struct MatMulConvFunctor<GPUDevice, T>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index a1eed4e68c..83df4dce38 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -386,7 +386,8 @@ struct LaunchConvOp<GPUDevice, T> {
// filter: [x, y, z, in, out]
// t_filter: [out, in, x, y, z]
functor::TransformFilter<GPUDevice, T, int, 5>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
Tensor transformed_output;
@@ -434,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
- stream->parent()),
- &algorithms));
+ OP_REQUIRES(ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown(
+ "Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
+
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -514,7 +521,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index afc611f277..21d135decd 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -142,8 +142,12 @@ class ConvParameters {
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo(
se::StreamExecutor* stream_exec) const {
+ auto* dnn_support = stream_exec->AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
// Skip this check for cuDNN 7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
+ auto version = dnn_support->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a5fa48f85e..46167db3a2 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -170,51 +170,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
return tensor_index;
}
-// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
- Dimension<3> output_dims;
- output_dims[0] = input_dims[2];
- output_dims[1] = input_dims[1];
- output_dims[2] = input_dims[0];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
-
- Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
- Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[2];
- input_tensor_index[1] = output_tensor_index[1];
- input_tensor_index[2] = output_tensor_index[0];
-
- int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
-
- output[output_index] =
- maybe_conj<T, conjugate>::run(ldg(input + input_index));
- }
-}
-
-// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
+// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
+// the given shuffle permutation in template parameters. Shuffle permutation
+// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
+// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
+// will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
+//
+// Requires that nthreads is equal to the total number of elements in the input
+// tensor.
+template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
+__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
+ Dimension<3> input_dims, T* output) {
Dimension<3> output_dims;
- output_dims[0] = input_dims[0];
- output_dims[1] = input_dims[2];
- output_dims[2] = input_dims[1];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
+ output_dims[sp0] = input_dims[0];
+ output_dims[sp1] = input_dims[1];
+ output_dims[sp2] = input_dims[2];
+
+ // Iterate over output as opposed to iterating over input for better
+ // performance. Iterating over output will generate sequential writes and
+ // random reads that performs better compared to sequential reads and random
+ // writes.
+ CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[0];
- input_tensor_index[1] = output_tensor_index[2];
- input_tensor_index[2] = output_tensor_index[1];
+ input_tensor_index[0] = output_tensor_index[sp0];
+ input_tensor_index[1] = output_tensor_index[sp1];
+ input_tensor_index[2] = output_tensor_index[sp2];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
@@ -439,7 +421,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
template <typename T, int NDIMS>
struct TransformFilter<GPUDevice, T, int, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, int>::ConstTensor in,
typename TTypes<T, NDIMS, int>::Tensor out) {
Dimension<3> combined_dims;
@@ -450,13 +432,18 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported output layout: " << ToString(dst_filter_format);
+
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
-// Converts Cudnn filter format back to TensorFlow filter format.
+// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename T, int NDIMS>
struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
@@ -470,7 +457,7 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
combined_dims[2] *= in.dimension(i);
}
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
@@ -937,7 +924,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
- SwapDimension1And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, input, input_dims, output);
}
@@ -969,7 +956,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
- SwapDimension0And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in, input_dims, out);
}
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 3a1ac73f64..87efdff789 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -628,6 +628,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.cc"],
+ deps = [
+ ":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
@@ -675,6 +689,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "model_dataset_op",
+ srcs = ["model_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
srcs = ["dataset_ops.cc"],
deps = [
@@ -708,6 +735,8 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":map_defun_op",
+ ":model_dataset_op",
+ ":multi_device_iterator_ops",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index a25f78c6f1..d1db1d7bec 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 221b5ad835..34c6c86538 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
- new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+ new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -553,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new MemoryIterator(
- {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+ {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index ad2365b25b..0bb929b3ce 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,43 +17,101 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
-/* static */
-Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
- std::unique_ptr<CapturedFunction>* out_function) {
- return Create(func, std::move(captured_inputs), true, out_function);
-}
+namespace {
+
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+ void IncrementProcessingTime(int64 delta) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+ return new SimpleNodeExecStats(this);
+ }
+
+ string ReportAllocsOnResourceExhausted(const string& err) override {
+ return "";
+ }
+
+ int64 processing_time() {
+ tf_shared_lock l(mu_);
+ return processing_time_;
+ }
+
+ private:
+ class SimpleNodeExecStats : public NodeExecStatsInterface {
+ public:
+ explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+ : step_stats_collector_(step_stats_collector) {}
+
+ void Done(const string& device) override {
+ step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+ start_time_ns_);
+ delete this;
+ }
+
+ void RecordExecutorStarted() override {
+ start_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void RecordComputeStarted() override {}
+
+ void RecordComputeEnded() override {}
+
+ void RecordExecutorEnded() override {
+ end_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void SetMemory(OpKernelContext* ctx) override {}
+
+ void SetOutput(int slot, const Tensor* tensor) override {}
+
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+ void SetScheduled(int64 nanos) override {}
+
+ private:
+ int64 start_time_ns_ = 0;
+ int64 end_time_ns_ = 0;
+ SimpleStepStatsCollector* step_stats_collector_; // Not owned.
+ };
+
+ mutex mu_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+} // namespace
/* static */
Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
- bool use_inter_op_parallelism,
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs),
- use_inter_op_parallelism));
- return Status::OK();
+ return CapturedFunction::Create(func, ctx, argument, true, out_function);
}
-/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
- OpInputList argument_inputs;
- TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
- std::vector<Tensor> arguments_t;
- arguments_t.reserve(argument_inputs.size());
- for (const Tensor& t : argument_inputs) {
- arguments_t.push_back(t);
- }
- return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+ OpInputList inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
+ std::vector<Tensor> arguments(inputs.begin(), inputs.end());
+ *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments),
+ use_inter_op_parallelism));
+ return Status::OK();
}
CapturedFunction::~CapturedFunction() {
@@ -358,7 +416,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
void CapturedFunction::RunAsync(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix) {
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
@@ -368,13 +427,13 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
done(s);
return;
}
- auto frame =
+ OwnedArgsCallFrame* frame =
new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- auto step_container = new ScopedStepContainer(
+ ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
});
@@ -389,25 +448,40 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
+ CancellationManager* c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
-
- tf_shared_lock l(mu_);
- ctx->lib()->Run(f_opts, handle, frame,
- std::bind(
- [rets, step_container, c_mgr, frame](
- FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete step_container;
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ std::shared_ptr<SimpleStepStatsCollector> stats_collector;
+ if (ctx->model()) {
+ stats_collector = MakeUnique<SimpleStepStatsCollector>();
+ }
+ f_opts.stats_collector = stats_collector.get();
+
+ auto callback = std::bind(
+ [rets, step_container, c_mgr, frame](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<model::Model>& model, const string& prefix,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ delete frame;
+ if (model) {
+ model->AddProcessingTime(prefix, stats_collector->processing_time());
+ model->RecordStart(prefix, false /* stop_output */);
+ }
+ done(s);
+ if (model) {
+ model->RecordStop(prefix, false /* start_output */);
+ }
+ },
+ std::move(done), ctx->model(), prefix, std::move(stats_collector),
+ std::placeholders::_1);
+
+ ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index e44bc78b1c..a10376bf97 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -42,27 +42,19 @@ namespace data {
// context.
class CapturedFunction {
public:
- // Creates a new instance from a list of named attributes and captured inputs.
- //
- // NOTE(mrry): The `captured_inputs` are passed by value. For
- // efficiency, you are recommended to move this argument into the call.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
std::unique_ptr<CapturedFunction>* out_function);
- // Creates a new instance from a list of named attributes and captured inputs.
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
//
// If `use_inter_op_parallelism` is false, the runtime may use an executor
// that is optimized for small functions.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
- bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction>* out_function);
-
- // Creates a new instance using a list of named attributes, fetching captured
- // inputs from a context argument.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
- const string& argument,
+ const string& argument, bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
~CapturedFunction();
@@ -104,7 +96,8 @@ class CapturedFunction {
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done);
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix);
// Returns the named list of function arguments.
const NameAttrList& func() { return func_; }
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e7ac368ae3..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -44,5 +44,42 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 234856ea39..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -27,6 +27,16 @@ Status MakeIteratorFromInputElement(
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
+// Returns Status::OK() if `expected` and `received` types match,
+// errors::InvalidArgument otherwise.
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received);
+
+// Returns Status::OK() if `expected` and `received` shapes are compatible,
+// errors::InvalidArgument otherwise.
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received);
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index bf0aecaf3c..00884314a9 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace data {
@@ -37,14 +39,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
FunctionLibraryRuntime::Handle pred_handle;
OP_REQUIRES_OK(ctx,
ctx->function_library()->Instantiate(
@@ -61,9 +55,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Node* ret_node = pred_body->ret_nodes[0];
Node* ret_input_node;
OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
if (ret_input_node->def().op() == "_Arg") {
int32 index = -1;
@@ -146,7 +141,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params) {}
+ : DatasetIterator<FilterDatasetBase>(params),
+ filtered_elements_(0),
+ dropped_elements_(0) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -161,6 +162,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
// `input_impl_` and `f` are thread-safe. However, if multiple
// threads enter this method, outputs may be observed in a
// non-deterministic order.
+ auto stats_aggregator = ctx->stats_aggregator();
bool matched;
do {
{
@@ -183,8 +185,34 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ dropped_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::dropped_elements"),
+ static_cast<float>((dropped_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect
+ // aggregated number of dropped elements for all the pipelines,
+ // exploit tagged_context here.
+ stats_aggregator->IncrementCounter(
+ prefix_end_, "dropped_elements", static_cast<float>(1));
+ }
}
} while (!matched);
+ // TODO(shivaniagrawal): add ratio of dropped_elements and
+ // filtered_elements as a histogram.
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ filtered_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::filtered_elements"),
+ static_cast<float>((filtered_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect aggregated
+ // number of filtered elements for all the pipelines, exploit
+ // tagged_context here.
+ stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements",
+ static_cast<float>(1));
+ }
*end_of_sequence = false;
return Status::OK();
}
@@ -197,6 +225,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
else
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"),
+ filtered_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"),
+ dropped_elements_));
return Status::OK();
}
@@ -207,12 +239,19 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
input_impl_.reset();
else
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"),
+ &filtered_elements_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"),
+ &dropped_elements_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ int64 filtered_elements_ GUARDED_BY(mu_);
+ int64 dropped_elements_ GUARDED_BY(mu_);
+ string prefix_end_;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index e3c45ef86c..2fada22a21 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
-
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
}
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index ac5cc1b2c1..71a36314a0 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -145,44 +145,18 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index e4fa557598..8b417bb1c2 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -42,50 +42,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- // Get captured inputs for the key, reduce, and window_size functions.
- OpInputList key_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments",
- &key_func_other_argument_inputs));
- std::vector<Tensor> key_func_other_arguments;
- key_func_other_arguments.reserve(key_func_other_argument_inputs.size());
- for (const Tensor& t : key_func_other_argument_inputs) {
- key_func_other_arguments.push_back(t);
- }
- OpInputList reduce_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments",
- &reduce_func_other_argument_inputs));
- std::vector<Tensor> reduce_func_other_arguments;
- reduce_func_other_arguments.reserve(
- reduce_func_other_argument_inputs.size());
- for (const Tensor& t : reduce_func_other_argument_inputs) {
- reduce_func_other_arguments.push_back(t);
- }
- OpInputList window_size_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx,
- ctx->input_list("window_size_func_other_arguments",
- &window_size_func_other_argument_inputs));
- std::vector<Tensor> window_size_func_other_arguments;
- window_size_func_other_arguments.reserve(
- window_size_func_other_argument_inputs.size());
- for (const Tensor& t : window_size_func_other_argument_inputs) {
- window_size_func_other_arguments.push_back(t);
- }
- // TODO(mrry): Refactor CapturedFunction to share the runtime
- // state between multiple functions?
std::unique_ptr<CapturedFunction> captured_key_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- key_func_, std::move(key_func_other_arguments),
- &captured_key_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
std::unique_ptr<CapturedFunction> captured_reduce_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(reduce_func_,
- std::move(reduce_func_other_arguments),
- &captured_reduce_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
std::unique_ptr<CapturedFunction> captured_window_size_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- window_size_func_, std::move(window_size_func_other_arguments),
- &captured_window_size_func));
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(window_size_func_, ctx,
+ "window_size_func_other_arguments",
+ &captured_window_size_func));
*output = new Dataset(
ctx, input, key_func_, reduce_func_, window_size_func_,
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 0768f46665..0aa802b874 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
const Tensor* cycle_length_t;
OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()),
@@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("block_length must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index fe6d705eab..c0bc507ec0 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -44,43 +44,6 @@ namespace {
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
} // namespace
class IteratorResource : public ResourceBase {
@@ -403,12 +366,12 @@ class IteratorStateVariant {
}
string TypeName() const { return kIteratorVariantTypeName; }
void Encode(VariantTensorData* data) const { *data = *data_; }
- bool Decode(const VariantTensorData& data) {
+ bool Decode(VariantTensorData data) {
if (data.type_name() != TypeName()) {
return false;
}
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
- *tensor_data = data;
+ std::swap(*tensor_data, data);
std::unique_ptr<VariantTensorDataReader> reader(
new VariantTensorDataReader(tensor_data.get()));
status_ = reader->status();
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 27c89b3661..2bbf4af664 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <atomic>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
@@ -39,7 +41,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +50,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 batch_size;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
OP_REQUIRES(
@@ -77,7 +70,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -92,8 +86,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
@@ -190,7 +184,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
mutex_lock l(mu_);
@@ -204,6 +199,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ port::NumSchedulableCPUs() /* max */, &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -218,12 +223,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
@@ -326,11 +333,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- result->num_calls--;
- }
+ mutex_lock l(mu_);
+ num_calls_--;
+ result->num_calls--;
cond_var_.notify_all();
}
@@ -365,7 +370,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
- });
+ },
+ prefix());
},
ctx, std::move(input_element)));
}
@@ -422,11 +428,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
- dataset()->batch_size_;
- }
-
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
@@ -475,26 +476,34 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
+ RecordStart(ctx.get());
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+ new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ int64 max_batch_results =
+ (num_parallel_calls + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ return num_calls_ >= num_parallel_calls ||
+ (batch_results_.size() > max_batch_results ||
+ (batch_results_.size() == max_batch_results &&
+ call_counter_ % dataset()->batch_size_ == 0));
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
+ while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
batch_results_.emplace_back(
new BatchResult(dataset()->batch_size_));
@@ -638,6 +647,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
@@ -661,7 +672,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index af301e2b42..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -38,18 +38,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments),
- use_inter_op_parallelism_, &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index b87d61ee44..6657f2b2b3 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -81,119 +81,167 @@ class MapDefunOp : public AsyncOpKernel {
}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size;
- OP_REQUIRES_OK_ASYNC(ctx, GetInputBatchSize(ctx, &batch_size), done);
+ ComputeOptions* compute_opts = nullptr;
- // Inputs
- auto* args = new std::vector<Tensor>;
- auto* arg_shapes = new std::vector<TensorShape>;
+ OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
- // Create a copy because every `Compute` may have different output shapes.
- auto* output_shapes = new std::vector<PartialTensorShape>(output_shapes_);
- arg_shapes->reserve(ctx->num_inputs());
- args->reserve(ctx->num_inputs());
+ Status s = SetupOutputs(ctx, compute_opts);
+ if (!s.ok()) delete compute_opts;
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
- auto* mu = new mutex;
-
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args->push_back(ctx->input(i));
- arg_shapes->push_back(ctx->input(i).shape());
- arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- }
-
- // Outputs
- auto* output = new OpOutputList;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
-
- for (size_t i = 0; i < output_types().size(); ++i) {
- if (output_shapes_.at(i).IsFullyDefined()) {
- Tensor* out = nullptr;
- TensorShape output_shape;
- output_shapes_.at(i).AsTensorShape(&output_shape);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out),
- done);
- }
- }
-
- SetRunOptions(ctx, &opts_, false);
+ FunctionLibraryRuntime::Options opts;
+ SetRunOptions(ctx, &opts, false);
// Run loop
StatusCallback callback = std::bind(
- [](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes,
- std::vector<PartialTensorShape>* output_shapes, OpOutputList* output,
- mutex* mu, DoneCallback& done, const Status& status) {
- delete args;
- delete arg_shapes;
- delete output;
- delete output_shapes;
- delete mu;
+ [](OpKernelContext* ctx, ComputeOptions* compute_opts,
+ DoneCallback& done, const Status& status) {
+ delete compute_opts;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output_shapes, output, mu, std::move(done),
- std::placeholders::_1);
+ ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
- for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
- // Start from i = 1 because refcounted is initialized with refcount = 1
- refcounted->Ref();
- }
+ CancellationManager* parent_mgr = ctx->cancellation_manager();
- for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame = new MapFunctionCallFrame(
- *args, *arg_shapes, output_shapes, mu, output, this, i,
- static_cast<size_t>(batch_size));
+ for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+ // We use a different cancellation manager each time the function is run
+ // to avoid the race condition between a function run error and other
+ // functions being cancelled as a result.
CancellationManager* c_mgr = new CancellationManager;
- opts_.cancellation_manager = c_mgr;
- ctx->function_library()->Run(
- opts_, func_handle_, call_frame,
- [call_frame, refcounted, c_mgr](const Status& func_status) {
- delete call_frame;
- delete c_mgr;
- refcounted->UpdateStatus(func_status);
- refcounted->Unref();
- });
+ CancellationToken token = parent_mgr->get_cancellation_token();
+ const bool success = parent_mgr->RegisterCallback(
+ token, [c_mgr]() { c_mgr->StartCancel(); });
+
+ opts.cancellation_manager = c_mgr;
+ if (!success) {
+ delete c_mgr;
+ refcounted->UpdateStatus(errors::Cancelled(
+ "MapDefunOp functions cancelled because parent graph cancelled"));
+ break;
+ }
+
+ auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
+ refcounted->Ref();
+ ctx->function_library()->Run(opts, func_handle_, call_frame,
+ [call_frame, refcounted, c_mgr, parent_mgr,
+ token](const Status& func_status) {
+ parent_mgr->DeregisterCallback(token);
+ delete c_mgr;
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
}
+
+ // Unref 1 because refcounted is initialized with refcount = 1
+ refcounted->Unref();
}
private:
FunctionLibraryRuntime::Handle func_handle_;
- FunctionLibraryRuntime::Options opts_;
std::vector<PartialTensorShape> output_shapes_;
+ struct ComputeOptions {
+ // These vary per MapDefunOp::ComputeAsync call, but must persist until
+ // all calls to the function are complete. This struct also encapsulates
+ // all the components that need to be passed to each MapFunctionCallFrame.
+
+ const std::vector<Tensor> args;
+ const std::vector<TensorShape> arg_shapes;
+ const int64 batch_size;
+
+ // Output of a compute call
+ std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+ OpOutputList output GUARDED_BY(mu);
+ mutex mu;
+
+ // Create a copy of output_shapes because every `Compute` may expect a
+ // different output shape.
+ ComputeOptions(std::vector<Tensor> args,
+ std::vector<TensorShape> arg_shapes, int64 batch_size,
+ const std::vector<PartialTensorShape>& output_shapes_attr)
+ : args(std::move(args)),
+ arg_shapes(std::move(arg_shapes)),
+ batch_size(batch_size),
+ output_shapes(output_shapes_attr) {}
+ };
+
+ // Get inputs to Compute and check that they are valid.
+ Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+ int64 batch_size =
+ ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ if (ctx->input(i).dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (ctx->input(i).dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+
+ std::vector<Tensor> args;
+ std::vector<TensorShape> arg_shapes;
+ args.reserve(ctx->num_inputs());
+ arg_shapes.reserve(ctx->num_inputs());
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ arg_shapes.push_back(ctx->input(i).shape());
+ arg_shapes.at(i).RemoveDim(0);
+ }
+
+ *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
+ batch_size, output_shapes_);
+ return Status::OK();
+ }
+
+ Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+ mutex_lock l(opts->mu);
+ TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, opts->batch_size);
+ TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+ }
+ }
+ return Status::OK();
+ }
+
class MapFunctionCallFrame : public CallFrameInterface {
public:
- MapFunctionCallFrame(const std::vector<Tensor>& args,
- const std::vector<TensorShape>& arg_shapes,
- std::vector<PartialTensorShape>* output_shapes,
- mutex* output_shapes_mutex, OpOutputList* output,
- OpKernel* kernel, size_t iter, size_t batch_size)
- : args_(args),
- arg_shapes_(arg_shapes),
- output_shapes_(output_shapes),
- output_shapes_mutex_(output_shapes_mutex),
- output_(output),
- kernel_(kernel),
- iter_(iter),
- batch_size_(batch_size) {}
+ MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+ size_t iter)
+ : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
- size_t num_args() const override { return args_.size(); }
+ size_t num_args() const override { return compute_opts_->args.size(); }
+
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= args_.size()) {
+ if (index < 0 || index >= compute_opts_->args.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
- bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
- arg_shapes_.at(index));
+ bool result =
+ val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
@@ -217,36 +265,34 @@ class MapDefunOp : public AsyncOpKernel {
index);
}
{ // Locking scope
- mutex_lock l(*output_shapes_mutex_);
- if (!output_shapes_->at(index).IsCompatibleWith(val.shape())) {
+ mutex_lock l(compute_opts_->mu);
+ if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+ val.shape())) {
return errors::InvalidArgument(
"Mismatch in function retval shape, ", val.shape(),
- ", and expected output shape,",
- output_shapes_->at(index).DebugString(), ".");
+ ", and expected output shape, ",
+ compute_opts_->output_shapes.at(index).DebugString(), ".");
}
- if (!output_shapes_->at(index).IsFullyDefined()) {
+ if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
// Given val, we have new information about the output shape at
// this index. Store the shape and allocate the output accordingly.
- output_shapes_->at(index) = val.shape();
+ compute_opts_->output_shapes.at(index) = val.shape();
Tensor* out = nullptr;
TensorShape actual_shape = val.shape();
- actual_shape.InsertDim(0, batch_size_);
- TF_RETURN_IF_ERROR(output_->allocate(index, actual_shape, &out));
+ actual_shape.InsertDim(0, compute_opts_->batch_size);
+ TF_RETURN_IF_ERROR(
+ compute_opts_->output.allocate(index, actual_shape, &out));
}
+ return batch_util::CopyElementToSlice(
+ val, (compute_opts_->output)[index], iter_);
}
- return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
}
private:
- const std::vector<Tensor>& args_;
- const std::vector<TensorShape>& arg_shapes_;
- std::vector<PartialTensorShape>* output_shapes_;
- mutex* output_shapes_mutex_;
- OpOutputList* output_;
+ ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
- const size_t batch_size_;
};
};
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
new file mode 100644
index 0000000000..9aa505f4f1
--- /dev/null
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
+class ModelDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ModelDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Model")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override { return "ModelDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ model_(std::make_shared<model::Model>()) {}
+
+ ~Iterator() override {
+ // Signal the optimize thread to terminate it. We will then join that
+ // thread when we delete `this->optimize_thread_`.
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ cond_var_.notify_all();
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
+ &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx));
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return input_impl_->GetNext(&ctx_with_model, out_tensors,
+ end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ IteratorContext::Params CreateParams(IteratorContext* ctx) {
+ IteratorContext::Params params = ctx->params();
+ params.model = model_;
+ return params;
+ }
+
+ private:
+ Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!optimize_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ optimize_thread_.reset(ctx->env()->StartThread(
+ {}, "optimize_thread",
+ [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ }
+ return Status::OK();
+ }
+
+ void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+ int64 last_optimization_ms = 0;
+ int64 optimization_period_ms = 10;
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ last_optimization_ms + optimization_period_ms >=
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
+ cond_var_.wait_for(
+ l, std::chrono::milliseconds(
+ last_optimization_ms + optimization_period_ms -
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
+ }
+ if (cancelled_) return;
+ }
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization
+ // until a threshold is reached.
+ if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms <<= 1;
+ } else {
+ optimization_period_ms = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
+ }
+
+ mutex mu_;
+ condition_variable cond_var_;
+ std::shared_ptr<model::Model> model_;
+ std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
+ ModelDatasetOp);
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
new file mode 100644
index 0000000000..5f143967d9
--- /dev/null
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -0,0 +1,633 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct HostBufferElement {
+ Status status;
+ bool end_of_sequence;
+ std::vector<Tensor> value;
+};
+
+using MultiDeviceIteratorCallback =
+ std::function<void(const HostBufferElement&)>;
+
+class MultiDeviceIterator : public ResourceBase {
+ public:
+ MultiDeviceIterator(const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ const std::vector<string>& devices,
+ std::unique_ptr<FunctionLibraryDefinition> flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ FunctionLibraryRuntime* lib)
+ : output_types_(output_types),
+ output_shapes_(output_shapes),
+ devices_(devices),
+ flib_def_(std::move(flib_def)),
+ pflr_(std::move(pflr)),
+ lib_(lib) {
+ DCHECK(lib_ != nullptr);
+ }
+
+ string DebugString() override {
+ return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
+ " devices");
+ }
+
+ Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
+ int64* incarnation_id) {
+ if (iterator) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, iterator->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
+ }
+
+ mutex_lock l(mu_);
+ if (multi_device_buffer_) {
+ multi_device_buffer_->Reset();
+ }
+
+ ++incarnation_id_;
+ *incarnation_id = incarnation_id_;
+
+ multi_device_buffer_.reset(
+ new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
+ std::move(iterator)));
+ return Status::OK();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ if (lib_ != nullptr) {
+ ctx->set_lib(lib_);
+ }
+ tf_shared_lock l(mu_);
+ multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
+ std::move(callback));
+ }
+
+ const DataTypeVector& output_types() const { return output_types_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ std::shared_ptr<const FunctionLibraryDefinition> function_library() {
+ tf_shared_lock l(mu_);
+ return lib_def_;
+ }
+
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
+ private:
+ // A private class that uses a background thread to keep a per device buffer
+ // full.
+ class MultiDeviceBuffer {
+ public:
+ MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
+ std::unique_ptr<IteratorBase> host_iterator)
+ : buffer_(size),
+ size_(size),
+ max_buffer_size_(max_buffer_size),
+ incarnation_id_(incarnation_id),
+ host_iterator_(std::move(host_iterator)) {}
+
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
+
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (background_thread_finished_) {
+ return;
+ }
+
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
+
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
+ }
+ RunPendingCallbacks();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ HostBufferElement elem;
+ if (incarnation_id_ != incarnation_id) {
+ elem.status = errors::InvalidArgument("Invalid incarnation id");
+ callback(elem);
+ return;
+ }
+
+ bool produced_output = false;
+ {
+ mutex_lock l(mu_);
+ if (cancelled_) {
+ elem.status = errors::Cancelled("Cancelled Multidevice iterator");
+ callback(elem);
+ return;
+ }
+
+ EnsureBackgroundThreadStarted(ctx);
+
+ if (!buffer_[shard_num].data.empty()) {
+ produced_output = true;
+ std::swap(elem, buffer_[shard_num].data.front());
+ buffer_[shard_num].data.pop_front();
+ // Wake up background thread if it is blocked on this element.
+ if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
+ buffer_[shard_num].cond_var.notify_all();
+ }
+ } else {
+ if (background_thread_finished_) {
+ produced_output = true;
+ elem.end_of_sequence = true;
+ } else {
+ buffer_[shard_num].callbacks.push_back(std::move(callback));
+ callback = nullptr;
+ }
+ }
+ }
+
+ if (produced_output) {
+ callback(elem);
+ }
+ }
+
+ private:
+ void EnsureBackgroundThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!background_thread_) {
+ background_thread_.reset(ctx->env()->StartThread(
+ {}, "multi_device_iterator_background_thread",
+ std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
+ this, new IteratorContext(*ctx))));
+ }
+ }
+
+ void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
+ // Run all remaining callbacks.
+ std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
+ std::vector<HostBufferElement> cancellation_elements;
+ {
+ mutex_lock l(mu_);
+
+ for (int i = 0; i < size_; ++i) {
+ while (!buffer_[i].callbacks.empty()) {
+ if (buffer_[i].data.empty()) {
+ HostBufferElement elem;
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ cancellation_elements.push_back(std::move(elem));
+ } else {
+ cancellation_elements.push_back(
+ std::move(buffer_[i].data.front()));
+ buffer_[i].data.pop_front();
+ }
+ cancellation_callbacks.push_back(
+ std::move(buffer_[i].callbacks.front()));
+ buffer_[i].callbacks.pop_front();
+ }
+ }
+ }
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_elements[i]);
+ }
+ }
+
+ void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ int shard_to_fetch = 0;
+ while (true) {
+ HostBufferElement elem;
+ MultiDeviceIteratorCallback callback = nullptr;
+ bool end_of_iterator = false;
+
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].cond_var.wait(l);
+ }
+
+ if (cancelled_) {
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ return;
+ }
+ }
+
+ elem.status =
+ host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
+
+ if (elem.status.ok() && elem.end_of_sequence) {
+ end_of_iterator = true;
+ }
+
+ {
+ mutex_lock l(mu_);
+ // Try to find a callback, else just push stuff into buffer.
+ if (!buffer_[shard_to_fetch].callbacks.empty()) {
+ callback = buffer_[shard_to_fetch].callbacks.front();
+ buffer_[shard_to_fetch].callbacks.pop_front();
+ } else {
+ buffer_[shard_to_fetch].data.push_back(std::move(elem));
+ elem = HostBufferElement();
+ }
+ }
+
+ if (callback) {
+ (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
+ }
+
+ // Finish off the thread if we reach the end of the iterator. Runs
+ // pending callbacks.
+ if (end_of_iterator) {
+ {
+ mutex_lock l(mu_);
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ }
+ RunPendingCallbacks();
+ return;
+ }
+ shard_to_fetch = (shard_to_fetch + 1) % size_;
+ }
+ }
+
+ struct HostBuffer {
+ condition_variable cond_var;
+ std::deque<HostBufferElement> data;
+ std::deque<MultiDeviceIteratorCallback> callbacks;
+ };
+
+ mutex mu_;
+ std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
+ bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
+
+ std::vector<HostBuffer> buffer_;
+
+ const size_t size_;
+ const int64 max_buffer_size_;
+ const int64 incarnation_id_;
+ const std::unique_ptr<IteratorBase> host_iterator_;
+ };
+
+ mutex mu_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<string> devices_;
+ const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
+ std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
+
+ int64 incarnation_id_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
+};
+
+// Just creates a MultiDeviceIterator and returns it.
+class MultiDeviceIteratorHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel.
+ ~MultiDeviceIteratorHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MultiDeviceIterator>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<MultiDeviceIterator>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MultiDeviceIterator(
+ output_types_, output_shapes_, devices_,
+ std::move(flib_def), std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MultiDeviceIterator>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MultiDeviceIterator* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+ string name_;
+ string container_;
+ std::vector<string> devices_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
+ MultiDeviceIteratorHandleOp);
+
+// Calls init on the MultiDeviceIterator.
+class MultiDeviceIteratorInitOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* tensor_max_buffer_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
+ int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
+ core::ScopedUnref unref(resource);
+
+ std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
+ int64 incarnation_id;
+ OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
+ &incarnation_id));
+ Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
+ tensor_incarnation_id.scalar<int64>()() = incarnation_id;
+ OP_REQUIRES_OK(ctx,
+ ctx->set_output("incarnation_id", tensor_incarnation_id));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
+ MultiDeviceIteratorInitOp);
+
+// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
+// TODO(rohanj): Implement using BackgroundWorker that Derek built?
+class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
+ public:
+ explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("multi_device_iterator_get_next_thread_",
+ SanitizeThreadSuffix(name())),
+ 1 /* num_threads */, false /* low_latency_hint */)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor* tensor_shard_num;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
+ int32 shard_num = tensor_shard_num->scalar<int32>()();
+
+ const Tensor* tensor_incarnation_id;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
+ int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
+
+ MultiDeviceIterator* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ thread_pool_->Schedule(std::bind(
+ [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ MultiDeviceIteratorCallback callback = std::bind(
+ [ctx](const HostBufferElement& elem, DoneCallback done) {
+ // iterator->Unref();
+ Status s = elem.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (elem.end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < elem.value.size(); ++i) {
+ ctx->set_output(i, elem.value[i]);
+ }
+ }
+ done();
+ },
+ std::placeholders::_1, std::move(done));
+
+ iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
+ callback);
+ iterator->Unref();
+ },
+ std::move(done)));
+ }
+
+ private:
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
+ MultiDeviceIteratorGetNextFromShardOp);
+
+class MultiDeviceIteratorToStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ resource->Unref();
+
+ Tensor* string_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &string_handle_t));
+ string_handle_t->scalar<string>()() =
+ resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorToStringHandleOp);
+
+class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES(
+ ctx,
+ output_types_.empty() || output_shapes_.empty() ||
+ output_types_.size() == output_shapes_.size(),
+ errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
+ "are set, they must have the same length."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& string_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
+ errors::InvalidArgument("string_handle must be a scalar"));
+
+ ResourceHandle resource_handle;
+ OP_REQUIRES(
+ ctx,
+ resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+ errors::InvalidArgument(
+ "Could not parse string_handle as a valid ResourceHandle"));
+
+ OP_REQUIRES(
+ ctx, resource_handle.device() == ctx->device()->attributes().name(),
+ errors::InvalidArgument("Attempted create an iterator on device \"",
+ ctx->device()->attributes().name(),
+ "\" from handle defined on device \"",
+ resource_handle.device(), "\""));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
+ core::ScopedUnref unref_iterator(resource);
+ if (!output_types_.empty()) {
+ OP_REQUIRES_OK(ctx,
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ }
+ if (!output_shapes_.empty()) {
+ OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
+ resource->output_shapes()));
+ }
+
+ Tensor* resource_handle_t;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
+ resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorFromStringHandleOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index b372d31a93..2ab5c83082 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -108,11 +108,8 @@ class OptionalFromValueOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
OpInputList components_input;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
- std::vector<Tensor> components;
- components.reserve(components_input.size());
- for (const Tensor& component_t : components_input) {
- components.push_back(component_t);
- }
+ std::vector<Tensor> components(components_input.begin(),
+ components_input.end());
OP_REQUIRES_OK(
ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
}
@@ -216,6 +213,14 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
+ if (t.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ }
+ }
+ for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
@@ -231,10 +236,9 @@ static Status OptionalDeviceCopy(
return Status::OK();
}
-#define REGISTER_OPTIONAL_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
- OptionalDeviceCopy)
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, OptionalDeviceCopy)
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index fd0e6c4cd0..7b01c3b4e0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 640f1565b7..2e6e0465f7 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <atomic>
#include <deque>
#include <utility>
@@ -44,14 +45,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -83,8 +76,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, interleave_func_, std::move(captured_func),
@@ -252,6 +245,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -351,11 +345,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
+ RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
+ RecordStart(ctx);
}
}
return errors::Cancelled(
@@ -484,10 +480,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -583,10 +579,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -601,7 +597,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -622,10 +619,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- std::unique_ptr<IteratorContext> ctx(ctx_ptr);
- auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
+ RecordStop(ctx.get());
});
bool make_new_iterator;
{
@@ -651,9 +649,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
-
bool read_new_input;
-
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
@@ -665,7 +661,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -715,7 +713,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -764,7 +764,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
@@ -1093,9 +1095,6 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -1111,7 +1110,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
@@ -1119,16 +1118,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
errors::InvalidArgument(
"num_parallel_calls must less than or equal to cycle_length."));
- // TODO(b/114267189): Use `other_arguments(inputs.begin(), inputs.end());`.
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, interleave_func_,
std::move(captured_func), cycle_length, block_length,
@@ -1221,6 +1214,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
+ num_parallel_calls_(params.dataset->num_parallel_calls_),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1241,6 +1235,16 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ AddTunableParameter(ctx, "parallelism",
+ &num_parallel_calls_ /* value */, 1 /* min */,
+ dataset()->cycle_length_ /* max */, &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
+ AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
@@ -1256,7 +1260,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
if (!invocation_results_.empty()) {
std::swap(result, invocation_results_.front());
@@ -1265,9 +1271,11 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
+ RecordStop(ctx);
result->notification.WaitForNotification();
+ RecordStart(ctx);
} while (result->skip);
if (result->status.ok()) {
@@ -1391,6 +1399,8 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
LOCKS_EXCLUDED(mu_) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
for (auto& result : results) {
if (!end_of_input) {
@@ -1408,56 +1418,66 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Release the ownership of the cycle element iterator, closing the
// iterator if end of input was encountered.
- {
- if (end_of_input) {
- current_elements_[cycle_index].reset();
- }
- mutex_lock l(mu_);
- element_in_use_[cycle_index] = false;
- num_calls_--;
- if (end_of_input) {
- args_list_[cycle_index].clear();
- num_open_--;
- }
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
}
cond_var_.notify_all();
}
- int64 MaxInvocationResults() {
- return dataset()->cycle_length_ * dataset()->block_length_;
- }
-
// Method responsible for 1) creating iterators out of input elements, 2)
// determining the order in which elements are fetched from the iterators,
// and 3) scheduling the fetching of the elements to a threadpool.
//
// This method runs in the `runner_thread` background thread.
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ return element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >=
+ dataset()->cycle_length_ * dataset()->block_length_;
+ };
while (true) {
- {
- mutex_lock l(mu_);
- // Wait until this thread is cancelled, the end of input has been
- // reached, or the cycle element at the `cycle_index_` position is
- // not in use and there is space in the `invocation_results_` queue.
- while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
- (element_in_use_[cycle_index_] ||
- num_calls_ >= dataset()->num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- cond_var_.wait(l);
- }
+ mutex_lock l(mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
+ RecordStop(ctx.get());
+ cond_var_.wait(l);
+ RecordStart(ctx.get());
+ }
- if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
- return;
- }
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
- while (!element_in_use_[cycle_index_] &&
- (!end_of_input_ || num_open_ > 0) &&
- num_calls_ < dataset()->num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- if (!current_elements_[cycle_index_]) {
- // Try to create a new iterator from the next input element.
- Status status = input_impl_->GetNext(
- ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ while ((!end_of_input_ || num_open_ > 0) && !busy()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
if (!status.ok()) {
invocation_results_.emplace_back(new InvocationResult());
std::shared_ptr<InvocationResult>& result =
@@ -1466,39 +1486,25 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
result->notification.Notify();
break;
}
- if (!end_of_input_) {
- Status status = MakeIteratorFromInputElement(
- ctx.get(), args_list_[cycle_index_], cycle_index_,
- dataset()->captured_func_.get(), prefix(),
- &current_elements_[cycle_index_]);
- if (!status.ok()) {
- invocation_results_.emplace_back(new InvocationResult());
- std::shared_ptr<InvocationResult>& result =
- invocation_results_.back();
- result->status.Update(status);
- result->notification.Notify();
- break;
- }
- ++num_open_;
- }
+ ++num_open_;
}
- if (current_elements_[cycle_index_]) {
- // Pre-allocate invocation results for outputs to be fetched
- // and then fetch the outputs asynchronously.
- std::vector<std::shared_ptr<InvocationResult>> results;
- results.reserve(dataset()->block_length_);
- for (int i = 0; i < dataset()->block_length_; ++i) {
- invocation_results_.emplace_back(new InvocationResult());
- results.push_back(invocation_results_.back());
- }
- num_calls_++;
- element_in_use_[cycle_index_] = true;
- thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
- ctx, cycle_index_,
- std::move(results)));
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
}
- cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
}
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
cond_var_.notify_all();
}
@@ -1601,6 +1607,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// and there are elements left to be fetched.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
+
// Iterator for input elements.
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index a0cb179eb8..6abe6c8338 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -44,25 +44,17 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments),
- use_inter_op_parallelism_, &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
@@ -97,31 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- ParallelMapIteratorFunction map_func;
- if (use_inter_op_parallelism_) {
- map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- };
- } else {
- map_func = [this](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(
- [this, ctx, result](std::vector<Tensor>& input_element,
- StatusCallback& done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- },
- std::move(input_element), std::move(done)));
+ const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+ ParallelMapIteratorFunction map_func =
+ [this, new_prefix](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done), new_prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](
+ IteratorContext* ctx, std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+ result, std::move(done)));
};
}
- return NewParallelMapIterator(
- {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(init_func), std::move(map_func), num_parallel_calls_);
+ return NewParallelMapIterator({this, new_prefix}, input_,
+ std::move(init_func), std::move(map_func),
+ num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 4ae742aaaf..ee20249bfe 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -14,11 +14,15 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include <atomic>
#include <deque>
#include <functional>
#include <utility>
#include <vector>
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
namespace tensorflow {
namespace data {
namespace {
@@ -37,11 +41,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
num_parallel_calls_(num_parallel_calls) {}
~ParallelMapIterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
mutex_lock l(mu_);
// Cancel the runner thread.
cancelled_ = true;
@@ -53,6 +52,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (num_parallel_calls_ == kAutoTune) {
+ num_parallel_calls_ = 1;
+ // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+ // use it here for the maximum.
+ AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
+ 1 /* min */, port::NumSchedulableCPUs() /* max */,
+ &cond_var_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ }
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
@@ -68,13 +78,17 @@ class ParallelMapIterator : public DatasetBaseIterator {
mutex_lock l(mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
+ RecordStop(ctx);
result->notification.WaitForNotification();
+ RecordStart(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
@@ -87,9 +101,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("invocation_results.size"),
- invocation_results_.size()));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
+ invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
std::shared_ptr<InvocationResult> result = invocation_results_[i];
TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
@@ -176,9 +189,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
{
mutex_lock l(mu_);
num_calls_--;
+ cond_var_.notify_all();
}
result->notification.Notify();
- cond_var_.notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
@@ -193,9 +206,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification` to unblock
- // a consumer.
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
@@ -205,8 +217,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
- int64 MaxInvocationResults() { return num_parallel_calls_; }
-
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@@ -226,27 +236,33 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
new_calls.reserve(num_parallel_calls_);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_;
+ return num_calls_ >= num_parallel_calls ||
+ invocation_results_.size() >= num_parallel_calls;
+ };
while (true) {
{
mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
+ while (!busy()) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
+ cond_var_.notify_all();
}
- cond_var_.notify_all();
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
@@ -295,7 +311,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
- const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
@@ -304,6 +319,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
+ // Identifies the maximum number of parallel calls.
+ std::atomic<int64> num_parallel_calls_;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 0cf5db017b..c28c06da62 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"Expected len(dense_defaults) == len(dense_keys) but got: ",
dense_default_tensors.size(), " vs. ", dense_keys_.size()));
- std::vector<Tensor> dense_defaults;
- dense_defaults.reserve(dense_default_tensors.size());
- for (const Tensor& dense_default_t : dense_default_tensors) {
- dense_defaults.push_back(dense_default_t);
- }
+ std::vector<Tensor> dense_defaults(dense_default_tensors.begin(),
+ dense_default_tensors.end());
for (int d = 0; d < dense_keys_.size(); ++d) {
const Tensor& def_value = dense_defaults[d];
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index 533d0bd5d2..da357339c9 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -26,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
}
}
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+} // namespace
+
void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
switch (mode_) {
case Mode::kDisabled:
@@ -37,7 +44,11 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
return;
case Mode::kDownswing:
if (current_buffer_size == 0) {
- buffer_limit_ *= 2; // Increase the buffer size.
+ if (buffer_limit_ >= kBufferLimitThreshold) {
+ buffer_limit_ += kBufferLimitThreshold;
+ } else {
+ buffer_limit_ *= 2;
+ }
mode_ = Mode::kUpswing;
}
return;
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index ad7d5eb3ff..754ed772db 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
@@ -102,16 +103,18 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ auto stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(mu_);
- auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
if (cancelled_) {
@@ -133,6 +136,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@@ -216,6 +227,12 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
strings::StrCat(prefix_end_, "::buffer_utilization"),
{static_cast<float>(buffer_.size()) /
static_cast<float>(auto_tuner_.buffer_limit())});
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
}
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
@@ -239,10 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
- prefetch_thread_.reset(
- ctx->env()->StartThread({}, "prefetch_thread",
- std::bind(&Iterator::PrefetchThread, this,
- new IteratorContext(*ctx))));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ prefetch_thread_.reset(ctx->env()->StartThread(
+ {}, "prefetch_thread",
+ [this, new_ctx]() { PrefetchThread(new_ctx); }));
}
return Status::OK();
}
@@ -251,8 +268,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// buffer.
//
// It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
+ void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -260,7 +278,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
@@ -277,8 +297,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
bool end_of_sequence;
BufferElement buffer_element;
- buffer_element.status =
- input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+ buffer_element.status = input_impl_->GetNext(
+ ctx.get(), &buffer_element.value, &end_of_sequence);
if (buffer_element.status.ok() && end_of_sequence) {
mutex_lock l(mu_);
prefetch_thread_finished_ = true;
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 6e515d6cc8..dbe31f37b8 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -45,23 +45,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
OpInputList initial_state_inputs;
OP_REQUIRES_OK(ctx,
ctx->input_list("initial_state", &initial_state_inputs));
- std::vector<Tensor> initial_state;
- initial_state.reserve(initial_state_inputs.size());
- for (const Tensor& t : initial_state_inputs) {
- initial_state.push_back(t);
- }
-
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
+ std::vector<Tensor> initial_state(initial_state_inputs.begin(),
+ initial_state_inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(initial_state),
std::move(captured_func), state_types_, output_types_,
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index e1cefd23d8..ca4ea25b89 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -33,11 +33,7 @@ class TensorDatasetOp : public DatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
// the "shapes" attr.;
- std::vector<Tensor> components;
- components.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- components.push_back(t);
- }
+ std::vector<Tensor> components(inputs.begin(), inputs.end());
*output = new Dataset(ctx, std::move(components));
}
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index 3975086841..ac44623ce2 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- *output = new Dataset(ctx, window_size, input);
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+ OP_REQUIRES(
+ ctx, window_stride > 0,
+ errors::InvalidArgument("Window stride must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+ *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+ drop_remainder);
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+ int64 window_shift, int64 window_stride, bool drop_remainder)
: DatasetBase(DatasetContext(ctx)),
+ input_(input),
window_size_(window_size),
- input_(input) {
+ window_shift_(window_shift),
+ window_stride_(window_stride),
+ drop_remainder_(drop_remainder) {
input_->Ref();
}
@@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+ window_stride_, drop_remainder_, ")::Dataset");
}
protected:
@@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* window_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ Node* window_size_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+ Node* window_shift_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+ Node* window_stride_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+ Node* drop_remainder_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size}, output));
+ b->AddDataset(this,
+ {input_graph_node, window_size_node, window_shift_node,
+ window_stride_node, drop_remainder_node},
+ output));
return Status::OK();
}
@@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- // Each row of `window_elements` is a tuple of tensors from the
- // input iterator.
+ const int64 window_size = dataset()->window_size_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> window_elements;
+ Status status = Status::OK();
{
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impl_ && buffer_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
- window_elements.reserve(dataset()->window_size_);
- *end_of_sequence = false;
- for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> window_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
- end_of_sequence));
- if (!*end_of_sequence) {
- window_elements.emplace_back(std::move(window_element_tuple));
- } else {
- input_impl_.reset();
+
+ // Add elements to the buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
+ if (input_impl_) {
+ *end_of_sequence = false;
+ for (size_t i = buffer_.size();
+ i < target_size && !*end_of_sequence; ++i) {
+ std::vector<Tensor> element;
+ Status status =
+ input_impl_->GetNext(ctx, &element, end_of_sequence);
+ if (!*end_of_sequence) {
+ buffer_.emplace_back(std::move(element), status);
+ } else {
+ input_impl_.reset();
+ }
}
}
+
+ // If there are not enough elements and `drop_remainder` is set, we do
+ // not wish to return a smaller window.
+ if (buffer_.empty() ||
+ (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+ window_elements.reserve(num_elements);
+ for (size_t i = 0; i < num_elements; ++i) {
+ status.Update(buffer_[window_stride * i].status);
+ if (!status.ok()) {
+ break;
+ }
+ window_elements.emplace_back(buffer_[window_stride * i].result);
+ }
+
+ // Shift the window, discarding elements if necessary.
+ int buffer_size = buffer_.size();
+ if (window_shift >= buffer_size) {
+ for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ // Ignore non-error status of discarded elements.
+ input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+ if (end_of_input) {
+ input_impl_.reset();
+ }
+ }
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
+ }
}
- if (window_elements.empty()) {
- DCHECK(*end_of_sequence);
- return Status::OK();
+ if (!status.ok()) {
+ return status;
}
+ // Construct output tensors.
const size_t num_tuple_components = window_elements[0].size();
const int64 num_window_elements = window_elements.size();
+ *end_of_sequence = false;
for (size_t idx = 0; idx < num_tuple_components; ++idx) {
DatasetBase* window_dataset;
std::vector<std::vector<Tensor>> window_component_elements;
@@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
&out_tensors->back()));
}
- *end_of_sequence = false;
return Status::OK();
}
@@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
}
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+ buffer_[i].result.size()));
+ for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
@@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
+ // Restore buffer.
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
+ int64 vector_size;
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat("buffer[", i, "].size"), &vector_size));
+ buffer_[i].result.resize(vector_size);
+ for (int64 j = 0; j < vector_size; j++) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ &buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
private:
+ struct InvocationResult {
+ InvocationResult() = default;
+ InvocationResult(std::vector<Tensor>&& result, const Status& status)
+ : result(result), status(status) {}
+
+ std::vector<Tensor> result;
+ Status status;
+ };
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].error_message"));
+ }
+
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
+ std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
- const int64 window_size_;
const DatasetBase* const input_;
+ const int64 window_size_;
+ const int64 window_shift_;
+ const int64 window_stride_;
+ const bool drop_remainder_;
};
};
REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
-
} // namespace
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index b4dcf0a74b..ae451be7e2 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
- OP_REQUIRES(context, width > 0 && header_size >= 0,
+ OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
+ OP_REQUIRES(context, height != 0,
+ errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
@@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel {
const int32 abs_height = abs(height);
// there may be padding bytes when the width is not a multiple of 4 bytes
- // 8 * channels == bits per pixel
- const int row_size = (8 * channels_ * width + 31) / 32 * 4;
+ const int row_size = (channels_ * width + 3) / 4 * 4;
const int64 last_pixel_offset = static_cast<int64>(header_size) +
(abs_height - 1) * row_size +
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 3eed847c16..6bfb5bd5bc 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
for (int i = 0; i < record_defaults.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459194..76afd6f18c 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a97723f..e7882acc80 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index 27918b410b..8edf7d4a2c 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -59,12 +59,12 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const array<
typename internal::traits<OutputBackward>::Index, 5>,
const TensorReverseOp<const Eigen::array<bool, 5>,
- const Kernel> > > >,
+ const Kernel>>>>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > >,
+ const OutputBackward>>>>,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
@@ -75,7 +75,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
+ const OutputBackward>>,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
2>,
@@ -83,7 +83,7 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const array<
typename internal::traits<OutputBackward>::Index, 5>,
const TensorReverseOp<const Eigen::array<bool, 5>,
- const Kernel> > > > > > >::type
+ const Kernel>>>>>>>::type
CuboidConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputPlanes,
@@ -94,12 +94,12 @@ CuboidConvolutionBackwardInput(
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
internal::traits<Kernel>::NumDimensions,
- internal::traits<Kernel>::Layout, TensorIndex> >
+ internal::traits<Kernel>::Layout, TensorIndex>>
kern(kernel);
const TensorRef<
const Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -239,8 +239,8 @@ CuboidConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // kernelPlanes, kernelRows and kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, kernelPlanes, kernelRows and kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -323,35 +323,69 @@ CuboidConvolutionBackwardInput(
*/
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
- internal::traits<OutputBackward>::Layout == ColMajor,
- TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 5>,
- const TensorContractionOp<
- const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const Input> > > > >,
- TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 5>,
- const TensorContractionOp<
- const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const Input> > >,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const OutputBackward> > > >::type
+ internal::traits<Input>::Layout == ColMajor,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>>>>>,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const TensorContractionOp<
+ const array<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>>>>>>::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -362,11 +396,11 @@ CuboidConvolutionBackwardKernel(
typedef typename internal::traits<Input>::Index TensorIndex;
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
internal::traits<Input>::NumDimensions,
- internal::traits<Input>::Layout, TensorIndex> >
+ internal::traits<Input>::Layout, TensorIndex>>
in(input);
TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -380,6 +414,13 @@ CuboidConvolutionBackwardKernel(
internal::traits<OutputBackward>::NumDimensions,
YOU_MADE_A_PROGRAMMING_MISTAKE);
+ // We do not support higher dimensional backward convolutions, or convolutions
+ // without batch dimension.
+ // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+ // dimension in eigen_backward_cuboid_convolutions_test.cc.
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+
const TensorIndex inputPlanes =
isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
const TensorIndex inputRows =
@@ -401,6 +442,10 @@ CuboidConvolutionBackwardKernel(
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
+ // Number of batches in the input tensor.
+ const TensorIndex batch =
+ isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
+
// TODO(ezhulenev): Add support for inflated strides. Without inflated strides
// effective kernel planes/rows/cols are always the same as the kernel itself
// (see eigen_spatial_convolutions for details).
@@ -408,6 +453,7 @@ CuboidConvolutionBackwardKernel(
const TensorIndex kernelRowsEff = kernelRows;
const TensorIndex kernelColsEff = kernelCols;
+ // Compute forward padding from input and output_backward dimensions.
const TensorIndex padPlanes = numext::maxi<Index>(
0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
const TensorIndex padRows = numext::maxi<Index>(
@@ -416,92 +462,147 @@ CuboidConvolutionBackwardKernel(
0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
const TensorIndex padding_top_z = padPlanes / 2;
- const TensorIndex padding_bottom_z = padPlanes - padding_top_z;
const TensorIndex padding_top = padRows / 2;
- const TensorIndex padding_bottom = padRows - padding_top;
const TensorIndex padding_left = padCols / 2;
- const TensorIndex padding_right = padCols - padding_left;
- // Reshaped output_backward before contraction.
- DSizes<TensorIndex, 2> output_dims;
+ // Compute paddings for output_backward before extracting patches.
+ const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+ const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+ const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+ const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+ const auto padded_out_rows = inputRows + kernelRows - 1;
+ const auto padded_out_cols = inputCols + kernelCols - 1;
+ const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+ const auto top_pad_rows = kernelRows - 1 - padding_top;
+ const auto left_pad_cols = kernelCols - 1 - padding_left;
+ const auto bottom_pad_planes =
+ padded_out_planes - expanded_out_planes - top_pad_planes;
+ const auto bottom_pad_rows =
+ padded_out_rows - expanded_out_rows - top_pad_rows;
+ const auto right_pad_cols =
+ padded_out_cols - expanded_out_cols - left_pad_cols;
+
+ // Reorder output_backward dimensions.
+ array<TensorIndex, 5> output_backward_shuffle;
if (isColMajor) {
- output_dims[0] = kernelFilters;
- output_dims[1] = outputPlanes * outputRows * outputCols;
- for (int i = 4; i < NumDims; ++i) {
- output_dims[1] *= out.dimension(i);
- }
+ // From: [out_depth, out_planes, out_rows, out_cols, batch]
+ // To: [batch, out_planes, out_rows, out_cols, out_depth]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
} else {
- output_dims[1] = kernelFilters;
- output_dims[0] = outputCols * outputRows * outputPlanes;
- for (int i = 0; i < NumDims - 4; ++i) {
- output_dims[0] *= out.dimension(i);
- }
+ // From: [batch, out_cols, out_rows, out_planes, out_depth]
+ // To: [out_depth, out_cols, out_rows, out_planes, batch]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
}
- // Reshaped extract_volume_patches(in)
- DSizes<TensorIndex, 2> pre_contract_dims;
+ // Reorder input dimensions.
+ array<TensorIndex, 5> input_shuffle;
if (isColMajor) {
- pre_contract_dims[0] =
- kernelChannels * kernelPlanes * kernelRows * kernelCols;
- pre_contract_dims[1] = outputPlanes * outputRows * outputCols;
- for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[1] *= in.dimension(i);
- }
- eigen_assert(output_dims[1] == pre_contract_dims[1]);
+ // From: [in_depth, in_planes, in_rows, in_cols, batch]
+ // To: [in_depth, batch, in_planes, in_rows, in_cols]
+ input_shuffle = {0, 4, 1, 2, 3};
} else {
- pre_contract_dims[1] =
- kernelCols * kernelRows * kernelPlanes * kernelChannels;
- pre_contract_dims[0] = outputCols * outputRows * outputPlanes;
- for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= in.dimension(i);
- }
- eigen_assert(output_dims[0] == pre_contract_dims[0]);
+ // From: [batch, in_cols, in_rows, in_planes, in_depth]
+ // To: [in_cols, in_rows, in_planes, batch, in_depth]
+ input_shuffle = {1, 2, 3, 0, 4};
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
+ // Input is playing the role of a "kernel" in this convolution.
+ DSizes<TensorIndex, 2> input_dims;
+ if (isColMajor) {
+ input_dims[0] = kernelChannels;
+ input_dims[1] = batch * inputPlanes * inputRows * inputCols;
+ } else {
+ input_dims[1] = kernelChannels;
+ input_dims[0] = inputCols * inputRows * inputPlanes * batch;
+ }
+ // Molds the output of the patch extraction result into a 2D tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
+ if (isColMajor) {
+ pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
+ pre_contract_dims[1] =
+ kernelPlanes * kernelRows * kernelCols * kernelFilters;
+ } else {
+ pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+ pre_contract_dims[0] =
+ kernelFilters * kernelCols * kernelRows * kernelPlanes;
+ }
+
+ // We will contract along the collapsed dimension that contains the
+ // batch, inputPlanes, inputRows and inputCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- DSizes<TensorIndex, 5> kernel_dims;
+ // Dimensions after contraction.
+ DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
- kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelRows;
- kernel_dims[4] = kernelCols;
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = kernelPlanes;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelCols;
+ post_contract_dims[4] = kernelFilters;
} else {
- kernel_dims[4] = kernelFilters;
- kernel_dims[3] = kernelChannels;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[1] = kernelRows;
- kernel_dims[0] = kernelCols;
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = kernelCols;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelPlanes;
+ post_contract_dims[4] = kernelChannels;
}
- return choose(
- Cond<internal::traits<Input>::Layout == ColMajor>(),
- output_backward.reshape(output_dims)
- .contract(input
+ // Reorder output of contraction to valid filter shape.
+ array<TensorIndex, 5> kernel_shuffle;
+ if (isColMajor) {
+ // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+ // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+ kernel_shuffle = {4, 0, 1, 2, 3};
+ } else {
+ // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+ // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+ kernel_shuffle = {1, 2, 3, 4, 0};
+ }
+
+ // Reverse kernel backprop dimensions.
+ array<TensorIndex, 5> kernel_reverse;
+ if (isColMajor) {
+ kernel_reverse = {false, false, true, true, true};
+ } else {
+ kernel_reverse = {true, true, true, false, false};
+ }
+
+ // Create convolution input (aka source of patches) from output backward
+ // tensor by shuffling dimensions.
+ const auto the_input =
+ output_backward.shuffle(output_backward_shuffle).eval();
+
+ // Create convolution kernel (aka filter) from input by shuffling and
+ // reshaping.
+ const auto the_kernel =
+ input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ the_kernel.contract(
+ the_input
.extract_volume_patches(
- kernelPlanes, kernelRows, kernelCols, stridePlanes,
- strideRows, strideCols, 1, 1, 1, padding_top_z,
- padding_bottom_z, padding_top, padding_bottom,
- padding_left, padding_right)
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
- contract_dims)
- .reshape(kernel_dims),
- input
- .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
- stridePlanes, strideRows, strideCols, 1, 1, 1,
- padding_top_z, padding_bottom_z, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
- .contract(output_backward.reshape(output_dims), contract_dims)
- .reshape(kernel_dims));
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols,
+ top_pad_planes, bottom_pad_planes, top_pad_rows,
+ bottom_pad_rows, left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims),
+ contract_dims),
+ the_input
+ .extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols, top_pad_planes,
+ bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+ left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims)
+ .contract(the_kernel, contract_dims))
+ .reshape(post_contract_dims)
+ .shuffle(kernel_shuffle)
+ .reverse(kernel_reverse);
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index 8d06107553..960920c55b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -238,8 +238,8 @@ SpatialConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // the kernelRows and the kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, the kernelRows and the kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -332,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic,
- const Input> > > > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward> > > >::type
@@ -456,12 +449,16 @@ SpatialConvolutionBackwardKernel(
eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
-
+ // We will contract along the collapsed dimension that contains the
+ // outputCols, outputRows and OTHERS.
array<IndexPair<TensorIndex>, 1> contract_dims;
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ if (isColMajor) {
+ // col-major: output_backward.contract(input.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ } else {
+ // row-major: input.patches.contract(output_backward)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ }
// After the contraction, the kernel will have the desired shape
// out_depth X in_shape X kernel_rows X kernel_cols
@@ -487,8 +484,7 @@ SpatialConvolutionBackwardKernel(
kernelRows, kernelCols, row_stride, col_stride,
row_in_stride, col_in_stride, 1, 1, padding_top,
padding_bottom, padding_left, padding_right, OutScalar(0))
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
+ .reshape(pre_contract_dims),
contract_dims)
.reshape(kernel_dims),
input
@@ -497,7 +493,6 @@ SpatialConvolutionBackwardKernel(
padding_top, padding_bottom, padding_left,
padding_right, OutScalar(0))
.reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
.contract(output_backward.reshape(output_dims), contract_dims)
.reshape(kernel_dims));
}
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec9659..673ec1458b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ /*num_batches*/ 1);
Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
patch_cols);
- Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
- output_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, /*num_batches*/ 1);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(id, i, j, k) *
- output_backward(od, output_i, output_j, output_k);
+ expected += input(id, i, j, k, /*batch*/ 0) *
+ output_backward(od, output_i, output_j,
+ output_k, /*batch*/ 0);
}
}
}
@@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
- input_depth);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+ input_planes, input_depth);
Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
input_depth, output_depth);
- Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
- output_planes, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(k, j, i, id) *
- output_backward(output_k, output_j, output_i, od);
+ expected += input(/*batch*/ 0, k, j, i, id) *
+ output_backward(/*batch*/ 0, output_k, output_j,
+ output_i, od);
}
}
}
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 62e9f9123d..6a9a2accd8 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -21,6 +21,1412 @@ limitations under the License.
namespace Eigen {
+namespace internal {
+
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract volume patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
+// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionInputMapper<
+ Scalar_, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef Scalar_ Scalar;
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper VectorMapper;
+ typedef SubMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(
+ const TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension,
+ const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
+ Device>& tensor,
+ const nocontract_t&, const nocontract_t&, const contract_t&,
+ const contract_t&)
+ : m_impl(tensor.impl().impl()) {
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_patch_depth = tensor.impl().dimensions()[0];
+ m_patch_planes = tensor.impl().dimensions()[1];
+ m_patch_rows = tensor.impl().dimensions()[2];
+ m_patch_cols = tensor.impl().dimensions()[3];
+ m_num_patches = tensor.impl().dimensions()[4];
+ } else {
+ const int NumDims = tensor.impl().dimensions().size();
+ m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
+ m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
+ m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
+ m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
+ m_num_patches = tensor.impl().dimensions()[NumDims - 5];
+ }
+
+ // Strides for navigating through the single patch.
+ m_patch_plane_stride = m_patch_depth;
+ m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
+ m_patch_col_stride = m_patch_rows * m_patch_row_stride;
+
+ // Strides for the output tensor.
+ // IMPORTANT: These strides are used to locate an element in a patch at a
+ // depth zero (channel), which is not quite the same as "traditional"
+ // stride.
+ m_rowStride = m_patch_planes;
+ m_colStride = m_patch_rows * m_rowStride;
+ m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
+ m_otherStride = m_patchStride * m_num_patches;
+
+ m_outputPlanes = tensor.impl().outputPlanes();
+ m_outputRows = tensor.impl().outputRows();
+ m_outputCols = tensor.impl().outputCols();
+
+ m_outputPlanesRows = m_outputPlanes * m_outputRows;
+
+ m_plane_strides = tensor.impl().userPlaneStride();
+ m_row_strides = tensor.impl().userRowStride();
+ m_col_strides = tensor.impl().userColStride();
+
+ m_in_plane_strides = tensor.impl().userInPlaneStride();
+ m_in_row_strides = tensor.impl().userInRowStride();
+ m_in_col_strides = tensor.impl().userInColStride();
+
+ m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
+ m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+ m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_inputDepth = tensor.impl().impl().dimensions()[0];
+ m_inputPlanes = tensor.impl().impl().dimensions()[1];
+ m_inputRows = tensor.impl().impl().dimensions()[2];
+ m_inputCols = tensor.impl().impl().dimensions()[3];
+ } else {
+ const int NumDims = tensor.impl().impl().dimensions().size();
+ m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
+ m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
+ m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
+ m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
+ }
+
+ // Strides for navigating through the input tensor.
+ m_planeInputStride = m_inputDepth;
+ m_rowInputStride = m_inputDepth * m_inputPlanes;
+ m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
+ m_patchInputStride =
+ m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
+
+ m_planePaddingTop = tensor.impl().planePaddingTop();
+ m_rowPaddingTop = tensor.impl().rowPaddingTop();
+ m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+ m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+
+ m_fastPatchPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_stride);
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
+
+ m_fastInputPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
+ m_fastInputRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+ m_fastInputColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+
+ m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
+ m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+
+ m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
+
+ m_fastOutputPlanesRows =
+ internal::TensorIntDivisor<Index>(m_outputPlanesRows);
+ }
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
+ : m_impl(base_mapper.m_impl) {
+ m_patch_depth = base_mapper.m_patch_depth;
+ m_patch_planes = base_mapper.m_patch_planes;
+ m_patch_rows = base_mapper.m_patch_rows;
+ m_patch_cols = base_mapper.m_patch_cols;
+ m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_plane_stride = base_mapper.m_patch_plane_stride;
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
+ m_rowStride = base_mapper.m_rowStride;
+ m_colStride = base_mapper.m_colStride;
+ m_patchStride = base_mapper.m_patchStride;
+ m_otherStride = base_mapper.m_otherStride;
+
+ m_planeInputStride = base_mapper.m_planeInputStride;
+ m_rowInputStride = base_mapper.m_rowInputStride;
+ m_colInputStride = base_mapper.m_colInputStride;
+ m_patchInputStride = base_mapper.m_patchInputStride;
+ m_otherInputStride = base_mapper.m_otherInputStride;
+
+ m_inputDepth = base_mapper.m_inputDepth;
+ m_inputPlanes = base_mapper.m_inputPlanes;
+ m_inputRows = base_mapper.m_inputRows;
+ m_inputCols = base_mapper.m_inputCols;
+
+ m_outputPlanes = base_mapper.m_outputPlanes;
+ m_outputRows = base_mapper.m_outputRows;
+ m_outputCols = base_mapper.m_outputCols;
+
+ m_plane_strides = base_mapper.m_plane_strides;
+ m_row_strides = base_mapper.m_row_strides;
+ m_col_strides = base_mapper.m_col_strides;
+
+ m_in_plane_strides = base_mapper.m_in_plane_strides;
+ m_in_row_strides = base_mapper.m_in_row_strides;
+ m_in_col_strides = base_mapper.m_in_col_strides;
+
+ m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
+ m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+ m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+ m_planePaddingTop = base_mapper.m_planePaddingTop;
+ m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+ m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+ m_outputPlanesRows = base_mapper.m_outputPlanesRows;
+
+ m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
+ m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
+ m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+ m_fastInputColStride = base_mapper.m_fastInputColStride;
+ m_fastRowStride = base_mapper.m_fastRowStride;
+ m_fastColStride = base_mapper.m_fastColStride;
+ m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
+ m_fastOutputRows = base_mapper.m_fastOutputRows;
+ m_fastOutputCols = base_mapper.m_fastOutputCols;
+ m_fastDimZero = base_mapper.m_fastDimZero;
+ m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
+ }
+
+ // If true, turns off some optimizations for loading packets since the image
+ // patches are "non-standard" such as there are non-trivial strides or
+ // inflations in the input.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
+ m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
+ m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+ return SubMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the coefficient at the patchIndex location instead of the usual
+ // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
+ // gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
+ return m_impl;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ private:
+ friend class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>;
+
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset * m_in_col_strides;
+ const Index origInputCol =
+ (m_patch_col_inflate_strides == 1)
+ ? inputCol
+ : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+ const Index origInputRow =
+ (m_patch_row_inflate_strides == 1)
+ ? inputRow
+ : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
+ const Index origInputPlane =
+ (m_patch_plane_inflate_strides == 1)
+ ? inputPlane
+ : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
+
+ if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
+ origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
+ origInputPlane >= m_inputPlanes ||
+ (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+ (inputRow != origInputRow * m_patch_row_inflate_strides) ||
+ (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + origInputPlane * m_planeInputStride +
+ origInputRow * m_rowInputStride +
+ origInputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ eigen_assert(!nonStandardPatches());
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
+ inputRow >= m_inputRows || inputPlane < 0 ||
+ inputPlane >= m_inputPlanes) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ if (nonStandardPatches()) {
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+ return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+ eigen_assert(!nonStandardPatches());
+
+ if ((patchDepth() % packetSize) == 0) {
+ return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ } else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
+ const Index patchOffsets[2] = {
+ patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+ const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+ patchOffsets[1] / m_fastColStride};
+ eigen_assert(colOffsets[0] <= colOffsets[1]);
+
+ const Index inputCols[2] = {colIndex + colOffsets[0],
+ colIndex + colOffsets[1]};
+ if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputCols[0] == inputCols[1]) {
+ const Index rowOffsets[2] = {
+ (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
+ (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
+ eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+ const Index inputRows[2] = {rowIndex + rowOffsets[0],
+ rowIndex + rowOffsets[1]};
+
+ if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputRows[0] == inputRows[1]) {
+ const Index planeOffsets[2] = {
+ patchOffsets[0] - colOffsets[0] * m_colStride -
+ rowOffsets[0] * m_rowStride,
+ patchOffsets[1] - colOffsets[1] * m_colStride -
+ rowOffsets[1] * m_rowStride};
+ eigen_assert(planeOffsets[0] <= planeOffsets[1]);
+ const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
+ planeIndex + planeOffsets[1]};
+
+ if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
+ const Index depth = patchId - patchOffsets[0] * patchDepth();
+ const Index inputIndex =
+ depth + inputPlanes[0] * m_planeInputStride +
+ inputRows[0] * m_rowInputStride +
+ inputCols[0] * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ }
+ }
+ }
+
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ eigen_assert(!nonStandardPatches());
+ eigen_assert((patchDepth() % packetSize) == 0);
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+ eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
+ inputCol >= m_inputCols || inputRow >= m_inputRows ||
+ inputPlane >= m_inputPlanes) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
+ Index colIndex, Index otherIndex) const {
+ const int packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_MAX
+ typename internal::remove_const<Scalar>::type values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] =
+ loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+ Packet rslt = internal::pload<Packet>(values);
+ return rslt;
+ }
+
+ // Precompute the indices (plane, row, col, other) of the first element of
+ // the given patch index, within the output tensor of the TensorVolumePatchOp.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
+ Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
+ Index& otherIndex) const {
+ const size_t NumInputDims = array_size<
+ typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+
+ // Check if patchIndex might contain batch and other dimensions.
+ otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
+
+ // Compute index of the patch within the batch (and other dimensions).
+ const Index patch3DIndex = (NumInputDims == 4)
+ ? patchIndex
+ : (patchIndex - otherIndex * m_num_patches);
+
+ otherIndex *= m_patchInputStride;
+
+ colIndex = patch3DIndex / m_fastOutputPlanesRows;
+ rowIndex =
+ (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
+ planeIndex =
+ patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
+
+ colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+ rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+ planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
+ }
+
+ Index m_patch_depth; // number of channels in the patch
+ Index m_patch_planes; // number of planes in the patch
+ Index m_patch_rows; // number of rows in the patch
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract
+
+ // Strides for navigating through the single patch.
+ Index m_patch_plane_stride;
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+
+ // Strides for the output tensor (depth is not the part of the stride).
+ Index m_rowStride;
+ Index m_colStride;
+ Index m_patchStride;
+ Index m_otherStride;
+
+ Index m_planeInputStride; // Plane stride in the input tensor
+ Index m_rowInputStride; // Row stride in the input tensor
+ Index m_colInputStride; // Col stride in the input tensor
+ Index m_patchInputStride; // Patch stride in the input tensor
+ Index m_otherInputStride;
+
+ Index m_inputDepth; // Depth of the input tensor
+ Index m_inputPlanes; // Number of planes in the input tensor
+ Index m_inputRows; // Number of rows in the input tensor
+ Index m_inputCols; // Number of cols in the input tensor
+
+ Index m_outputPlanes; // Number of output planes
+ Index m_outputRows; // Number of output rows
+ Index m_outputCols; // Number of output cols
+ Index m_outputPlanesRows; // Cached outputPlanes * outputRows.
+
+ Index m_plane_strides; // User specified plane stride
+ Index m_row_strides; // User specified row stride
+ Index m_col_strides; // User specified col stride
+
+ // User specified plane/row/col atrous convolution strides.
+ Index m_in_plane_strides;
+ Index m_in_row_strides;
+ Index m_in_col_strides;
+
+ // User specified plane/row/col inflation strides in the image patch.
+ Index m_patch_plane_inflate_strides;
+ Index m_patch_row_inflate_strides;
+ Index m_patch_col_inflate_strides;
+
+ Index m_planePaddingTop; // Plane padding
+ Index m_rowPaddingTop; // Row padding
+ Index m_colPaddingLeft; // Column padding
+
+ // Fast representation of various divisors.
+ internal::TensorIntDivisor<Index> m_fastNumPatches;
+
+ internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
+ internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastInputRowStride;
+ internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+ internal::TensorIntDivisor<Index> m_fastRowStride;
+ internal::TensorIntDivisor<Index> m_fastColStride;
+
+ internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth
+ internal::TensorIntDivisor<Index> m_fastOutputPlanes;
+ internal::TensorIntDivisor<Index> m_fastOutputRows;
+ internal::TensorIntDivisor<Index> m_fastOutputCols;
+ internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
+
+ const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename packet_traits<Scalar>::half HalfPacket;
+
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ ParentMapper;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef Self LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper),
+ m_depth_offset(vert_offset),
+ m_col_offset(horiz_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const Self& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper.m_base_mapper),
+ m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+ m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+ return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
+ Index j) const {
+ return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
+ Index j) const {
+ return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
+ j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
+ loadCoeffStandard(Index i) const {
+ return m_base_mapper.loadCoeffStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+ return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ loadPacketStandard(Index i) const {
+ return m_base_mapper.loadPacketStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC bool aligned(Index) const {
+ return false;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_base_mapper.nonStandardPatches();
+ }
+
+ // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
+ // plane and depth index respectively that fits into the peeled_k elements
+ // starting at m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
+ const Index row) const {
+ const Index max_plane = fastPatchPlaneStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride() -
+ row * patchRowStride());
+ return std::min<Index>(1 + max_plane, patchPlanes());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input plane
+ // stride. Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const {
+ eigen_assert(m_base_mapper.m_patch_depth ==
+ m_base_mapper.m_planeInputStride &&
+ "Patch depth must be equal to plane input stride.");
+ return m_base_mapper.m_planeInputStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const {
+ eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
+ "Patch planes must be equal to row stride.");
+ return m_base_mapper.m_rowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const {
+ return m_base_mapper.m_patch_rows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const {
+ return m_base_mapper.m_patch_cols;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ return m_base_mapper.m_patch_row_stride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ return m_base_mapper.m_fastPatchRowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
+ const Index p = m_planeIndex + plane;
+ return p < 0 || p >= m_base_mapper.m_inputPlanes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+ const Index r = m_rowIndex + row;
+ return r < 0 || r >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+ const Index c = m_colIndex + col;
+ return c < 0 || c >= m_base_mapper.m_inputCols;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
+ const Index col) const {
+ const Index p = m_planeIndex + plane;
+ const Index r = m_rowIndex + row;
+ const Index c = m_colIndex + col;
+ return p * m_base_mapper.m_planeInputStride +
+ r * m_base_mapper.m_rowInputStride +
+ c * m_base_mapper.m_colInputStride + m_otherIndex;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index planeOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ const Index planeOffset = patchOffset -
+ colOffset * m_base_mapper.m_colStride -
+ rowOffset * m_base_mapper.m_rowStride;
+ return planeOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index rowOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ return rowOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index colOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return colOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index depthOffset() const {
+ return m_depth_offset % patchDepth();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
+ getLinearMapper(Index i, Index j) const {
+ return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+ }
+
+ private:
+ const ParentMapper& m_base_mapper;
+ Index m_depth_offset; // First row in the input matrix
+ Index m_col_offset; // First col in the input matrix
+
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_planeIndex;
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
+};
+
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted volume patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment, int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+
+ typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if ((packet_size % 4) == 0 && !non_standard_patches) {
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
+
+ const bool pad_row0 = pad_col0 || dm0.padRow(r);
+ const bool pad_row1 = pad_col1 || dm1.padRow(r);
+ const bool pad_row2 = pad_col2 || dm2.padRow(r);
+ const bool pad_row3 = pad_col3 || dm3.padRow(r);
+
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
+
+ const bool pad0 = pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
+ } else {
+ // Packet can span multiple planes, rows or columns, so we have to go
+ // though the slower "standard" path.
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketStandard(k);
+ kernel.packet[1] = dm1.loadPacketStandard(k);
+ kernel.packet[2] = dm2.loadPacketStandard(k);
+ kernel.packet[3] = dm3.loadPacketStandard(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const int packet_size = 2;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if (!non_standard_patches) {
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketStandard(k);
+ kernel0.packet[1] = dm1.loadPacketStandard(k);
+ kernel1.packet[0] = dm2.loadPacketStandard(k);
+ kernel1.packet[1] = dm3.loadPacketStandard(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Special case for non-vectorized types such as float16 (packet_size = 1).
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ if (!rhs.nonStandardPatches()) {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+} // namespace internal
+
/** CuboidConvolution
* \ingroup CXX11_NeuralNetworks_Module
*
@@ -98,7 +1504,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
// Spatial size of the kernel.
- const TensorIndex kernelDepth =
+ const TensorIndex kernelPlanes =
isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
const TensorIndex kernelRows =
isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
@@ -118,27 +1524,27 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
const TensorIndex inputCols =
isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
- TensorIndex out_depth;
+ TensorIndex out_planes;
TensorIndex out_height;
TensorIndex out_width;
switch (padding_type) {
case PADDING_VALID:
- out_depth = Eigen::divup(inputPlanes - kernelDepth + 1,
- static_cast<TensorIndex>(stridePlanes));
+ out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
+ static_cast<TensorIndex>(stridePlanes));
out_height = Eigen::divup(inputRows - kernelRows + 1,
static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols - kernelCols + 1,
static_cast<TensorIndex>(strideCols));
break;
case PADDING_SAME:
- out_depth =
+ out_planes =
Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
out_height =
Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
break;
default:
- out_depth = 0;
+ out_planes = 0;
out_height = 0;
out_width = 0;
eigen_assert(false && "unexpected padding");
@@ -147,9 +1553,9 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> kernel_dims;
if (isColMajor) {
kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
} else {
- kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelFilters;
}
@@ -160,15 +1566,15 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
pre_contract_dims[0] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[1] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_planes * out_height * out_width;
for (int i = 4; i < NumDims; ++i) {
pre_contract_dims[1] *= in.dimension(i);
}
} else {
pre_contract_dims[1] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[0] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_planes * out_height * out_width;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= in.dimension(i);
}
@@ -187,7 +1593,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
post_contract_dims[0] = kernelFilters;
- post_contract_dims[1] = out_depth;
+ post_contract_dims[1] = out_planes;
post_contract_dims[2] = out_height;
post_contract_dims[3] = out_width;
for (int i = 4; i < NumDims; ++i) {
@@ -195,7 +1601,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
}
} else {
post_contract_dims[NumDims - 1] = kernelFilters;
- post_contract_dims[NumDims - 2] = out_depth;
+ post_contract_dims[NumDims - 2] = out_planes;
post_contract_dims[NumDims - 3] = out_height;
post_contract_dims[NumDims - 4] = out_width;
for (int i = 0; i < NumDims - 4; ++i) {
@@ -208,13 +1614,13 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
kernel.reshape(kernel_dims)
.contract(input
.extract_volume_patches(
- kernelDepth, kernelRows, kernelCols, stridePlanes,
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
strideRows, strideCols, padding_type)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
input
- .extract_volume_patches(kernelDepth, kernelRows, kernelCols,
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
stridePlanes, strideRows, strideCols,
padding_type)
.reshape(pre_contract_dims)
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index a4dff4b91c..e926d73f87 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -22,8 +22,36 @@ namespace Eigen {
namespace internal {
-// TODO: Consolidate this part of the code with the image patch extraction code
-// since they are both very similar.
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract image patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelRows * kernelCols;
+// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+// TODO(ezhulenev): Consolidate this part of the code with the image patch
+// extraction code since they are both very similar.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar_, typename Index,
typename nocontract_t, typename contract_t, int Side, int packet_size,
@@ -77,12 +105,17 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3];
} else {
- const int NumDims = tensor.impl().dimensions().size();
+ const size_t NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
m_num_patches = tensor.impl().dimensions()[NumDims - 4];
}
+
+ // Strides for navigating through the single patch.
+ m_patch_row_stride = patch_depth;
+ m_patch_col_stride = patch_rows * m_patch_row_stride;
+
m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
m_patch_col_inflate_strides = tensor.impl().colInflateStride();
@@ -111,6 +144,10 @@ class TensorContractionInputMapper<
m_rowPaddingTop = tensor.impl().rowPaddingTop();
m_colPaddingLeft = tensor.impl().colPaddingLeft();
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
m_fastInputRowStride =
internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
m_fastInputColStride =
@@ -126,6 +163,10 @@ class TensorContractionInputMapper<
: m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
@@ -148,6 +189,8 @@ class TensorContractionInputMapper<
m_rowPaddingTop = base_mapper.m_rowPaddingTop;
m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
m_fastNumPatches = base_mapper.m_fastNumPatches;
@@ -238,6 +281,8 @@ class TensorContractionInputMapper<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>;
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
Index colIndex, Index otherIndex) const {
@@ -250,6 +295,7 @@ class TensorContractionInputMapper<
(m_patch_col_inflate_strides == 1)
? inputCol
: ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
const Index rowOffset = patchOffset - colOffset * m_colStride;
const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
const Index origInputRow =
@@ -268,6 +314,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
Index colIndex,
@@ -276,10 +324,9 @@ class TensorContractionInputMapper<
// Find the offset of the element wrt the location of the first element.
const Index patchOffset = patchId / m_fastDimZero;
-
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
inputRow >= m_inputRows) {
@@ -291,6 +338,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
Index colIndex,
@@ -318,12 +367,14 @@ class TensorContractionInputMapper<
if ((patchDepth() % packetSize) == 0) {
return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
} else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
const Index patchOffsets[2] = {
patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
patchOffsets[1] / m_fastColStride};
-
const Index inputCols[2] = {colIndex + colOffsets[0],
colIndex + colOffsets[1]};
if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
@@ -371,8 +422,8 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
inputRow >= m_inputRows) {
@@ -401,7 +452,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
const Index patch2DIndex = (NumInputDims == 3)
@@ -414,8 +465,15 @@ class TensorContractionInputMapper<
rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
}
- Index m_patch_cols; // number of colums in the patch
- Index m_num_patches; // number of patches to extract.
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract.
+
+ // Strides for navigating through the single patch.
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
Index m_patch_row_inflate_strides; // the strides for row inflation in the
// image patch
Index m_patch_col_inflate_strides; // the strides for col inflation in the
@@ -549,6 +607,40 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
+ // index respectively that fits into the peeled_k elements starting at
+ // m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input row stride.
+ // Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
return m_base_mapper.m_rowInputStride;
@@ -563,6 +655,28 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -603,8 +717,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -617,12 +730,44 @@ class TensorContractionSubMapper<
Index m_depth_offset; // First row in the input matrix
Index m_col_offset; // First col in the input matrix
- Index m_rowIndex; // precomputed row index corresponding to the col offset
- Index m_colIndex; // precomputed col index corresponding to the col offset
- Index
- m_otherIndex; // precomputed other index corresponding to the col offset
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
};
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted image patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar, typename Index,
typename nocontract_t, typename contract_t, int packet_size,
@@ -649,9 +794,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -660,9 +805,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -675,30 +817,27 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows, if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -709,14 +848,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -738,19 +876,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
@@ -767,6 +895,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -824,9 +954,9 @@ struct gemm_pack_rhs<
Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -835,9 +965,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
@@ -851,30 +978,27 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -885,14 +1009,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -916,22 +1039,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -949,7 +1062,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -968,7 +1083,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1006,8 +1121,7 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1016,8 +1130,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1045,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h
index a3d795813d..80ab745bfe 100644
--- a/tensorflow/core/kernels/eigen_volume_patch.h
+++ b/tensorflow/core/kernels/eigen_volume_patch.h
@@ -43,6 +43,7 @@ struct CustomTensorEvaluator {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = NumDims == 6,
RawAccess = false
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc
new file mode 100644
index 0000000000..52cd078a35
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/*
+See extract_image_patches_op* files and docs for extract_image_patches in
+../ops/image_ops.cc.
+
+Rates are not supported as of now, but the comments hint how to edit the code
+when rates are to be added.
+*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+static inline void ParseAttributeVec5(OpKernelConstruction* context,
+ const string& attr_name,
+ std::vector<int32>* attr) {
+ OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
+ OP_REQUIRES(
+ context, (*attr)[0] == 1 && (*attr)[4] == 1,
+ errors::Unimplemented("Only support ", attr_name, " across space."));
+ OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1,
+ errors::OutOfRange(attr_name, " is out of range."));
+}
+
+template <typename Device, typename T>
+class ExtractVolumePatchesOp : public UnaryOp<T> {
+ public:
+ explicit ExtractVolumePatchesOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ ParseAttributeVec5(context, "ksizes", &ksizes_);
+ ParseAttributeVec5(context, "strides", &strides_);
+ // ParseAttributeVec5(context, "rates", &rates_);
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_planes, in_rows, in_cols, channels ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input.shape().DebugString()));
+
+ const int batch = input.dim_size(0);
+ const int in_planes = input.dim_size(1);
+ const int in_rows = input.dim_size(2);
+ const int in_cols = input.dim_size(3);
+ const int depth = input.dim_size(4);
+
+ const int ksize_planes = ksizes_[1];
+ const int ksize_rows = ksizes_[2];
+ const int ksize_cols = ksizes_[3];
+
+ const int stride_planes = strides_[1];
+ const int stride_rows = strides_[2];
+ const int stride_cols = strides_[3];
+
+ /*
+ // TODO(hsgkim): enable rates
+ // Rates are disabled as of now due to Eigen's definitions of
+ // `extract_volume_patch` functions; none of them accept rates
+ // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
+ // workaround has to be found for this.
+ // In order to enable rates, uncomment the following lines and use
+ // ksize_*_eff instead of ksize_* for the second argument of
+ // GetWindowedOutputSize calls.
+
+ const int rate_planes = rates_[1];
+ const int rate_rows = rates_[2];
+ const int rate_cols = rates_[3];
+
+ const int ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ int64 out_planes = 0, out_rows = 0, out_cols = 0;
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_planes, ksize_planes, stride_planes,
+ padding_, &out_planes, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_rows, ksize_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ const std::vector<int64> out_sizes = {
+ batch, out_planes, out_rows, out_cols,
+ ksize_planes * ksize_rows * ksize_cols * depth};
+ TensorShape out_shape(out_sizes);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ functor::ExtractVolumePatchesForward<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
+ ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
+ /* rate_planes, rate_rows, rate_cols, */
+ BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> ksizes_;
+ std::vector<int32> strides_;
+ // std::vector<int32> rates_;
+
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp);
+};
+
+// Registration of the CPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<CPUDevice, T>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+// clang-format off
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
+ int patch_planes, int patch_rows, int patch_cols, \
+ int stride_planes, int stride_rows, int stride_cols, \
+ /* int rate_planes, int rate_rows, int rate_cols, */ \
+ const Eigen::PaddingType& padding, \
+ typename TTypes<T, 5>::Tensor output); \
+ extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
+// clang-format on
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<GPUDevice, T>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h
new file mode 100644
index 0000000000..7e0502b770
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_volume_patch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct ExtractVolumePatchesForward {
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int patch_planes, int patch_rows, int patch_cols,
+ int stride_planes, int stride_rows, int stride_cols,
+ /* int rate_planes, int rate_rows, int rate_cols, */
+ const Eigen::PaddingType& padding,
+ typename TTypes<T, 5>::Tensor output) {
+ const int64 N = std::max(input.size(), output.size());
+ if (N <= std::numeric_limits<Index32>::max()) {
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
+ To32Bit(input)
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output_32bit.dimensions());
+ } else {
+ output.device(d) =
+ input
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output.dimensions());
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
new file mode 100644
index 0000000000..c636493602
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 8bfa40304e..f2e0b2558f 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing")
tf_ops_fuzz_target_lib("parse_tensor_op")
+tf_ops_fuzz_target_lib("decode_compressed")
+
tf_ops_fuzz_target_lib("decode_json_example")
diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
new file mode 100644
index 0000000000..0a56f4b63f
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeCompressed : public FuzzStringInputOp {
+ void BuildGraph(const Scope& scope) override {
+ auto input =
+ tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING);
+ auto d1 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d1"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType(""));
+ auto d2 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d2"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("ZLIB"));
+ auto d3 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d3"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("GZIP"));
+ Scope grouper =
+ scope.WithControlDependencies(std::vector<tensorflow::Operation>{
+ d1.output.op(), d2.output.op(), d3.output.op()});
+ (void)tensorflow::ops::NoOp(grouper.WithOpName("output"));
+ }
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed);
+
+} // namespace fuzzing
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index cd2873bdca..7710cf93d6 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0ddc..374a05850e 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index bca1cff41c..2088c13586 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -77,9 +77,9 @@ static Status TensorListDeviceCopy(
return Status::OK();
}
-#define REGISTER_LIST_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
+#define REGISTER_LIST_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
+ TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) {
return Status::OK();
}
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
- TensorListShape);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);
bool TensorList::Decode(const VariantTensorData& data) {
tensors = data.tensors();
@@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<CPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index c591226b76..a00bf700ca 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<GPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<GPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e9a7..1ded012f3c 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -90,6 +91,59 @@ class PrintOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f364..a259d995fa 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index cc4b6941b9..62aa7d5c29 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -103,6 +103,7 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>,
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = true,
RawAccess = false
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
new file mode 100644
index 0000000000..a055351337
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "tensorflow/core/util/mkl_util.h"
+#endif
+
+// TODO(ezhulenev): Add numerical tests that will compare results of default
+// (aka Eigen) convolutions with MKL convolutions.
+
+// -------------------------------------------------------------------------- //
+// Performance Benchmarks. //
+// -------------------------------------------------------------------------- //
+
+// Compare performance of default Tensorflow convolution kernels (Eigen) with
+// MKL kernels on CPU.
+
+// Before running these benchmarks configure OpenMP environment variables:
+// export KMP_BLOCKTIME=0
+// export OMP_NUM_THREADS=${num_threads}
+
+namespace tensorflow {
+
+struct Conv2DDimensions {
+ Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw)
+ : input_batches(n),
+ input_height(h),
+ input_width(w),
+ input_depth(c),
+ filter_count(fc),
+ filter_height(fh),
+ filter_width(fw) {}
+
+ int input_batches;
+ int input_height;
+ int input_width;
+ int input_depth;
+ int filter_count;
+ int filter_height;
+ int filter_width;
+};
+
+static Tensor GetRandomTensor(const TensorShape& shape) {
+ Tensor tensor(DT_FLOAT, TensorShape(shape));
+ tensor.flat<float>() = tensor.flat<float>().setRandom();
+ return tensor;
+}
+
+// Get a random Tensor for the Conv2D input.
+static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a random Tensor for the Conv2D filter.
+static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+// Get a random Tensor for the Conv2D output (assuming SAME padding).
+static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.filter_count});
+}
+
+// Get a Tensor encoding Conv2D input shape.
+static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a Tensor encoding Conv2D filter shape.
+static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Tensor NonMklTensor() {
+ MklDnnShape non_mkl_shape;
+ non_mkl_shape.SetMklTensor(false);
+
+ auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
+ Tensor tensor(DT_UINT8, {size});
+
+ non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
+ size * sizeof(uint8));
+ return tensor;
+}
+#endif
+
+static Graph* DefaultConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D")
+ .Input(input)
+ .Input(filter)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D")
+ .Input(input)
+ .Input(filter)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"),
+ "_MklConv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) {
+ Graph* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"),
+ "_MklConv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+#endif
+
+// Macro arguments names: --------------------------------------------------- //
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \
+ BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW)
+
+// Flops computation in these benchmarks are the same as in
+// eigen_benchmark_cpu_test.cc.
+
+#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (C); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+// ImageNet Convolutions ---------------------------------------------------- //
+
+BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788448..82dfece4a2 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
- // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
+ // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
// us using the same state in different batches.
gen_copy.Skip(start_row * (num_samples + 3) / 4);
random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 7bb403290d..fc1c9003aa 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -127,12 +127,12 @@ class PartitionedCallOp : public AsyncOpKernel {
optimization_options.graph = &graph;
optimization_options.flib_def = overlay_lib;
optimization_options.device_set = &device_set;
- Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(
ctx,
OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
done);
+ Placer placer(graph.get(), &device_set);
OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
OP_REQUIRES_OK_ASYNC(
ctx,
@@ -210,7 +210,7 @@ class PartitionedCallOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
DataType dtype = attr_value->type();
if (dtype == DT_RESOURCE) {
- ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0);
node->set_assigned_device_name(handle.device());
}
}
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92f94..272aa3b4f5 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@ class QueueBase : public QueueInterface {
// NOTE(mrry): This method is deprecated. Use
// `tensorflow::batch_util::CopySliceToElement()` defined in
// "./batch_util.h" instead.
+ ABSL_DEPRECATED(
+ "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+ "\"./batch_util.h\" instead.")
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int64 index);
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index c4d404259b..97ddc852f7 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
+ const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index e37232539f..04a53697c0 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
- // Verify that minval < maxval
+ // Allocate output, and exit early if possible
+ Tensor* output;
+ OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
+ if (output->NumElements() == 0) return;
+
+ // Verify that minval < maxval. This check intentionally happens after the
+ // early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
@@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
Distribution;
Distribution dist(lo, hi);
- Tensor* output;
- OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 88b3c2ac76..bb8254eaac 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -21,11 +21,11 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 9cf953f4bf..8bfa44b2d0 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
+
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index ebcfb673d1..26705a8d34 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -79,7 +79,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, 0);
+ const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
errors::FailedPrecondition(
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 15a707a9c6..cded417986 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
}
template <>
@@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<Tlen>();
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
+ if (!context->status().ok()) return;
const int input_dims = input.dims();
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e0194605ce..2f8aede427 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -145,6 +145,7 @@ class ScatterNdUpdateOp : public OpKernel {
if (dtype_ == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000000..dc627ac77a
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+ explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+ explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h
new file mode 100644
index 0000000000..f075bf0fa2
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, typename OutType>
+struct UpperBoundFunctor {
+ // Searches for values in sorted_inputs and returns the greatest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+
+template <typename Device, typename T, typename OutType>
+struct LowerBoundFunctor {
+ // Searches for values in sorted_inputs and returns the lowest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+} // namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
new file mode 100644
index 0000000000..263b5bf298
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+template <typename T, typename OutType>
+__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+
+template <typename T, typename OutType>
+__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+} // namespace
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ UpperBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ LowerBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index 9cd590ae61..30cb1e0a7f 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/abi.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
- GetShapeFromKnownVecSize);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize);
static void ExpectHasError(const Status& s, StringPiece substr) {
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
@@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) {
Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs);
EXPECT_FALSE(s.ok());
ExpectHasError(
- s,
- "No unary variant shape function found for Variant type_name: "
- "NO KNOWN SHAPE");
+ s, strings::StrCat(
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name())));
}
{
diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc
index 393818730b..a4a59dbcbc 100644
--- a/tensorflow/core/kernels/split_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc
@@ -54,6 +54,7 @@ void SplitCustom<Device, T>::operator()(
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 7cc3c532c9..11db72bfa3 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel {
void ComputeEasyCases(OpKernelContext* context, bool* done) {
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
- const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const Tensor& split_dim_tensor = context->input(0);
+ OP_REQUIRES(
+ context, split_dim_tensor.shape().dims() == 0,
+ errors::InvalidArgument("split_dim must be a scalar but has rank ",
+ split_dim_tensor.shape().dims()));
+ const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = num_outputs();
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 65296f61fd..add4afafc9 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -131,10 +131,8 @@ class Stack : public ResourceBase {
};
Status GetStack(OpKernelContext* ctx, Stack** stack) {
- string key;
if (ctx->input_dtype(0) == DT_RESOURCE) {
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
- key = resource.name();
+ return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
} else {
Tensor Tstack_handle = ctx->mutable_input(0, false);
if (Tstack_handle.NumElements() != 2) {
@@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
}
const string& container = Tstack_handle.flat<string>()(0);
const string& stack_name = Tstack_handle.flat<string>()(1);
- key = strings::StrCat(container, stack_name);
- }
- ResourceMgr* rm = ctx->resource_manager();
- if (rm == nullptr) {
- return errors::Internal("No resource manager.");
- }
- auto* step_container = ctx->step_container();
- if (step_container == nullptr) {
- return errors::Internal("No step container.");
+ string key = strings::StrCat(container, stack_name);
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) {
+ return errors::Internal("No resource manager.");
+ }
+ auto* step_container = ctx->step_container();
+ if (step_container == nullptr) {
+ return errors::Internal("No step container.");
+ }
+ TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
+ return Status::OK();
}
- TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
- return Status::OK();
}
std::atomic<int64> Stack::stack_counter{0};
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 7b537fef5b..f0575de4d9 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -306,6 +306,7 @@ class StridedSliceAssignOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000000..13130a5797
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
index a6829b29d9..435a7abdca 100644
--- a/tensorflow/core/kernels/string_length_op.cc
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -14,13 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/string_util.h"
namespace tensorflow {
namespace {
class StringLengthOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel {
auto src = input.flat<string>();
auto dst = output->flat<int32>();
- for (int n = 0; n < src.size(); ++n) {
- dst(n) = src(n).size();
+ switch (unit_) {
+ case CharUnit::BYTE:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ break;
+ case CharUnit::UTF8_CHAR:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = UTF8StrLen(src(n));
+ }
+ break;
}
}
+
+ private:
+ CharUnit unit_ = CharUnit::BYTE;
};
REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
new file mode 100644
index 0000000000..3a9803a052
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/string_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace {
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+} // namespace
+
+namespace tensorflow {
+
+// Sets unit value based on str.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
+ if (str == "UTF8") {
+ *encoding = UnicodeEncoding::UTF8;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid encoding \"", str, "\": Should be one of: BYTE"));
+ }
+ return Status::OK();
+}
+
+// Sets unit value based on str.
+Status ParseCharUnit(const string& str, CharUnit* unit) {
+ if (str == "BYTE") {
+ *unit = CharUnit::BYTE;
+ } else if (str == "UTF8_CHAR") {
+ *unit = CharUnit::UTF8_CHAR;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
+ }
+ return Status::OK();
+}
+
+// Return the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string) {
+ const int32 byte_size = string.size();
+ const char* const end = string.data() + byte_size;
+ const char* ptr = string.data();
+ int32 skipped_count = 0;
+ while (ptr < end) {
+ skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
+ }
+ const int32 result = byte_size - skipped_count;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
new file mode 100644
index 0000000000..390cf57702
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Enumeration for unicode encodings. Used by ops such as
+// tf.strings.unicode_encode and tf.strings.unicode_decode.
+// TODO(edloper): Add support for:
+// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
+enum class UnicodeEncoding { UTF8 };
+
+// Enumeration for character units. Used by string such as
+// tf.strings.length and tf.substr.
+// TODO(edloper): Add support for: UTF32_CHAR, etc.
+enum class CharUnit { BYTE, UTF8_CHAR };
+
+// Sets `encoding` based on `str`.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
+
+// Sets `unit` value based on `str`.
+Status ParseCharUnit(const string& str, CharUnit* unit);
+
+// Returns the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e45918a0..07f1d6e767 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdlib>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -25,6 +27,8 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -64,26 +68,28 @@ class SubstrOp : public OpKernel {
const T len =
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
} else {
// Perform Op element-wise with tensor pos/len
auto pos_flat = pos_tensor.flat<T>();
auto len_flat = len_tensor.flat<T>();
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
}
} else {
@@ -142,14 +148,16 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
+ StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
OP_REQUIRES(
- context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
+ context,
+ FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
errors::InvalidArgument("pos ", pos, " out of range for string",
"b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
break;
}
@@ -192,16 +200,18 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
+ StringPiece in(input_bcast(i, j));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument(
- "pos ", pos, " out of range for ", "string b'",
- in, "' at index (", i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
+ OP_REQUIRES(
+ context, FastBoundsCheck(std::abs(pos), in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (", i,
+ ", ", j, ")"));
+ StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ output(i, j).assign(sub_in.data(), sub_in.size());
}
}
break;
@@ -213,6 +223,16 @@ class SubstrOp : public OpKernel {
}
}
}
+
+ private:
+ // This adjusts the requested position. Note it does not perform any bound
+ // checks.
+ T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ if (pos_requested < 0) {
+ return s.size() + pos_requested;
+ }
+ return pos_requested;
+ }
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
new file mode 100644
index 0000000000..2e07050260
--- /dev/null
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor position(DT_INT32, TensorShape({}));
+ position.flat<int32>().setConstant(pos);
+ Tensor length(DT_INT32, TensorShape({}));
+ length.flat<int32>().setConstant(len);
+
+ TF_CHECK_OK(NodeBuilder("substr_op", "Substr")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, position))
+ .Input(test::graph::Constant(g, length))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_Substr(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
+ 256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 765467bc1e..0e6c0ddccc 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
}
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index e8dc4fad21..384a63e945 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 2ec2651c04..a97a71b344 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -259,6 +259,7 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -290,7 +291,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
}
} else {
container = "_tensor_arrays";
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ const auto& resource = ctx->input(0).flat<ResourceHandle>()(0);
if (StringPiece(resource.name()).substr(0, container.size()) !=
container) {
return errors::InvalidArgument("Wrong input container. ",
@@ -576,6 +577,7 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -1218,6 +1220,7 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5aa0..2fbe1fe7cb 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index d3c4f62071..83b83fcdb9 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -21,6 +21,7 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
if (ctx->input_dtype(input) == DT_RESOURCE) {
Var* var;
if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
+ core::ScopedUnref scoped_unref(var);
return var->mu();
} else {
ctx->CtxFailureWithWarning(
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
index 62e814ff77..8d839ba85a 100644
--- a/tensorflow/core/kernels/unravel_index_op.cc
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -97,10 +97,12 @@ class UnravelIndexOp : public OpKernel {
auto output = output_tensor->matrix<Tidx>();
- Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
- Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
- Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
- Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+ Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}};
+ Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()});
+ Eigen::array<Eigen::Index, 2> indices_reshape{
+ {1, indices_tensor.NumElements()}};
+ Eigen::array<Eigen::Index, 2> indices_bcast(
+ {dims_tensor.NumElements(), 1});
output = indices_tensor.vec<Tidx>()
.reshape(indices_reshape)
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8879d9dd4c..2255597651 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -21,10 +21,10 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"