aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar AG Ramesh <ag.ramesh@intel.com>2018-09-27 12:24:49 -0700
committerGravatar GitHub <noreply@github.com>2018-09-27 12:24:49 -0700
commit0136c7307f036290fa3ca308c1a9c67c053d903f (patch)
treec7f601ed59d079d5fef55c57f1664e86e367893b /tensorflow/core/kernels
parent268bf6b118646c8e93162d591263bca907c7db28 (diff)
parentd2a674a959101c35b8cf65c79a603baa16936805 (diff)
Merge branch 'master' into agramesh/fix_mkl_slice
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD46
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc9
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD20
-rw-r--r--tensorflow/core/kernels/bincount_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto13
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc26
-rw-r--r--tensorflow/core/kernels/conv_ops.cc321
-rw-r--r--tensorflow/core/kernels/conv_ops.h44
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_xdivy.cc38
-rw-r--r--tensorflow/core/kernels/cwise_op_xlogy.cc41
-rw-r--r--tensorflow/core/kernels/cwise_ops.h45
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/data/BUILD15
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc37
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h10
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc37
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc633
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc8
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc16
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc14
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h432
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h342
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.cc197
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.h58
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc32
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/histogram_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/matmul_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_slice_op.cc358
-rw-r--r--tensorflow/core/kernels/random_op.cc10
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h10
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc60
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h10
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc1
-rw-r--r--tensorflow/core/kernels/split_lib_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/string_length_op.cc23
-rw-r--r--tensorflow/core/kernels/string_util.cc63
-rw-r--r--tensorflow/core/kernels/string_util.h45
-rw-r--r--tensorflow/core/kernels/tensor_array.cc3
-rw-r--r--tensorflow/core/kernels/tensor_array.h3
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc3
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc6
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc44
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h37
-rw-r--r--tensorflow/core/kernels/training_ops.cc8
-rw-r--r--tensorflow/core/kernels/transpose_op.cc10
-rw-r--r--tensorflow/core/kernels/unicode_script_op.cc53
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h8
58 files changed, 2772 insertions, 568 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 08245e6ea0..0b8e9ec527 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -217,6 +217,19 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "extract_volume_patches_op",
+ prefix = "extract_volume_patches_op",
+ deps = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":ops_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
@@ -622,6 +635,7 @@ cc_library(
":diag_op",
":edit_distance_op",
":extract_image_patches_op",
+ ":extract_volume_patches_op",
":gather_nd_op",
":gather_op",
":guarantee_const_op",
@@ -2014,8 +2028,8 @@ tf_kernel_library(
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:resource_variable_ops_op_lib",
- "//third_party/eigen3",
],
)
@@ -4417,11 +4431,20 @@ cc_library(
":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
+ ":unicode_script_op",
],
)
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
STRING_DEPS = [
":bounds_check",
+ ":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -5152,6 +5175,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
+ "string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@@ -5231,6 +5255,8 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
+ "cwise_op_xlogy.cc",
+ "cwise_op_xdivy.cc",
"data_format_ops.cc",
"decode_wav_op.cc",
"deep_conv2d.cc",
@@ -5320,6 +5346,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
+ "string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
@@ -5445,6 +5472,7 @@ filegroup(
"batch_kernels.*",
"regex_full_match_op.cc",
"regex_replace_op.cc",
+ "unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@@ -6390,6 +6418,12 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_slice_op",
+ prefix = "mkl_slice_op",
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
+tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
deps = ARRAY_DEPS + mkl_deps(),
@@ -6533,6 +6567,16 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "unicode_script_op",
+ srcs = ["unicode_script_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:string_ops_op_lib",
+ "@icu//:common",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe63..f48bd0c318 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c70..25ae795d8e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 792eb74e31..039b0db144 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -1,7 +1,7 @@
# Description: Utilities.
package(
- default_visibility = ["//tensorflow:internal"],
+ default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
@@ -12,7 +12,11 @@ cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
@@ -21,7 +25,11 @@ cc_library(
cc_library(
name = "periodic_function",
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
":periodic_function_dynamic",
"//tensorflow/core:lib",
@@ -190,7 +198,11 @@ cc_library(
testonly = 1,
srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
+ visibility = [
+ "//learning/serving:__subpackages__",
+ "//tensorflow:internal",
+ "//tensorflow_serving:__subpackages__",
+ ],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e1f6..7d09e9b820 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index c9664f0c1c..1ab72af059 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -11,6 +11,7 @@ message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
+ CategoricalSplit categorical_split = 3;
}
NodeMetadata metadata = 777;
}
@@ -57,6 +58,18 @@ message BucketizedSplit {
int32 right_id = 4;
}
+message CategoricalSplit {
+ // Categorical feature column and split describing the rule feature value ==
+ // value.
+ int32 feature_id = 1;
+ int32 value = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index cc90bb2f45..2798722536 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node(
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- const auto& split = node.bucketized_split();
- if (bucketized_features[split.feature_id()](index_in_batch) <=
- split.threshold()) {
- return split.left_id();
- } else {
- return split.right_id();
+
+ switch (node.node_case()) {
+ case boosted_trees::Node::kBucketizedSplit: {
+ const auto& split = node.bucketized_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold())
+ ? split.left_id()
+ : split.right_id();
+ }
+ case boosted_trees::Node::kCategoricalSplit: {
+ const auto& split = node.categorical_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) ==
+ split.value())
+ ? split.left_id()
+ : split.right_id();
+ }
+ default:
+ DCHECK(false) << "Node type " << node.node_case() << " not supported.";
}
+ return -1;
}
float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 717a9f40a9..78856c4a99 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> {
};
#endif
+#define TF_REQUIRES(EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
+ } while (false)
+
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params) {
+ TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
+ TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
+ TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
+ string data_format_string;
+ TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
+ TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
+ errors::InvalidArgument("Invalid data format"));
+
+ const auto& strides = params->strides;
+ const auto& dilations = params->dilations;
+ const auto& data_format = params->data_format;
+
+ TF_REQUIRES(dilations.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ TF_REQUIRES(strides.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides, data_format, 'N');
+ const int64 stride_c = GetTensorDim(strides, data_format, 'C');
+ const int64 stride_h = GetTensorDim(strides, data_format, 'H');
+ const int64 stride_w = GetTensorDim(strides, data_format, 'W');
+ TF_REQUIRES(
+ stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ TF_REQUIRES(stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
+ const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
+ const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
+ const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
+ TF_REQUIRES(
+ dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ TF_REQUIRES(
+ dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ return Status::OK();
+}
+
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions) {
+ // Check that 2D convolution input and filter have exactly 4 dimensions.
+ TF_REQUIRES(input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ TF_REQUIRES(filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+ for (int i = 0; i < 3; i++) {
+ TF_REQUIRES(
+ FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. Check that it is the same as the
+ // filter's in_depth or it is evenly divisible by filter's in_depth.
+ const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
+ const int64 patch_depth_raw = filter.dim_size(2);
+ TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input depth too large"));
+ TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Patch depth too large"));
+ const int in_depth = static_cast<int>(in_depth_raw);
+ const int patch_depth = static_cast<int>(patch_depth_raw);
+ TF_REQUIRES(in_depth % patch_depth == 0,
+ errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
+ TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int input_rows = static_cast<int>(input_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
+ TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int input_cols = static_cast<int>(input_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
+ TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // Take the stride and dilation from the second and third dimensions only (we
+ // do not support striding or dilation on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
+ const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
+ const int dilation_rows =
+ GetTensorDim(params.dilations, params.data_format, 'H');
+ const int dilation_cols =
+ GetTensorDim(params.dilations, params.data_format, 'W');
+
+ // Compute windowed output sizes for rows and columns.
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
+ &out_rows, &pad_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
+ &out_cols, &pad_cols));
+
+ dimensions->batch = batch;
+ dimensions->input_rows = input_rows;
+ dimensions->input_cols = input_cols;
+ dimensions->in_depth = in_depth;
+ dimensions->filter_rows = filter_rows;
+ dimensions->filter_cols = filter_cols;
+ dimensions->patch_depth = patch_depth;
+ dimensions->out_depth = out_depth;
+ dimensions->stride_rows = stride_rows;
+ dimensions->stride_cols = stride_cols;
+ dimensions->dilation_rows = dilation_rows;
+ dimensions->dilation_cols = dilation_cols;
+ dimensions->out_rows = out_rows;
+ dimensions->out_cols = out_cols;
+ dimensions->pad_rows = pad_rows;
+ dimensions->pad_cols = pad_cols;
+
+ return Status::OK();
+}
+
+#undef TF_REQUIRES
+
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
+
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- OP_REQUIRES(context, strides_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
- const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
- const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
- const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
- const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
- OP_REQUIRES(
- context, stride_n == 1 && stride_c == 1,
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
- errors::InvalidArgument(
- "Row and column strides should be larger than 0."));
-
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
-
const Tensor& input = context->input(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
const Tensor& filter = context->input(1);
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
- }
+ Conv2DDimensions dimensions;
+ OP_REQUIRES_OK(context,
+ ComputeConv2DDimension(params_, input, filter, &dimensions));
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth or be evenly divisible by filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- const int64 patch_depth = filter.dim_size(2);
- OP_REQUIRES(context, in_depth % patch_depth == 0,
- errors::InvalidArgument(
- "input depth must be evenly divisible by filter depth: ",
- in_depth, " vs ", patch_depth));
-
- // The last dimension for filter is out_depth.
- const int out_depth = static_cast<int>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
- const int input_rows = static_cast<int>(input_rows_raw);
- const int filter_rows = static_cast<int>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
- const int input_cols = static_cast<int>(input_cols_raw);
- const int filter_cols = static_cast<int>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(context,
- FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
- const int batch = static_cast<int>(batch_raw);
-
- // For now we take the stride and dilation from the second and third
- // dimensions only (we do not support striding or dilation on the batch or
- // depth dimension).
- const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
-
- const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
- const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_rows, filter_rows, dilation_rows,
- stride_rows, padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_cols, filter_cols, dilation_cols,
- stride_cols, padding_, &out_cols, &pad_cols));
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(
+ params_.data_format, dimensions.batch, dimensions.out_rows,
+ dimensions.out_cols, dimensions.out_depth);
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "Conv2D: in_depth = " << in_depth
- << ", patch_depth = " << patch_depth
- << ", input_cols = " << input_cols
- << ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
- << ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", dilation_rows = " << dilation_rows
- << ", dilation_cols = " << dilation_cols
- << ", out_depth = " << out_depth;
+ VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
+ << ", patch_depth = " << dimensions.patch_depth
+ << ", input_cols = " << dimensions.input_cols
+ << ", filter_cols = " << dimensions.filter_cols
+ << ", input_rows = " << dimensions.input_rows
+ << ", filter_rows = " << dimensions.filter_rows
+ << ", stride_rows = " << dimensions.stride_rows
+ << ", stride_cols = " << dimensions.stride_cols
+ << ", dilation_rows = " << dimensions.dilation_rows
+ << ", dilation_cols = " << dimensions.dilation_cols
+ << ", out_depth = " << dimensions.out_depth;
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
@@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> {
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
if (LaunchXsmmConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
#endif
if (LaunchDeepConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
- output, data_format_);
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, params_.padding,
+ output, params_.data_format);
}
private:
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
+ Conv2DParameters params_;
bool use_cudnn_;
- Padding padding_;
- TensorFormat data_format_;
- LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
+ LaunchConv2DOp<Device, T> launcher_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
};
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index adf4601b43..7ec878e0b2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
+// Convolution parameters specified by Op attributes.
+struct Conv2DParameters {
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Convolution dimensions inferred from parameters, input and filter tensors.
+struct Conv2DDimensions {
+ int batch;
+ int input_rows;
+ int input_cols;
+ int in_depth;
+
+ int filter_rows;
+ int filter_cols;
+ int patch_depth;
+ int out_depth;
+
+ int stride_rows;
+ int stride_cols;
+
+ int dilation_rows;
+ int dilation_cols;
+
+ int64 out_rows;
+ int64 out_cols;
+ int64 pad_rows;
+ int64 pad_cols;
+};
+
+// Initializes and validates Conv2D parameters configured by OpKernel
+// attributes.
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params);
+
+// Computes and validates convolutions dimensions from Conv2D parameters. If
+// parameters are valid, dimensions will be updated with derived convolution
+// dimensions, otherwise error will be returned.
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
new file mode 100644
index 0000000000..e4b21a66c6
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
new file mode 100644
index 0000000000..1e1b5a426e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc
new file mode 100644
index 0000000000..6a6aec5e86
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xdivy.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xdivy<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc
new file mode 100644
index 0000000000..e71a9109b2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xlogy.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xlogy<TYPE>>);
+REGISTER_SYCL_KERNEL(Eigen::half);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+REGISTER_SYCL_KERNEL(complex64);
+REGISTER_SYCL_KERNEL(complex128);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 22eb66e979..66ba827a90 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -471,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> {
enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
};
+// TODO(srvasude): Add packet versions of this operation.
+template <typename Scalar>
+struct xlogy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x * numext::log(y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xlogy_op<Scalar>> {
+ enum {
+ Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost,
+ PacketAccess = false
+ };
+};
+
+template <typename Scalar>
+// TODO(srvasude): Add packet versions of this operation.
+struct xdivy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x / y;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xdivy_op<Scalar>> {
+ enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -830,6 +869,12 @@ struct squared_difference
Eigen::internal::scalar_difference_op<T>>> {};
template <typename T>
+struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
+
+template <typename T>
+struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
+
+template <typename T>
struct less : base<T, Eigen::internal::less<T>, bool> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb..8ad3b4d1fc 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif // !INTEL_MKL || !ENABLE_MKL
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index b3c359010d..87efdff789 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -628,6 +628,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.cc"],
+ deps = [
+ ":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
@@ -722,6 +736,7 @@ tf_kernel_library(
":map_dataset_op",
":map_defun_op",
":model_dataset_op",
+ ":multi_device_iterator_ops",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e7ac368ae3..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -44,5 +44,42 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 234856ea39..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -27,6 +27,16 @@ Status MakeIteratorFromInputElement(
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
+// Returns Status::OK() if `expected` and `received` types match,
+// errors::InvalidArgument otherwise.
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received);
+
+// Returns Status::OK() if `expected` and `received` shapes are compatible,
+// errors::InvalidArgument otherwise.
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received);
+
} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 71a36314a0..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -86,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
@@ -96,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ initialized_ = true;
+ }
+
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -123,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 30c6585ba2..c0bc507ec0 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -44,43 +44,6 @@ namespace {
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
} // namespace
class IteratorResource : public ResourceBase {
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
new file mode 100644
index 0000000000..5f143967d9
--- /dev/null
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -0,0 +1,633 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct HostBufferElement {
+ Status status;
+ bool end_of_sequence;
+ std::vector<Tensor> value;
+};
+
+using MultiDeviceIteratorCallback =
+ std::function<void(const HostBufferElement&)>;
+
+class MultiDeviceIterator : public ResourceBase {
+ public:
+ MultiDeviceIterator(const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ const std::vector<string>& devices,
+ std::unique_ptr<FunctionLibraryDefinition> flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ FunctionLibraryRuntime* lib)
+ : output_types_(output_types),
+ output_shapes_(output_shapes),
+ devices_(devices),
+ flib_def_(std::move(flib_def)),
+ pflr_(std::move(pflr)),
+ lib_(lib) {
+ DCHECK(lib_ != nullptr);
+ }
+
+ string DebugString() override {
+ return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
+ " devices");
+ }
+
+ Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
+ int64* incarnation_id) {
+ if (iterator) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, iterator->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
+ }
+
+ mutex_lock l(mu_);
+ if (multi_device_buffer_) {
+ multi_device_buffer_->Reset();
+ }
+
+ ++incarnation_id_;
+ *incarnation_id = incarnation_id_;
+
+ multi_device_buffer_.reset(
+ new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
+ std::move(iterator)));
+ return Status::OK();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ if (lib_ != nullptr) {
+ ctx->set_lib(lib_);
+ }
+ tf_shared_lock l(mu_);
+ multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
+ std::move(callback));
+ }
+
+ const DataTypeVector& output_types() const { return output_types_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ std::shared_ptr<const FunctionLibraryDefinition> function_library() {
+ tf_shared_lock l(mu_);
+ return lib_def_;
+ }
+
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
+ private:
+ // A private class that uses a background thread to keep a per device buffer
+ // full.
+ class MultiDeviceBuffer {
+ public:
+ MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
+ std::unique_ptr<IteratorBase> host_iterator)
+ : buffer_(size),
+ size_(size),
+ max_buffer_size_(max_buffer_size),
+ incarnation_id_(incarnation_id),
+ host_iterator_(std::move(host_iterator)) {}
+
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
+
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (background_thread_finished_) {
+ return;
+ }
+
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
+
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
+ }
+ RunPendingCallbacks();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ HostBufferElement elem;
+ if (incarnation_id_ != incarnation_id) {
+ elem.status = errors::InvalidArgument("Invalid incarnation id");
+ callback(elem);
+ return;
+ }
+
+ bool produced_output = false;
+ {
+ mutex_lock l(mu_);
+ if (cancelled_) {
+ elem.status = errors::Cancelled("Cancelled Multidevice iterator");
+ callback(elem);
+ return;
+ }
+
+ EnsureBackgroundThreadStarted(ctx);
+
+ if (!buffer_[shard_num].data.empty()) {
+ produced_output = true;
+ std::swap(elem, buffer_[shard_num].data.front());
+ buffer_[shard_num].data.pop_front();
+ // Wake up background thread if it is blocked on this element.
+ if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
+ buffer_[shard_num].cond_var.notify_all();
+ }
+ } else {
+ if (background_thread_finished_) {
+ produced_output = true;
+ elem.end_of_sequence = true;
+ } else {
+ buffer_[shard_num].callbacks.push_back(std::move(callback));
+ callback = nullptr;
+ }
+ }
+ }
+
+ if (produced_output) {
+ callback(elem);
+ }
+ }
+
+ private:
+ void EnsureBackgroundThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!background_thread_) {
+ background_thread_.reset(ctx->env()->StartThread(
+ {}, "multi_device_iterator_background_thread",
+ std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
+ this, new IteratorContext(*ctx))));
+ }
+ }
+
+ void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
+ // Run all remaining callbacks.
+ std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
+ std::vector<HostBufferElement> cancellation_elements;
+ {
+ mutex_lock l(mu_);
+
+ for (int i = 0; i < size_; ++i) {
+ while (!buffer_[i].callbacks.empty()) {
+ if (buffer_[i].data.empty()) {
+ HostBufferElement elem;
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ cancellation_elements.push_back(std::move(elem));
+ } else {
+ cancellation_elements.push_back(
+ std::move(buffer_[i].data.front()));
+ buffer_[i].data.pop_front();
+ }
+ cancellation_callbacks.push_back(
+ std::move(buffer_[i].callbacks.front()));
+ buffer_[i].callbacks.pop_front();
+ }
+ }
+ }
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_elements[i]);
+ }
+ }
+
+ void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ int shard_to_fetch = 0;
+ while (true) {
+ HostBufferElement elem;
+ MultiDeviceIteratorCallback callback = nullptr;
+ bool end_of_iterator = false;
+
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].cond_var.wait(l);
+ }
+
+ if (cancelled_) {
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ return;
+ }
+ }
+
+ elem.status =
+ host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
+
+ if (elem.status.ok() && elem.end_of_sequence) {
+ end_of_iterator = true;
+ }
+
+ {
+ mutex_lock l(mu_);
+ // Try to find a callback, else just push stuff into buffer.
+ if (!buffer_[shard_to_fetch].callbacks.empty()) {
+ callback = buffer_[shard_to_fetch].callbacks.front();
+ buffer_[shard_to_fetch].callbacks.pop_front();
+ } else {
+ buffer_[shard_to_fetch].data.push_back(std::move(elem));
+ elem = HostBufferElement();
+ }
+ }
+
+ if (callback) {
+ (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
+ }
+
+ // Finish off the thread if we reach the end of the iterator. Runs
+ // pending callbacks.
+ if (end_of_iterator) {
+ {
+ mutex_lock l(mu_);
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ }
+ RunPendingCallbacks();
+ return;
+ }
+ shard_to_fetch = (shard_to_fetch + 1) % size_;
+ }
+ }
+
+ struct HostBuffer {
+ condition_variable cond_var;
+ std::deque<HostBufferElement> data;
+ std::deque<MultiDeviceIteratorCallback> callbacks;
+ };
+
+ mutex mu_;
+ std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
+ bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
+
+ std::vector<HostBuffer> buffer_;
+
+ const size_t size_;
+ const int64 max_buffer_size_;
+ const int64 incarnation_id_;
+ const std::unique_ptr<IteratorBase> host_iterator_;
+ };
+
+ mutex mu_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<string> devices_;
+ const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
+ std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
+
+ int64 incarnation_id_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
+};
+
+// Just creates a MultiDeviceIterator and returns it.
+class MultiDeviceIteratorHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel.
+ ~MultiDeviceIteratorHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MultiDeviceIterator>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<MultiDeviceIterator>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MultiDeviceIterator(
+ output_types_, output_shapes_, devices_,
+ std::move(flib_def), std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MultiDeviceIterator>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MultiDeviceIterator* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+ string name_;
+ string container_;
+ std::vector<string> devices_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
+ MultiDeviceIteratorHandleOp);
+
+// Calls init on the MultiDeviceIterator.
+class MultiDeviceIteratorInitOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* tensor_max_buffer_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
+ int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
+ core::ScopedUnref unref(resource);
+
+ std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
+ int64 incarnation_id;
+ OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
+ &incarnation_id));
+ Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
+ tensor_incarnation_id.scalar<int64>()() = incarnation_id;
+ OP_REQUIRES_OK(ctx,
+ ctx->set_output("incarnation_id", tensor_incarnation_id));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
+ MultiDeviceIteratorInitOp);
+
+// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
+// TODO(rohanj): Implement using BackgroundWorker that Derek built?
+class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
+ public:
+ explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("multi_device_iterator_get_next_thread_",
+ SanitizeThreadSuffix(name())),
+ 1 /* num_threads */, false /* low_latency_hint */)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor* tensor_shard_num;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
+ int32 shard_num = tensor_shard_num->scalar<int32>()();
+
+ const Tensor* tensor_incarnation_id;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
+ int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
+
+ MultiDeviceIterator* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ thread_pool_->Schedule(std::bind(
+ [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ MultiDeviceIteratorCallback callback = std::bind(
+ [ctx](const HostBufferElement& elem, DoneCallback done) {
+ // iterator->Unref();
+ Status s = elem.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (elem.end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < elem.value.size(); ++i) {
+ ctx->set_output(i, elem.value[i]);
+ }
+ }
+ done();
+ },
+ std::placeholders::_1, std::move(done));
+
+ iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
+ callback);
+ iterator->Unref();
+ },
+ std::move(done)));
+ }
+
+ private:
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
+ MultiDeviceIteratorGetNextFromShardOp);
+
+class MultiDeviceIteratorToStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ resource->Unref();
+
+ Tensor* string_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &string_handle_t));
+ string_handle_t->scalar<string>()() =
+ resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorToStringHandleOp);
+
+class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES(
+ ctx,
+ output_types_.empty() || output_shapes_.empty() ||
+ output_types_.size() == output_shapes_.size(),
+ errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
+ "are set, they must have the same length."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& string_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
+ errors::InvalidArgument("string_handle must be a scalar"));
+
+ ResourceHandle resource_handle;
+ OP_REQUIRES(
+ ctx,
+ resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+ errors::InvalidArgument(
+ "Could not parse string_handle as a valid ResourceHandle"));
+
+ OP_REQUIRES(
+ ctx, resource_handle.device() == ctx->device()->attributes().name(),
+ errors::InvalidArgument("Attempted create an iterator on device \"",
+ ctx->device()->attributes().name(),
+ "\" from handle defined on device \"",
+ resource_handle.device(), "\""));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
+ core::ScopedUnref unref_iterator(resource);
+ if (!output_types_.empty()) {
+ OP_REQUIRES_OK(ctx,
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ }
+ if (!output_shapes_.empty()) {
+ OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
+ resource->output_shapes()));
+ }
+
+ Tensor* resource_handle_t;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
+ resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorFromStringHandleOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index 346e4ceebd..2ab5c83082 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -213,6 +213,14 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
+ if (t.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ }
+ }
+ for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 2a1e9c85f1..754ed772db 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -103,9 +103,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ auto stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(mu_);
- auto stats_aggregator = ctx->stats_aggregator();
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
@@ -136,6 +136,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@@ -219,6 +227,12 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
strings::StrCat(prefix_end_, "::buffer_utilization"),
{static_cast<float>(buffer_.size()) /
static_cast<float>(auto_tuner_.buffer_limit())});
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
}
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index f5314f7a75..7e528a71be 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -34,16 +34,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
- *output = new Dataset(ctx, input, stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const Tensor& resource_handle,
StatsAggregatorResource* stats_aggregator_resource)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ resource_handle_(resource_handle),
stats_aggregator_resource_(stats_aggregator_resource) {
input_->Ref();
stats_aggregator_resource_->Ref();
@@ -75,8 +77,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* resource_handle_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, resource_handle_node}, output));
+ return Status::OK();
}
private:
@@ -129,6 +136,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
};
};
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459194..76afd6f18c 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a97723f..e7882acc80 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index c41fbc42d3..6a9a2accd8 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -40,8 +40,8 @@ namespace internal {
// at the given vertical and horizontal offsets.
//
// "Virtual matrix" dimensions:
-// *0: kernelChannels * kernelDepth * kernelRows * kernelCols;
-// 1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...)
+// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
+// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
//
// *) extracted patches are continuous in memory (innermost dimension assuming
// col major layout)
@@ -113,6 +113,11 @@ class TensorContractionInputMapper<
m_num_patches = tensor.impl().dimensions()[NumDims - 5];
}
+ // Strides for navigating through the single patch.
+ m_patch_plane_stride = m_patch_depth;
+ m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
+ m_patch_col_stride = m_patch_rows * m_patch_row_stride;
+
// Strides for the output tensor.
// IMPORTANT: These strides are used to locate an element in a patch at a
// depth zero (channel), which is not quite the same as "traditional"
@@ -166,6 +171,13 @@ class TensorContractionInputMapper<
m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+ m_fastPatchPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_stride);
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
+
m_fastInputPlaneStride =
internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
m_fastInputRowStride =
@@ -195,6 +207,10 @@ class TensorContractionInputMapper<
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+ m_patch_plane_stride = base_mapper.m_patch_plane_stride;
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_rowStride = base_mapper.m_rowStride;
m_colStride = base_mapper.m_colStride;
m_patchStride = base_mapper.m_patchStride;
@@ -234,6 +250,9 @@ class TensorContractionInputMapper<
m_outputPlanesRows = base_mapper.m_outputPlanesRows;
m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
@@ -305,9 +324,9 @@ class TensorContractionInputMapper<
}
EIGEN_DEVICE_FUNC
- EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; }
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
EIGEN_DEVICE_FUNC
- EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; }
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
EIGEN_DEVICE_FUNC
@@ -391,14 +410,13 @@ class TensorContractionInputMapper<
const Index patchOffset = patchId / m_fastDimZero;
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
-
const Index rowOffset =
(patchOffset - colOffset * m_colStride) / m_fastRowStride;
- const Index inputRow = rowIndex + rowOffset;
-
const Index planeOffset =
patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
const Index inputPlane = planeIndex + planeOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
@@ -524,12 +542,13 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset =
(patchOffset - colOffset * m_colStride) / m_fastRowStride;
- const Index inputRow = rowIndex + rowOffset;
const Index planeOffset =
patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
const Index inputPlane = planeIndex + planeOffset;
if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
@@ -564,7 +583,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
// Check if patchIndex might contain batch and other dimensions.
@@ -594,7 +613,12 @@ class TensorContractionInputMapper<
Index m_patch_cols; // number of columns in the patch
Index m_num_patches; // number of patches to extract
- // Strides for the output tensor.
+ // Strides for navigating through the single patch.
+ Index m_patch_plane_stride;
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+
+ // Strides for the output tensor (depth is not the part of the stride).
Index m_rowStride;
Index m_colStride;
Index m_patchStride;
@@ -637,6 +661,10 @@ class TensorContractionInputMapper<
// Fast representation of various divisors.
internal::TensorIntDivisor<Index> m_fastNumPatches;
+ internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
internal::TensorIntDivisor<Index> m_fastInputRowStride;
internal::TensorIntDivisor<Index> m_fastInputColStride;
@@ -750,13 +778,62 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
+ // plane and depth index respectively that fits into the peeled_k elements
+ // starting at m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
+ const Index row) const {
+ const Index max_plane = fastPatchPlaneStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride() -
+ row * patchRowStride());
+ return std::min<Index>(1 + max_plane, patchPlanes());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input plane
+ // stride. Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
- return m_base_mapper.m_patch_depth;
+ eigen_assert(m_base_mapper.m_patch_depth ==
+ m_base_mapper.m_planeInputStride &&
+ "Patch depth must be equal to plane input stride.");
+ return m_base_mapper.m_planeInputStride;
}
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchPlanes() const {
- return m_base_mapper.m_patch_planes;
+ eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
+ "Patch planes must be equal to row stride.");
+ return m_base_mapper.m_rowStride;
}
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchRows() const {
@@ -768,6 +845,36 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ return m_base_mapper.m_patch_row_stride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ return m_base_mapper.m_fastPatchRowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -832,8 +939,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -859,24 +965,29 @@ class TensorContractionSubMapper<
// matrix" constructed from extracted volume patches) in contiguous memory.
//
// Given column major input (A0 beside A1 in memory):
-// A0 B0 C0 D0 E0 F0 G0 H0 ...
-// A1 B1 C1 D1 E1 F1 G1 H1 ...
-// A2 B2 C2 D2 E2 F2 G2 H2 ...
-// A3 B3 C3 D3 E3 F3 G3 H3 ...
-// A4 B4 C4 D4 E4 F4 G4 H4 ...
-// A5 B5 C5 D5 E5 F5 G5 H5 ...
-// A6 B6 C6 D6 E6 F6 G6 H6 ...
-// A7 B7 C7 D7 E7 F7 G7 H7 ...
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
// A8 ...
// ...
//
-// Packing yields row major output (A0 beside A1 in memory):
-// A0 A1 A2 A3 A4 A5 A6 A7
-// B0 B1 B2 B3 B4 B5 B6 B7
-// C0 ...
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
//
-// *) A, B, C, ... - patches extracted from the original input.
// *) nr - number of registers along the 'n' dimension.
// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
// Multiplication" paper.
@@ -905,7 +1016,11 @@ struct gemm_pack_rhs<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>
SubMapper;
+
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -914,9 +1029,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -929,81 +1041,58 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
-
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
- const Index patch_planes = rhs.patchPlanes();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
- startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
-
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- Eigen::divup(
- peeled_k - c * patch_rows * patch_planes * patch_depth,
- patch_planes * patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
- const Index startPlane =
- ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
- const Index max_planes = std::min<Index>(
- Eigen::divup(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth, // row
- patch_depth) +
- startPlane,
- patch_planes);
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
- const bool pad_row0 = dm0.padRow(r);
- const bool pad_row1 = dm1.padRow(r);
- const bool pad_row2 = dm2.padRow(r);
- const bool pad_row3 = dm3.padRow(r);
+ const bool pad_row0 = pad_col0 || dm0.padRow(r);
+ const bool pad_row1 = pad_col1 || dm1.padRow(r);
+ const bool pad_row2 = pad_col2 || dm2.padRow(r);
+ const bool pad_row3 = pad_col3 || dm3.padRow(r);
- for (Index p = startPlane; p < max_planes; ++p) {
- eigen_assert(k < peeled_k);
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
- const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
- const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
- const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
- const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+ const bool pad0 = pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_row3 || dm3.padPlane(p);
const Index idx0 = dm0.baseIndex(p, r, c);
const Index idx1 = dm1.baseIndex(p, r, c);
const Index idx2 = dm2.baseIndex(p, r, c);
const Index idx3 = dm3.baseIndex(p, r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow) && (p == startPlane))
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
? rhs.depthOffset()
: 0;
- const Index max_depth = std::min<Index>(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth - // row
- p * patch_depth + // plane
- startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
-
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -1026,20 +1115,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple planes, rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = dm0.loadPacketStandard(k);
@@ -1055,7 +1136,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -1074,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1113,6 +1196,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1121,9 +1207,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
@@ -1138,56 +1221,39 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
-
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
- const Index patch_planes = rhs.patchPlanes();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) +
- startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
-
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- Eigen::divup(
- peeled_k - c * patch_rows * patch_planes * patch_depth,
- patch_planes * patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
- const Index startPlane =
- ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0;
- const Index max_planes = std::min<Index>(
- Eigen::divup(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth, // row
- patch_depth) +
- startPlane,
- patch_planes);
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
const bool pad_row0 = dm0.padRow(r);
const bool pad_row1 = dm1.padRow(r);
const bool pad_row2 = dm2.padRow(r);
const bool pad_row3 = dm3.padRow(r);
- for (Index p = startPlane; p < max_planes; ++p) {
- eigen_assert(k < peeled_k);
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
@@ -1199,20 +1265,14 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(p, r, c);
const Index idx3 = dm3.baseIndex(p, r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow) && (p == startPlane))
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
? rhs.depthOffset()
: 0;
- const Index max_depth = std::min<Index>(
- peeled_k -
- c * patch_rows * patch_planes * patch_depth - // col
- r * patch_planes * patch_depth - // row
- p * patch_depth + // plane
- startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
-
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -1237,21 +1297,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
@@ -1270,6 +1318,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -1289,7 +1339,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1328,6 +1378,8 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
Index depth, Index cols, Index stride = 0,
@@ -1335,8 +1387,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1364,7 +1414,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1454,7 +1504,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
// Spatial size of the kernel.
- const TensorIndex kernelDepth =
+ const TensorIndex kernelPlanes =
isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
const TensorIndex kernelRows =
isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
@@ -1474,27 +1524,27 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
const TensorIndex inputCols =
isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
- TensorIndex out_depth;
+ TensorIndex out_planes;
TensorIndex out_height;
TensorIndex out_width;
switch (padding_type) {
case PADDING_VALID:
- out_depth = Eigen::divup(inputPlanes - kernelDepth + 1,
- static_cast<TensorIndex>(stridePlanes));
+ out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
+ static_cast<TensorIndex>(stridePlanes));
out_height = Eigen::divup(inputRows - kernelRows + 1,
static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols - kernelCols + 1,
static_cast<TensorIndex>(strideCols));
break;
case PADDING_SAME:
- out_depth =
+ out_planes =
Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
out_height =
Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
break;
default:
- out_depth = 0;
+ out_planes = 0;
out_height = 0;
out_width = 0;
eigen_assert(false && "unexpected padding");
@@ -1503,9 +1553,9 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> kernel_dims;
if (isColMajor) {
kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
} else {
- kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelFilters;
}
@@ -1516,15 +1566,15 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
pre_contract_dims[0] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[1] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_planes * out_height * out_width;
for (int i = 4; i < NumDims; ++i) {
pre_contract_dims[1] *= in.dimension(i);
}
} else {
pre_contract_dims[1] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[0] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_planes * out_height * out_width;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= in.dimension(i);
}
@@ -1543,7 +1593,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
post_contract_dims[0] = kernelFilters;
- post_contract_dims[1] = out_depth;
+ post_contract_dims[1] = out_planes;
post_contract_dims[2] = out_height;
post_contract_dims[3] = out_width;
for (int i = 4; i < NumDims; ++i) {
@@ -1551,7 +1601,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
}
} else {
post_contract_dims[NumDims - 1] = kernelFilters;
- post_contract_dims[NumDims - 2] = out_depth;
+ post_contract_dims[NumDims - 2] = out_planes;
post_contract_dims[NumDims - 3] = out_height;
post_contract_dims[NumDims - 4] = out_width;
for (int i = 0; i < NumDims - 4; ++i) {
@@ -1564,13 +1614,13 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
kernel.reshape(kernel_dims)
.contract(input
.extract_volume_patches(
- kernelDepth, kernelRows, kernelCols, stridePlanes,
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
strideRows, strideCols, padding_type)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
input
- .extract_volume_patches(kernelDepth, kernelRows, kernelCols,
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
stridePlanes, strideRows, strideCols,
padding_type)
.reshape(pre_contract_dims)
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index a4dff4b91c..e926d73f87 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -22,8 +22,36 @@ namespace Eigen {
namespace internal {
-// TODO: Consolidate this part of the code with the image patch extraction code
-// since they are both very similar.
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract image patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelRows * kernelCols;
+// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+// TODO(ezhulenev): Consolidate this part of the code with the image patch
+// extraction code since they are both very similar.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar_, typename Index,
typename nocontract_t, typename contract_t, int Side, int packet_size,
@@ -77,12 +105,17 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3];
} else {
- const int NumDims = tensor.impl().dimensions().size();
+ const size_t NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
m_num_patches = tensor.impl().dimensions()[NumDims - 4];
}
+
+ // Strides for navigating through the single patch.
+ m_patch_row_stride = patch_depth;
+ m_patch_col_stride = patch_rows * m_patch_row_stride;
+
m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
m_patch_col_inflate_strides = tensor.impl().colInflateStride();
@@ -111,6 +144,10 @@ class TensorContractionInputMapper<
m_rowPaddingTop = tensor.impl().rowPaddingTop();
m_colPaddingLeft = tensor.impl().colPaddingLeft();
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
m_fastInputRowStride =
internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
m_fastInputColStride =
@@ -126,6 +163,10 @@ class TensorContractionInputMapper<
: m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
@@ -148,6 +189,8 @@ class TensorContractionInputMapper<
m_rowPaddingTop = base_mapper.m_rowPaddingTop;
m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
m_fastNumPatches = base_mapper.m_fastNumPatches;
@@ -238,6 +281,8 @@ class TensorContractionInputMapper<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>;
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
Index colIndex, Index otherIndex) const {
@@ -250,6 +295,7 @@ class TensorContractionInputMapper<
(m_patch_col_inflate_strides == 1)
? inputCol
: ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
const Index rowOffset = patchOffset - colOffset * m_colStride;
const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
const Index origInputRow =
@@ -268,6 +314,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
Index colIndex,
@@ -276,10 +324,9 @@ class TensorContractionInputMapper<
// Find the offset of the element wrt the location of the first element.
const Index patchOffset = patchId / m_fastDimZero;
-
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
inputRow >= m_inputRows) {
@@ -291,6 +338,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
Index colIndex,
@@ -318,12 +367,14 @@ class TensorContractionInputMapper<
if ((patchDepth() % packetSize) == 0) {
return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
} else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
const Index patchOffsets[2] = {
patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
patchOffsets[1] / m_fastColStride};
-
const Index inputCols[2] = {colIndex + colOffsets[0],
colIndex + colOffsets[1]};
if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
@@ -371,8 +422,8 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
inputRow >= m_inputRows) {
@@ -401,7 +452,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
const Index patch2DIndex = (NumInputDims == 3)
@@ -414,8 +465,15 @@ class TensorContractionInputMapper<
rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
}
- Index m_patch_cols; // number of colums in the patch
- Index m_num_patches; // number of patches to extract.
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract.
+
+ // Strides for navigating through the single patch.
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
Index m_patch_row_inflate_strides; // the strides for row inflation in the
// image patch
Index m_patch_col_inflate_strides; // the strides for col inflation in the
@@ -549,6 +607,40 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
+ // index respectively that fits into the peeled_k elements starting at
+ // m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input row stride.
+ // Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
return m_base_mapper.m_rowInputStride;
@@ -563,6 +655,28 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -603,8 +717,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -617,12 +730,44 @@ class TensorContractionSubMapper<
Index m_depth_offset; // First row in the input matrix
Index m_col_offset; // First col in the input matrix
- Index m_rowIndex; // precomputed row index corresponding to the col offset
- Index m_colIndex; // precomputed col index corresponding to the col offset
- Index
- m_otherIndex; // precomputed other index corresponding to the col offset
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
};
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted image patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar, typename Index,
typename nocontract_t, typename contract_t, int packet_size,
@@ -649,9 +794,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -660,9 +805,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -675,30 +817,27 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows, if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -709,14 +848,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -738,19 +876,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
@@ -767,6 +895,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -824,9 +954,9 @@ struct gemm_pack_rhs<
Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -835,9 +965,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
@@ -851,30 +978,27 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -885,14 +1009,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -916,22 +1039,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -949,7 +1062,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -968,7 +1083,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1006,8 +1121,7 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1016,8 +1130,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1045,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc
new file mode 100644
index 0000000000..52cd078a35
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/*
+See extract_image_patches_op* files and docs for extract_image_patches in
+../ops/image_ops.cc.
+
+Rates are not supported as of now, but the comments hint how to edit the code
+when rates are to be added.
+*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+static inline void ParseAttributeVec5(OpKernelConstruction* context,
+ const string& attr_name,
+ std::vector<int32>* attr) {
+ OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
+ OP_REQUIRES(
+ context, (*attr)[0] == 1 && (*attr)[4] == 1,
+ errors::Unimplemented("Only support ", attr_name, " across space."));
+ OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1,
+ errors::OutOfRange(attr_name, " is out of range."));
+}
+
+template <typename Device, typename T>
+class ExtractVolumePatchesOp : public UnaryOp<T> {
+ public:
+ explicit ExtractVolumePatchesOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ ParseAttributeVec5(context, "ksizes", &ksizes_);
+ ParseAttributeVec5(context, "strides", &strides_);
+ // ParseAttributeVec5(context, "rates", &rates_);
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_planes, in_rows, in_cols, channels ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input.shape().DebugString()));
+
+ const int batch = input.dim_size(0);
+ const int in_planes = input.dim_size(1);
+ const int in_rows = input.dim_size(2);
+ const int in_cols = input.dim_size(3);
+ const int depth = input.dim_size(4);
+
+ const int ksize_planes = ksizes_[1];
+ const int ksize_rows = ksizes_[2];
+ const int ksize_cols = ksizes_[3];
+
+ const int stride_planes = strides_[1];
+ const int stride_rows = strides_[2];
+ const int stride_cols = strides_[3];
+
+ /*
+ // TODO(hsgkim): enable rates
+ // Rates are disabled as of now due to Eigen's definitions of
+ // `extract_volume_patch` functions; none of them accept rates
+ // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
+ // workaround has to be found for this.
+ // In order to enable rates, uncomment the following lines and use
+ // ksize_*_eff instead of ksize_* for the second argument of
+ // GetWindowedOutputSize calls.
+
+ const int rate_planes = rates_[1];
+ const int rate_rows = rates_[2];
+ const int rate_cols = rates_[3];
+
+ const int ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ int64 out_planes = 0, out_rows = 0, out_cols = 0;
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_planes, ksize_planes, stride_planes,
+ padding_, &out_planes, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_rows, ksize_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ const std::vector<int64> out_sizes = {
+ batch, out_planes, out_rows, out_cols,
+ ksize_planes * ksize_rows * ksize_cols * depth};
+ TensorShape out_shape(out_sizes);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ functor::ExtractVolumePatchesForward<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
+ ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
+ /* rate_planes, rate_rows, rate_cols, */
+ BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> ksizes_;
+ std::vector<int32> strides_;
+ // std::vector<int32> rates_;
+
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp);
+};
+
+// Registration of the CPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<CPUDevice, T>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+// clang-format off
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
+ int patch_planes, int patch_rows, int patch_cols, \
+ int stride_planes, int stride_rows, int stride_cols, \
+ /* int rate_planes, int rate_rows, int rate_cols, */ \
+ const Eigen::PaddingType& padding, \
+ typename TTypes<T, 5>::Tensor output); \
+ extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
+// clang-format on
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<GPUDevice, T>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h
new file mode 100644
index 0000000000..7e0502b770
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_volume_patch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct ExtractVolumePatchesForward {
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int patch_planes, int patch_rows, int patch_cols,
+ int stride_planes, int stride_rows, int stride_cols,
+ /* int rate_planes, int rate_rows, int rate_cols, */
+ const Eigen::PaddingType& padding,
+ typename TTypes<T, 5>::Tensor output) {
+ const int64 N = std::max(input.size(), output.size());
+ if (N <= std::numeric_limits<Index32>::max()) {
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
+ To32Bit(input)
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output_32bit.dimensions());
+ } else {
+ output.device(d) =
+ input
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output.dimensions());
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
new file mode 100644
index 0000000000..c636493602
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index c90ad2cfeb..ada1235449 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession {
}
void FuzzImpl(const uint8_t* data, size_t size) final {
+ // We need to be sure that we don't request too many elements (i.e., we
+ // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large
+ // number of elements, up to the limit of the memory available to the OS.
+ // However, due to the tracing done in ASAN, after 2^32 bytes of requested
+ // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's
+ // try parsing the proto here, check that the size (if valid) is below a
+ // maximum threshold (using 2^20 for convenience), and then run the
+ // remainder of the fuzzer testing. Of course, this duplicates some work
+ // but it's better than repeating the investigation whenever Autofuzz
+ // detects another similar OOM.
+ string as_string = string(reinterpret_cast<const char*>(data), size);
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, as_string)) {
+ LOG(WARNING) << "Unable to parse proto of tensor\n";
+ return;
+ }
+ if (!TensorShape::IsValid(proto.tensor_shape())) {
+ LOG(WARNING) << "Invalid tensor shape\n";
+ return;
+ }
+ TensorShape shape(proto.tensor_shape());
+ const int64 num_elements = shape.num_elements();
+ const int64 max_num_elements = 1 << 20;
+ if (num_elements > max_num_elements) {
+ LOG(WARNING) << "Requiring a tensor with too many elements\n";
+ return;
+ }
+
+ // Now we can do the actual fuzz implementation
Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
- input_tensor.scalar<string>()() =
- string(reinterpret_cast<const char*>(data), size);
+ input_tensor.scalar<string>()() = as_string;
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
RunOneInput(input_tensor).IgnoreError();
}
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 277ee2be02..1c78de253e 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
@@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
const Eigen::array<Eigen::DenseIndex, 1> loc{i};
gather_nd_generator(loc);
}
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0ddc..374a05850e 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab38..4ad390a411 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// MKL does not support half, bfloat16 and int32 types for
// matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif // INTEL_MKL_DNN_ONLY
-#else // INTEL MKL
+#else // INTEL_MKL && ENABLE_MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc3..bc135de11e 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
+#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif // ENABLE_MKL
} // end namespace tensorflow
#endif
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce32..f4788f4851 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
+#ifdef ENABLE_MKL
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
+#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
new file mode 100644
index 0000000000..d63e14adf6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/array_ops.cc.
+
+#ifdef INTEL_MKL
+#ifndef INTEL_MKL_ML_ONLY
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/util/mkl_util.h"
+
+using mkldnn::stream;
+using mkldnn::view;
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ // tensor must be either int32 or int64
+ DCHECK(false);
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// A version of SharedValidation (slice_op.h) written for input that is in
+// either Mkl layout or Tensorflow layout.
+// A shared code to validate input shapes and check for identity, which is not dependent on the type of T.
+// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.)
+static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const int kInputTensorIndex = 0;
+ const int kInputBeginIndex = 1;
+ const int kInputSizeIndex = 2;
+ const Tensor& input = MklGetInput(context, kInputTensorIndex);
+ const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
+ const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
+
+ MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
+ GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
+ GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
+ GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
+
+ // Begin and size tensors cannot be in MklDnn layout.
+ DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
+ DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
+
+ TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetTfShape()
+ : input.shape();
+ const int input_dims = input_tf_shape.dims();
+
+ OP_REQUIRES(
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input_dims &&
+ size_tensor.NumElements() == input_dims,
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
+ " and ", size_tensor.shape().DebugString(), " instead."));
+
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_tf_shape.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_tf_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_tf_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+// A version of SharedSliceCommonCases function written for input tensor
+// that may be in MklDnn layout or in Tensorflow layout.
+template <typename T>
+static void CheckCommonCasesForMklInputs(OpKernelContext* context,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ bool* done) {
+ bool is_identity = true;
+ *done = false;
+
+ ValidateMklInputs(context, &is_identity, begin, size);
+ if (!context->status().ok()) return;
+
+ const Tensor& input = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ // Mkl metadata tensor in this case can just be forwarded from input to
+ // output.
+ AllocateOutputSetMklShape(context, 0, input_mkl_shape);
+ *done = true;
+ }
+}
+
+// MKL-DNN implementation of Slice
+template <typename Device, typename T>
+class MklDnnSliceOp : public OpKernel {
+ public:
+ explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ ~MklDnnSliceOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ bool done = false;
+
+ CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
+ if (!context->status().ok() || done == true) return;
+
+ // Though MKL-DNN supports more than 8 dimension and
+ // less than 12 dimension tensor.
+ // But we are mimicking functionality of Eigen Slice op for CPU.
+ if (begin.size() >= 8) {
+ OP_REQUIRES(
+ context, false,
+ errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions"));
+ }
+
+ ComputeMklDnnSlice(context, begin, size);
+ }
+
+ private:
+ // Slice op implemented using MKL-DNN APIs.
+ void ComputeMklDnnSlice(OpKernelContext* context,
+ const gtl::InlinedVector<int64, 4>& begin,
+ const gtl::InlinedVector<int64, 4>& size) {
+ try {
+ // MKL-DNN API usage below is guided by description at:
+ // https://github.com/01org/mkl-dnn/issues/69
+ //
+ // Relevant part of the description is copied below:
+ //
+ // Let's say you want to copy a part of memory into another buffer (and
+ // probably change the format). Then your steps are:
+ //
+ // 1. create memory primitive descriptor in_mem_pd and memory primitive
+ // in_mem_p for the entire source data.
+ // 2. create view primitive descriptor in_submem_pd based on in_mem_pd,
+ // initial offsets, and sub-sizes
+ // 3. create memory primitive descriptor out_mem_pd and memory primitive
+ // out_mem_p for the output (the logical sizes should match sub-sizes
+ // used in step 2, but the format might be arbitrary)
+ // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd
+ // and out_mem_pd
+ // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and
+ // out_mem_p.
+ //
+ // Please notice that there is no view primitive. There is only view
+ // primitive descriptor. And the reorder uses source memory as input but
+ // traverses it according to a view in_submem_pd.
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Populate offsets and sizes in memory::dims format based on vector.
+ memory::dims begin_dims = {};
+ begin_dims.resize(begin.size());
+ for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
+ memory::dims size_dims = {};
+ bool empty = false;
+ size_dims.resize(size.size());
+ for (size_t i = 0; i < size.size(); ++i) {
+ size_dims[i] = size[i];
+ if (size_dims[i] == 0) empty = true;
+ }
+
+ Tensor* output_tensor = nullptr;
+ MklDnnShape output_mkl_shape;
+
+ // If no dimension is selected in slice, the result should be empty.
+ // Just return an empty output tensor, and a dummy Mkl-shape tensor.
+ if (empty) { // for empty dims
+ auto shape_to = MklDnnDimsToTFShape(size_dims);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ output_mkl_shape);
+ return;
+ }
+
+ // Step 1 (as per above description) - Create memory for user data.
+ // We use blocked format here to describe input tensor.
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (input_mkl_shape.IsMklTensor()) {
+ auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
+ auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
+ begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
+ size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
+ auto input_md = input_mkl_shape.GetMklLayout();
+ src.SetUsrMem(input_md, &input_tensor);
+ } else {
+ // Initialize input dimensions and strides to be used when input is not
+ // in MklDnn layout.
+ memory::dims input_dims, input_strides;
+ input_dims = TFShapeToMklDnnDims(input_tensor.shape());
+ input_strides = CalculateTFStrides(input_dims);
+ // Create input memory descriptor.
+ auto input_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ src.SetUsrMem(input_md, &input_tensor);
+ }
+
+ // Step 2 - create view primitive descriptor
+ auto view_pd =
+ view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims)
+ .dst_primitive_desc();
+ auto output_strides = CalculateTFStrides(size_dims);
+ auto output_md =
+ MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
+ auto output_pd = memory::primitive_desc(output_md, cpu_engine);
+
+ // Step 3 - Create memory for output. If input is in MklDnn layout, then
+ // output is also in MklDnn layout. Otherwise, output is in Tensorflow
+ // layout.
+ AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
+ &output_tensor, &output_mkl_shape);
+ DCHECK(output_tensor);
+ DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
+ output.SetUsrMem(output_md, output_tensor);
+
+ std::vector<primitive> net;
+ // Step 4 - create reorder primitive desc between view_pd and output_pd.
+ auto reorder_pd =
+ reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc());
+ // Step 5 - create reorder primitive itself.
+ net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem()));
+ // Execute the reorder primitive.
+ stream(stream::kind::eager).submit(net).wait();
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ void AllocateOutputTensor(OpKernelContext* context,
+ const MklDnnShape& input_mkl_shape,
+ memory::primitive_desc* output_pd,
+ const memory::dims& output_dims,
+ Tensor** output_tensor,
+ MklDnnShape* output_mkl_shape) {
+ DCHECK(output_tensor);
+ DCHECK(output_mkl_shape);
+
+ TensorShape output_tf_shape;
+
+ if (input_mkl_shape.IsMklTensor()) {
+ // Since input tensor is in Mkl layout, output tensor will be in Mkl
+ // layout.
+
+ // Allocate shape of Mkl tensor.
+ output_mkl_shape->SetMklTensor(true);
+ output_mkl_shape->SetMklLayout(output_pd);
+ output_mkl_shape->SetElemType(MklDnnType<T>());
+ output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
+ input_mkl_shape.GetTfDataFormat());
+
+ output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1);
+ } else {
+ // If input is not in Mkl layout, then output won't be in Mkl layout.
+ output_mkl_shape->SetMklTensor(false);
+ output_tf_shape = MklDnnDimsToTFShape(output_dims);
+ }
+
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ *output_mkl_shape);
+ }
+};
+
+// MKL-DNN Slice registration
+#define REGISTER_MKL_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDnnSliceOp<CPUDevice, type>);
+
+TF_CALL_float(REGISTER_MKL_SLICE);
+#undef REGISTER_MKL_SLICE
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index e37232539f..04a53697c0 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
- // Verify that minval < maxval
+ // Allocate output, and exit early if possible
+ Tensor* output;
+ OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
+ if (output->NumElements() == 0) return;
+
+ // Verify that minval < maxval. This check intentionally happens after the
+ // early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
@@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
Distribution;
Distribution dist(lo, hi);
- Tensor* output;
- OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 88b3c2ac76..bb8254eaac 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -21,11 +21,11 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 26705a8d34..427044ca67 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -51,7 +51,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif
-#include "tensorflow/core/kernels/resource_variable_ops.h"
+#include <memory>
+#include <vector>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -60,10 +62,12 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -72,6 +76,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
+ ResourceHandlesOp<Var>);
ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
@@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, t);
}
+ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
+ int n;
+ OP_REQUIRES_OK(c, c->GetAttr("N", &n));
+ OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
+ OP_REQUIRES(c, n == dtypes_.size(),
+ errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp (", n,
+ " vs. ", dtypes_.size(), ")"));
+}
+
+void ReadVariablesOp::Compute(OpKernelContext* ctx) {
+ std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
+ dtypes_.size());
+ std::vector<const ResourceHandle*> handles(dtypes_.size());
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ handles[i] = &HandleFromInput(ctx, i);
+ }
+ const auto status = LookupResources(ctx, handles, &variables);
+ OP_REQUIRES(ctx, status.ok(),
+ errors::FailedPrecondition(
+ "Error while reading resource variable. This could mean that "
+ "the variable was uninitialized. ",
+ status.ToString()));
+
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ // We're acquiring a reference to the underlying buffer while
+ // holding a shared lock to guarantee ordering of reads and
+ // writes.
+ tf_shared_lock ml(*variables[i]->mu());
+ const Tensor& t = *variables[i]->tensor();
+ OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
+ errors::InvalidArgument(
+ "Trying to read variable ", handles[i]->name(),
+ " from Container: ", handles[i]->container(),
+ " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
+ " got ", DataTypeString(t.dtype())));
+ ctx->set_output(i, t);
+ }
+}
+
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
+ ReadVariablesOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
+ ReadVariablesOp);
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
@@ -121,7 +172,12 @@ REGISTER_KERNEL_BUILDER(
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
- ResourceHandleOp<Var>)
+ ResourceHandleOp<Var>) \
+ REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("resources") \
+ .TypeConstraint<type>("dtypes"), \
+ ResourceHandlesOp<Var>)
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 9b60106f13..cffb732c38 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel {
DataType dtype_;
};
+class ReadVariablesOp : public OpKernel {
+ public:
+ explicit ReadVariablesOp(OpKernelConstruction* c);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
+ private:
+ DataTypeVector dtypes_;
+};
+
class DestroyResourceOp : public OpKernel {
public:
explicit DestroyResourceOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e0194605ce..2f8aede427 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -145,6 +145,7 @@ class ScatterNdUpdateOp : public OpKernel {
if (dtype_ == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc
index 393818730b..a4a59dbcbc 100644
--- a/tensorflow/core/kernels/split_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc
@@ -54,6 +54,7 @@ void SplitCustom<Device, T>::operator()(
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 7b537fef5b..f0575de4d9 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -306,6 +306,7 @@ class StridedSliceAssignOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
index a6829b29d9..435a7abdca 100644
--- a/tensorflow/core/kernels/string_length_op.cc
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -14,13 +14,18 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/string_util.h"
namespace tensorflow {
namespace {
class StringLengthOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
@@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel {
auto src = input.flat<string>();
auto dst = output->flat<int32>();
- for (int n = 0; n < src.size(); ++n) {
- dst(n) = src(n).size();
+ switch (unit_) {
+ case CharUnit::BYTE:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ break;
+ case CharUnit::UTF8_CHAR:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = UTF8StrLen(src(n));
+ }
+ break;
}
}
+
+ private:
+ CharUnit unit_ = CharUnit::BYTE;
};
REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
new file mode 100644
index 0000000000..3a9803a052
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.cc
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/string_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace {
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+} // namespace
+
+namespace tensorflow {
+
+// Sets unit value based on str.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
+ if (str == "UTF8") {
+ *encoding = UnicodeEncoding::UTF8;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid encoding \"", str, "\": Should be one of: BYTE"));
+ }
+ return Status::OK();
+}
+
+// Sets unit value based on str.
+Status ParseCharUnit(const string& str, CharUnit* unit) {
+ if (str == "BYTE") {
+ *unit = CharUnit::BYTE;
+ } else if (str == "UTF8_CHAR") {
+ *unit = CharUnit::UTF8_CHAR;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
+ }
+ return Status::OK();
+}
+
+// Return the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string) {
+ const int32 byte_size = string.size();
+ const char* const end = string.data() + byte_size;
+ const char* ptr = string.data();
+ int32 skipped_count = 0;
+ while (ptr < end) {
+ skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
+ }
+ const int32 result = byte_size - skipped_count;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
new file mode 100644
index 0000000000..390cf57702
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.h
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Enumeration for unicode encodings. Used by ops such as
+// tf.strings.unicode_encode and tf.strings.unicode_decode.
+// TODO(edloper): Add support for:
+// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
+enum class UnicodeEncoding { UTF8 };
+
+// Enumeration for character units. Used by string such as
+// tf.strings.length and tf.substr.
+// TODO(edloper): Add support for: UTF32_CHAR, etc.
+enum class CharUnit { BYTE, UTF8_CHAR };
+
+// Sets `encoding` based on `str`.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
+
+// Sets `unit` value based on `str`.
+Status ParseCharUnit(const string& str, CharUnit* unit);
+
+// Returns the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 765467bc1e..0e6c0ddccc 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
}
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index e8dc4fad21..384a63e945 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index fe93b91eb8..a97a71b344 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -259,6 +259,7 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -576,6 +577,7 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -1218,6 +1220,7 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5aa0..2fbe1fe7cb 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index d3c4f62071..4262a5404b 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -15,13 +15,16 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/util/ptr_util.h"
+
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource) {
+ *maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
- Var* var;
- if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
- return var->mu();
+ if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
+ return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
@@ -32,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
-// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once. If
-// do_lock is false, returns immediately. Note that this silently doesn't lock
-// mutexes for invalid variable references; in all usages this is followed by
-// GetInputTensor which will signal a failure.
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// in address order to mitigate deadlock. Returns a structure that, when
+// deleted, will release the acquired mutexes. Safe to pass duplicates - will
+// only lock each distinct mutex once. If do_lock is false, returns
+// immediately. Note that this silently doesn't lock mutexes for invalid
+// variable references; in all usages this is followed by GetInputTensor which
+// will signal a failure.
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
@@ -46,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
break;
}
}
- std::vector<mutex_lock> locks;
if (!do_lock && !any_resource) {
- return locks;
+ return VariableInputLockHolder({}, {});
}
+ std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- mutex* mutex = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
+ if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
@@ -63,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
+ std::unique_ptr<std::vector<mutex_lock>> locks =
+ MakeUnique<std::vector<mutex_lock>>();
+ locks->reserve(acquire_order.size());
+
for (auto input : acquire_order) {
- mutex* mu = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
+ core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
- locks.emplace_back(*mu);
+ locks->emplace_back(*mu);
}
}
- return locks;
+ return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 071cb371a7..9f173a80f7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -23,9 +23,42 @@ limitations under the License.
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
+// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
+//
+// If `input` corresponds to a `DT_RESOURCE`-type variable input,
+// `*maybe_resource` will be updated to contain the underlying resource, and the
+// caller will be responsible for calling `Unref()` on that resource.
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource);
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// Utility structure that releases a sequence of borrowed mutexes when it is
+// deleted.
+struct VariableInputLockHolder {
+ public:
+ VariableInputLockHolder(std::vector<Var*> vars,
+ std::unique_ptr<std::vector<mutex_lock>> locks)
+ : vars_(std::move(vars)), locks_(std::move(locks)) {}
+
+ VariableInputLockHolder(VariableInputLockHolder&& other)
+ : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
+
+ ~VariableInputLockHolder() {
+ // Release the locks before unreffing the Vars, because each lock
+ // is potentially borrowed from a Var in vars_.
+ locks_.reset();
+ for (Var* var : vars_) {
+ var->Unref();
+ }
+ }
+
+ private:
+ std::vector<Var*> vars_;
+ // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
+ // because a `std::vector<mutex_lock>` is not movable on all platforms.
+ std::unique_ptr<std::vector<mutex_lock>> locks_;
+};
+
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 9a07ded17d..acf162deec 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -561,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* resource;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource);
+ core::ScopedUnref scoped_unref(resource);
if (use_exclusive_lock_ && mu != nullptr) {
mutex_lock l1(*mu);
// Don't try to acquire a lock on the second ref as they share the same
@@ -710,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &var);
+ core::ScopedUnref scoped_unref(var);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a3..48e392c070 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
ConjugateTransposeCpuOp);
+#endif // INTEL_MKL && ENABLE_MKL
+
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
-#endif // INTEL_MKL
#if GOOGLE_CUDA
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc
new file mode 100644
index 0000000000..085e397eba
--- /dev/null
+++ b/tensorflow/core/kernels/unicode_script_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "unicode/errorcode.h" // TF:icu
+#include "unicode/uscript.h" // TF:icu
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class UnicodeScriptOp : public OpKernel {
+ public:
+ explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<int32>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<int32>();
+
+ icu::ErrorCode status;
+ for (int i = 0; i < input_flat.size(); i++) {
+ UScriptCode script_code = uscript_getScript(input_flat(i), status);
+ if (status.isSuccess()) {
+ output_flat(i) = script_code;
+ } else {
+ output_flat(i) = -1;
+ status.reset();
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU),
+ UnicodeScriptOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 8879d9dd4c..2255597651 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -21,10 +21,10 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"