aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD50
-rw-r--r--tensorflow/core/kernels/batching_util/shared_batch_scheduler.h6
-rw-r--r--tensorflow/core/kernels/broadcast_to_op.cc91
-rw-r--r--tensorflow/core/kernels/broadcast_to_op.h220
-rw-r--r--tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc34
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h5
-rw-r--r--tensorflow/core/kernels/ctc_decoder_ops.cc34
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc35
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc8
-rw-r--r--tensorflow/core/kernels/roll_op.cc7
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h8
-rw-r--r--tensorflow/core/kernels/string_strip_op.cc53
-rw-r--r--tensorflow/core/kernels/training_ops.cc150
-rw-r--r--tensorflow/core/kernels/training_ops.h12
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc30
15 files changed, 678 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index f7f6a9b505..201cd35798 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -617,6 +617,7 @@ cc_library(
":batch_space_ops",
":bcast_ops",
":bitcast_op",
+ ":broadcast_to_op",
":concat_op",
":constant_op",
":depth_space_ops",
@@ -669,6 +670,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "broadcast_to_op",
+ prefix = "broadcast_to_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "concat_op",
prefix = "concat_op",
deps = ARRAY_DEPS,
@@ -4227,6 +4234,7 @@ cc_library(
":regex_replace_op",
":string_join_op",
":string_split_op",
+ ":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
],
@@ -4272,6 +4280,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_strip_op",
+ prefix = "string_strip_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
name = "substr_op",
prefix = "substr_op",
deps = STRING_DEPS,
@@ -5947,8 +5961,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -5963,8 +5976,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -5980,8 +5992,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6001,8 +6012,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6018,8 +6028,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6035,8 +6044,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6044,8 +6052,7 @@ tf_mkl_kernel_library(
srcs = ["mkl_fused_batch_norm_op.cc"],
deps = NN_DEPS + [
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6053,8 +6060,7 @@ tf_mkl_kernel_library(
prefix = "mkl_aggregate_ops",
deps = MATH_DEPS + [
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6062,8 +6068,7 @@ tf_mkl_kernel_library(
prefix = "mkl_concat_op",
deps = ARRAY_DEPS + [
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6071,8 +6076,7 @@ tf_mkl_kernel_library(
prefix = "mkl_reshape_op",
deps = ARRAY_DEPS + [
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6080,8 +6084,7 @@ tf_mkl_kernel_library(
prefix = "mkl_identity_op",
deps = ARRAY_DEPS + [
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
@@ -6089,8 +6092,7 @@ tf_mkl_kernel_library(
prefix = "mkl_lrn_op",
deps = NN_DEPS + [
"//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
+ ] + if_mkl(["@mkl_dnn"]),
)
tf_mkl_kernel_library(
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
index edc88a0384..b4bce90841 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
@@ -136,7 +136,7 @@ class SharedBatchScheduler
// (inclusive). If there is a need to quantize the batch sizes, i.e. only
// submit batches whose size is in a small set of allowed sizes, that can be
// done by adding padding in the process-batch callback.
- int max_batch_size = 1000;
+ size_t max_batch_size = 1000;
// If a task has been enqueued for this amount of time (in microseconds),
// and a thread is available, the scheduler will immediately form a batch
@@ -157,7 +157,7 @@ class SharedBatchScheduler
// If this limit is reached, Schedule() will return an UNAVAILABLE error.
// See the class documentation above for guidelines on how to tune this
// parameter.
- int max_enqueued_batches = 10;
+ size_t max_enqueued_batches = 10;
};
Status AddQueue(const QueueOptions& options,
std::function<void(std::unique_ptr<Batch<TaskType>>)>
@@ -394,7 +394,7 @@ Status SharedBatchScheduler<TaskType>::AddQueue(
std::function<void(std::unique_ptr<Batch<TaskType>>)>
process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
- if (options.max_batch_size <= 0) {
+ if (options.max_batch_size == 0) {
return errors::InvalidArgument("max_batch_size must be positive; was ",
options.max_batch_size);
}
diff --git a/tensorflow/core/kernels/broadcast_to_op.cc b/tensorflow/core/kernels/broadcast_to_op.cc
new file mode 100644
index 0000000000..2810925bbc
--- /dev/null
+++ b/tensorflow/core/kernels/broadcast_to_op.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/broadcast_to_op.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+template <typename Device, typename T>
+class BroadcastToOp : public OpKernel {
+ public:
+ explicit BroadcastToOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& input_tensor = ctx->input(0);
+ const TensorShape& input_shape = input_tensor.shape();
+
+ const Tensor& shape_tensor = ctx->input(1);
+
+ TensorShape output_shape;
+ OP_REQUIRES_OK(ctx,
+ ctx->op_kernel().MakeShape(shape_tensor, &output_shape));
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor));
+
+ const Device& d = ctx->eigen_device<Device>();
+ functor::BroadcastTo<Device, T>()(d, ctx, *output_tensor, output_shape,
+ input_tensor, input_shape);
+ }
+};
+
+// As MakeShape is able to handle both DT_INT32 and DT_INT64,
+// no need to have TypeConstraint for `Tidx`
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("BroadcastTo").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ BroadcastToOp<CPUDevice, type>);
+
+TF_CALL_ALL_TYPES(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+
+#if GOOGLE_CUDA
+
+namespace functor {
+#define DECLARE_GPU_TEMPLATE(Type) \
+ template <> \
+ void BroadcastTo<GPUDevice, Type>::operator()( \
+ const GPUDevice& d, OpKernelContext* ctx, Tensor& output, \
+ const TensorShape& output_shape, const Tensor& input, \
+ const TensorShape& input_shape); \
+ extern template struct BroadcastTo<GPUDevice, Type>;
+
+TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_TEMPLATE);
+#undef DECLARE_GPU_KERNEL
+} // namespace functor
+
+#define REGISTER_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("BroadcastTo") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("shape"), \
+ BroadcastToOp<GPUDevice, type>);
+
+TF_CALL_GPU_ALL_TYPES(REGISTER_KERNEL);
+#undef REGISTER_KERNEL
+#endif
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/broadcast_to_op.h b/tensorflow/core/kernels/broadcast_to_op.h
new file mode 100644
index 0000000000..608e9b6ac9
--- /dev/null
+++ b/tensorflow/core/kernels/broadcast_to_op.h
@@ -0,0 +1,220 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
+#define TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+
+namespace functor {
+
+template <typename Device, typename T>
+struct BroadcastTo {
+ void operator()(const Device &d, OpKernelContext *ctx, Tensor &output_tensor,
+ const TensorShape &output_shape, const Tensor &input_tensor,
+ const TensorShape &input_shape) {
+#define BROADCAST_SHAPE(broadcast, reshape, NDIMS, input_shape, output_shape) \
+ for (int i = 0; i < NDIMS; i++) { \
+ OP_REQUIRES(ctx, (broadcast[i] % reshape[i] == 0), \
+ errors::InvalidArgument("invalid shape to broadcast from ", \
+ input_shape.DebugString(), " to ", \
+ output_shape.DebugString())); \
+ broadcast[i] = broadcast[i] / reshape[i]; \
+ }
+
+ switch (output_shape.dims()) {
+ case 1: {
+ auto reshape = AsEigenDSizesWithPrefix<1>(input_shape);
+ auto broadcast = output_shape.AsEigenDSizes<1>();
+
+ BROADCAST_SHAPE(broadcast, reshape, 1, input_shape, output_shape);
+
+ auto output = output_tensor.tensor<T, 1>();
+ switch (input_shape.dims()) {
+ case 0: {
+ output.device(d) = output.constant(input_tensor.scalar<T>()());
+ } break;
+ case 1: {
+ auto input = input_tensor.tensor<T, 1>();
+ output.device(d) = input.broadcast(broadcast);
+ } break;
+ default:
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ } break;
+ case 2: {
+ auto reshape = AsEigenDSizesWithPrefix<2>(input_shape);
+ auto broadcast = output_shape.AsEigenDSizes<2>();
+
+ BROADCAST_SHAPE(broadcast, reshape, 2, input_shape, output_shape);
+
+ auto output = output_tensor.tensor<T, 2>();
+ switch (input_shape.dims()) {
+ case 0: {
+ output.device(d) = output.constant(input_tensor.scalar<T>()());
+ } break;
+ case 1: {
+ auto input = input_tensor.tensor<T, 1>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 2: {
+ auto input = input_tensor.tensor<T, 2>();
+ output.device(d) = input.broadcast(broadcast);
+ } break;
+ default:
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ } break;
+ case 3: {
+ auto reshape = AsEigenDSizesWithPrefix<3>(input_shape);
+ auto broadcast = output_shape.AsEigenDSizes<3>();
+
+ BROADCAST_SHAPE(broadcast, reshape, 3, input_shape, output_shape);
+
+ auto output = output_tensor.tensor<T, 3>();
+ switch (input_shape.dims()) {
+ case 0: {
+ output.device(d) = output.constant(input_tensor.scalar<T>()());
+ } break;
+ case 1: {
+ auto input = input_tensor.tensor<T, 1>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 2: {
+ auto input = input_tensor.tensor<T, 2>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 3: {
+ auto input = input_tensor.tensor<T, 3>();
+ output.device(d) = input.broadcast(broadcast);
+ } break;
+ default:
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ } break;
+ case 4: {
+ auto reshape = AsEigenDSizesWithPrefix<4>(input_shape);
+ auto broadcast = output_shape.AsEigenDSizes<4>();
+
+ BROADCAST_SHAPE(broadcast, reshape, 4, input_shape, output_shape);
+
+ auto output = output_tensor.tensor<T, 4>();
+ switch (input_shape.dims()) {
+ case 0: {
+ output.device(d) = output.constant(input_tensor.scalar<T>()());
+ } break;
+ case 1: {
+ auto input = input_tensor.tensor<T, 1>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 2: {
+ auto input = input_tensor.tensor<T, 2>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 3: {
+ auto input = input_tensor.tensor<T, 3>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 4: {
+ auto input = input_tensor.tensor<T, 4>();
+ output.device(d) = input.broadcast(broadcast);
+ } break;
+ default:
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ } break;
+ case 5: {
+ auto reshape = AsEigenDSizesWithPrefix<5>(input_shape);
+ auto broadcast = output_shape.AsEigenDSizes<5>();
+
+ BROADCAST_SHAPE(broadcast, reshape, 5, input_shape, output_shape);
+ auto output = output_tensor.tensor<T, 5>();
+ switch (input_shape.dims()) {
+ case 0: {
+ output.device(d) = output.constant(input_tensor.scalar<T>()());
+ } break;
+ case 1: {
+ auto input = input_tensor.tensor<T, 1>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 2: {
+ auto input = input_tensor.tensor<T, 2>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 3: {
+ auto input = input_tensor.tensor<T, 3>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 4: {
+ auto input = input_tensor.tensor<T, 4>();
+ output.device(d) = input.reshape(reshape).broadcast(broadcast);
+ } break;
+ case 5: {
+ auto input = input_tensor.tensor<T, 5>();
+ output.device(d) = input.broadcast(broadcast);
+ } break;
+ default:
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ } break;
+ default:
+ ctx->CtxFailure(errors::InvalidArgument(
+ "invalid shape to broadcast from ", input_shape.DebugString(),
+ " to ", output_shape.DebugString()));
+ break;
+ }
+ }
+
+ private:
+ template <int NDIMS>
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPrefix(
+ const TensorShape &shape) const {
+ Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes;
+ for (int d = 0; d < NDIMS - shape.dims(); d++) {
+ dsizes[d] = 1;
+ }
+ for (int d = NDIMS - shape.dims(); d < NDIMS; d++) {
+ dsizes[d] = shape.dim_size(d - (NDIMS - shape.dims()));
+ }
+ return dsizes;
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
diff --git a/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc b/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc
new file mode 100644
index 0000000000..6459571085
--- /dev/null
+++ b/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/broadcast_to_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+#define INSTANTIATE_GPU_KERNEL(Type) \
+ template class functor::BroadcastTo<GPUDevice, Type>;
+TF_CALL_GPU_ALL_TYPES(INSTANTIATE_GPU_KERNEL);
+#undef INSTANTIATE_GPU_KERNEL
+
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 4215c4541c..d2c8020bb6 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -139,9 +139,8 @@ class ConvParameters {
bool ShouldIncludeWinogradNonfusedAlgo(
se::StreamExecutor* stream_exec) const {
// Skip this check for cuDNN 7 and newer.
- se::port::StatusOr<std::tuple<int, int, int>> version =
- stream_exec->AsDnn()->GetVersion();
- if (version.ok() && std::get<0>(version.ValueOrDie()) >= 7) {
+ auto version = stream_exec->AsDnn()->GetVersion();
+ if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
return ShouldIncludeWinogradNonfusedAlgoPreCudnn7<T>();
diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc
index 96bdb6a241..8cadeac68d 100644
--- a/tensorflow/core/kernels/ctc_decoder_ops.cc
+++ b/tensorflow/core/kernels/ctc_decoder_ops.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/ctc/ctc_beam_search.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
+#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
@@ -213,20 +214,29 @@ class CTCGreedyDecoderOp : public OpKernel {
// Perform best path decoding
std::vector<std::vector<std::vector<int> > > sequences(batch_size);
- for (int b = 0; b < batch_size; ++b) {
- sequences[b].resize(1);
- auto& sequence = sequences[b][0];
- int prev_indices = -1;
- for (int t = 0; t < seq_len_t(b); ++t) {
- int max_class_indices;
- log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
- if (max_class_indices != blank_index &&
- !(merge_repeated_ && max_class_indices == prev_indices)) {
- sequence.push_back(max_class_indices);
+ auto decode = [&](const int64 begin, const int64 end) {
+ for (int b = begin; b < end; ++b) {
+ sequences[b].resize(1);
+ auto &sequence = sequences[b][0];
+ int prev_indices = -1;
+ for (int t = 0; t < seq_len_t(b); ++t) {
+ int max_class_indices;
+ log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices);
+ if (max_class_indices != blank_index &&
+ !(merge_repeated_ && max_class_indices == prev_indices)) {
+ sequence.push_back(max_class_indices);
+ }
+ prev_indices = max_class_indices;
}
- prev_indices = max_class_indices;
}
- }
+ };
+
+ const int64 kCostPerUnit = 50 * max_time * num_classes;
+ const int64 total = batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *ctx->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, total,
+ kCostPerUnit, decode);
OP_REQUIRES_OK(
ctx, decode_helper_.StoreAllDecodedSequences(
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index ea763ce85b..cda1402b03 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -312,9 +312,8 @@ class MklInputConversionOp : public OpKernel {
VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
"different, "
<< "need to convert to same format";
-
- // Convert input0, and keep input1 unchanged
- // Create MklDnnShape for output mkl tensor based on input0
+ // TODO: For now, input0 is converted and input1 is unchanged
+ // we should choose the optimal MKL format to convert to.
Tensor* tensor_out;
MklDnnShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(true);
@@ -362,7 +361,8 @@ class MklInputConversionOp : public OpKernel {
// with MKL tensors)
VLOG(1) << "MklInputConversionOp: Broadcast needed, "
<< "converted MKL inputs to TF format";
-
+ // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters
+ // are removed from ConvertMklToTf
MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
op_data_type, has_avx512f_,
kInputIndex_0);
@@ -403,19 +403,7 @@ class MklInputConversionOp : public OpKernel {
}
// Broadcast is needed if the shapes are not the same
- bool broadcast_needed;
-
- size_t in0_size = 1;
- for (size_t i = 0; i < mkl_shape->GetDimension(); ++i)
- in0_size *= mkl_shape->TfDimSize(i);
-
- size_t in1_size = 1;
- for (size_t i = 0; i < tf_tensor->shape().dims(); ++i)
- in1_size *= tf_tensor->shape().dim_size(i);
-
- broadcast_needed = (in0_size != in1_size);
-
- if (!broadcast_needed) {
+ if (mkl_shape->GetTfShape().num_elements() == tf_tensor->shape().num_elements() ) {
// Both shapes are same, convert the TF input to MKL
VLOG(1) << "MklInputConversionOp: No broadcast needed.";
VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
@@ -446,10 +434,19 @@ class MklInputConversionOp : public OpKernel {
// Create reorder between tensorflow layout and Mkl layout if necessary
std::vector<primitive> net;
- tf_input.CheckReorderToOpMem(
+ bool reordered = tf_input.CheckReorderToOpMem(
memory::primitive_desc(output_mkl_md, cpu_engine),
tensor_out, &net);
- stream(stream::kind::eager).submit(net).wait();
+ if(!reordered) {
+ // This is the case that the TF tensor has the same shape and format of
+ // mkl tensor. However, tf_tensor can not be simply forwarded to the output
+ // tensor since mkl data tensor is always one dimensional tensor.
+ // Tensor::CopyFrom shares the buffer of the other tensor while set its shape
+ // to the other tensor.
+ tensor_out->CopyFrom(*tf_tensor, tensor_out->shape());
+ }
+ else
+ stream(stream::kind::eager).submit(net).wait();
// -- The tensor in MKL format passes through --
ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 0a0f69522f..1ed43834dd 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -441,7 +441,9 @@ class MklReluOpBase : public OpKernel {
// Allocate output and MklDnnShape tensors separately for possible
// in-place operation
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {src_index}, dst_index, tf_shape_dst, &dst_tensor));
+ {static_cast<const int>(src_index)},
+ static_cast<const int>(dst_index),
+ tf_shape_dst, &dst_tensor));
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
// Destination memory descriptor is same as source memory descriptor.
@@ -611,7 +613,9 @@ class MklReluGradOpBase : public OpKernel {
// Allocate diff_src and MklDnnShape tensors separately for possible
// in-place operation
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {diff_dst_index}, diff_src_index, tf_shape_diff_src,
+ {static_cast<const int>(diff_dst_index)},
+ static_cast<const int>(diff_src_index),
+ tf_shape_diff_src,
&diff_src_tensor));
AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
index bcbdbee058..4b630809c5 100644
--- a/tensorflow/core/kernels/roll_op.cc
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -254,8 +254,11 @@ class RollOp : public OpKernel {
// total modulo sum of shifts for each dimension
gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0);
for (int i = 0; i < num_shifts; i++) {
- const int axis = axis_flat(i);
- OP_REQUIRES(context, axis < num_dims,
+ int axis = axis_flat(i);
+ if (axis < 0) {
+ axis += num_dims;
+ }
+ OP_REQUIRES(context, 0 <= axis && axis < num_dims,
errors::InvalidArgument("axis ", axis, " is out of range"));
const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 183e5a1d58..bedd965966 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -16,6 +16,14 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
#define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_
+
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
+
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
new file mode 100644
index 0000000000..ae700f4294
--- /dev/null
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/string_ops.cc.
+
+#include <string>
+
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+
+class StringStripOp : public OpKernel {
+ public:
+ explicit StringStripOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ Tensor* output_tensor;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor));
+
+ const auto input = input_tensor->flat<string>();
+ auto output = output_tensor->flat<string>();
+
+ for (int64 i = 0; i < input.size(); ++i) {
+ StringPiece entry(input(i));
+ str_util::RemoveWhitespaceContext(&entry);
+ output(i) = entry.ToString();
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringStrip").Device(DEVICE_CPU), StringStripOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index f53c567c4d..5b13b10937 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -330,6 +330,27 @@ struct ApplyAdamSYCL {
template <typename T>
struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {};
+template <typename Device, typename T>
+struct ApplyAdaMaxNonCuda {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ m.device(d) += (grad - m) * (T(1) - beta1());
+ // Here v is u in section 7.1
+ v.device(d) = (beta2() * v).cwiseMax(grad.abs());
+ // var is θ in section 7.1
+ var.device(d) -= lr() / (T(1) - beta1_power()) * (m / (v + epsilon()));
+ }
+};
+
+template <typename T>
+struct ApplyAdaMax<CPUDevice, T> : ApplyAdaMaxNonCuda<CPUDevice, T> {};
+
template <typename T>
struct ApplyRMSProp<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
@@ -2752,6 +2773,135 @@ REGISTER_KERNELS(GPU, double);
#undef REGISTER_KERNELS
template <typename Device, typename T>
+class ApplyAdaMaxOp : public OpKernel {
+ public:
+ explicit ApplyAdaMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_,
+ {0, 1, 2});
+
+ Tensor var;
+ OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+ ctx, 0, use_exclusive_lock_, false, &var));
+ Tensor m;
+ OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+ ctx, 1, use_exclusive_lock_, false, &m));
+ Tensor v;
+ OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>(
+ ctx, 2, use_exclusive_lock_, false, &v));
+ OP_REQUIRES(
+ ctx, var.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", requested_input(0)));
+ OP_REQUIRES(
+ ctx, m.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", requested_input(1)));
+ OP_REQUIRES(
+ ctx, v.IsInitialized(),
+ errors::FailedPrecondition(
+ "Attempting to use uninitialized variables: ", requested_input(2)));
+
+ const Tensor& beta1_power = ctx->input(3);
+ const Tensor& lr = ctx->input(4);
+ const Tensor& beta1 = ctx->input(5);
+ const Tensor& beta2 = ctx->input(6);
+ const Tensor& epsilon = ctx->input(7);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()),
+ errors::InvalidArgument("beta1_power is not a scalar: ",
+ beta1_power.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()),
+ errors::InvalidArgument("lr is not a scalar : ",
+ lr.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()),
+ errors::InvalidArgument("beta1 is not a scalar: ",
+ beta1.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()),
+ errors::InvalidArgument("beta2 is not a scalar: ",
+ beta2.shape().DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()),
+ errors::InvalidArgument("epsilon is not a scalar: ",
+ epsilon.shape().DebugString()));
+
+ const Tensor& grad = ctx->input(8);
+ OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()),
+ errors::InvalidArgument("var and m do not have the same shape",
+ var.shape().DebugString(), " ",
+ m.shape().DebugString()));
+ OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()),
+ errors::InvalidArgument("var and v do not have the same shape",
+ var.shape().DebugString(), " ",
+ v.shape().DebugString()));
+ OP_REQUIRES(
+ ctx, var.shape().IsSameSize(grad.shape()),
+ errors::InvalidArgument("var and grad do not have the same shape",
+ var.shape().DebugString(), " ",
+ grad.shape().DebugString()));
+
+ const Device& device = ctx->template eigen_device<Device>();
+ functor::ApplyAdaMax<Device, T>()(
+ device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
+ beta1_power.scalar<T>(), lr.scalar<T>(),
+ beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
+ grad.flat<T>());
+
+ MaybeForwardRefInputToRefOutput(ctx, 0, 0);
+ }
+
+ private:
+ bool use_exclusive_lock_;
+};
+
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
+ ApplyAdaMaxOp<D##Device, T>); \
+ REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
+ .HostMemory("var") \
+ .HostMemory("m") \
+ .HostMemory("v") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
+ ApplyAdaMaxOp<D##Device, T>);
+#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
+
+TF_CALL_half(REGISTER_CPU_KERNELS);
+TF_CALL_float(REGISTER_CPU_KERNELS);
+TF_CALL_double(REGISTER_CPU_KERNELS);
+
+#if GOOGLE_CUDA
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ApplyAdaMax<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::Flat var, \
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
+ typename TTypes<T>::ConstScalar beta1_power, \
+ typename TTypes<T>::ConstScalar lr, \
+ typename TTypes<T>::ConstScalar beta1, \
+ typename TTypes<T>::ConstScalar beta2, \
+ typename TTypes<T>::ConstScalar epsilon, \
+ typename TTypes<T>::ConstFlat grad); \
+ extern template struct ApplyAdaMax<GPUDevice, T>;
+DECLARE_GPU_SPEC(Eigen::half);
+DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(double);
+#undef DECLARE_GPU_SPEC
+} // namespace functor
+
+REGISTER_KERNELS(GPU, Eigen::half);
+REGISTER_KERNELS(GPU, float);
+REGISTER_KERNELS(GPU, double);
+#endif
+#undef REGISTER_CPU_KERNELS
+#undef REGISTER_KERNELS
+
+template <typename Device, typename T>
class ApplyRMSPropOp : public OpKernel {
public:
explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index 7ee956053a..f536a61eb0 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -140,6 +140,18 @@ struct ApplyAdam {
};
template <typename Device, typename T>
+struct ApplyAdaMax {
+ void operator()(const Device& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad);
+};
+
+template <typename Device, typename T>
struct ApplyRMSProp {
void operator()(const Device& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index 0376a3b2c6..2aa17f2a0f 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -143,6 +143,32 @@ struct ApplyAdam<GPUDevice, T> {
};
template <typename T>
+struct ApplyAdaMax<GPUDevice, T> {
+ void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
+ typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
+ typename TTypes<T>::ConstScalar beta1_power,
+ typename TTypes<T>::ConstScalar lr,
+ typename TTypes<T>::ConstScalar beta1,
+ typename TTypes<T>::ConstScalar beta2,
+ typename TTypes<T>::ConstScalar epsilon,
+ typename TTypes<T>::ConstFlat grad) {
+ Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
+ bcast[0] = grad.dimension(0);
+ Eigen::Sizes<1> single;
+ const auto one = static_cast<T>(1.0);
+ m.device(d) =
+ m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
+ (grad - m);
+ v.device(d) =
+ (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs());
+ var.device(d) -=
+ lr / (beta1_power.constant(one) -
+ beta1_power).reshape(single).broadcast(bcast) *
+ (m / (v + epsilon));
+ }
+};
+
+template <typename T>
struct ApplyRMSProp<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
@@ -278,6 +304,10 @@ template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
template struct functor::ApplyAdam<GPUDevice, float>;
template struct functor::ApplyAdam<GPUDevice, double>;
+template struct functor::ApplyAdaMax<GPUDevice, Eigen::half>;
+template struct functor::ApplyAdaMax<GPUDevice, float>;
+template struct functor::ApplyAdaMax<GPUDevice, double>;
+
template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
template struct functor::ApplyRMSProp<GPUDevice, float>;
template struct functor::ApplyRMSProp<GPUDevice, double>;