diff options
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/BUILD | 50 | ||||
-rw-r--r-- | tensorflow/core/kernels/batching_util/shared_batch_scheduler.h | 6 | ||||
-rw-r--r-- | tensorflow/core/kernels/broadcast_to_op.cc | 91 | ||||
-rw-r--r-- | tensorflow/core/kernels/broadcast_to_op.h | 220 | ||||
-rw-r--r-- | tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/kernels/conv_ops_gpu.h | 5 | ||||
-rw-r--r-- | tensorflow/core/kernels/ctc_decoder_ops.cc | 34 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_input_conversion_op.cc | 35 | ||||
-rw-r--r-- | tensorflow/core/kernels/mkl_relu_op.cc | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/roll_op.cc | 7 | ||||
-rw-r--r-- | tensorflow/core/kernels/segment_reduction_ops.h | 8 | ||||
-rw-r--r-- | tensorflow/core/kernels/string_strip_op.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/kernels/training_ops.cc | 150 | ||||
-rw-r--r-- | tensorflow/core/kernels/training_ops.h | 12 | ||||
-rw-r--r-- | tensorflow/core/kernels/training_ops_gpu.cu.cc | 30 |
15 files changed, 678 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f7f6a9b505..201cd35798 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -617,6 +617,7 @@ cc_library( ":batch_space_ops", ":bcast_ops", ":bitcast_op", + ":broadcast_to_op", ":concat_op", ":constant_op", ":depth_space_ops", @@ -669,6 +670,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "broadcast_to_op", + prefix = "broadcast_to_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( name = "concat_op", prefix = "concat_op", deps = ARRAY_DEPS, @@ -4227,6 +4234,7 @@ cc_library( ":regex_replace_op", ":string_join_op", ":string_split_op", + ":string_strip_op", ":string_to_hash_bucket_op", ":substr_op", ], @@ -4272,6 +4280,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "string_strip_op", + prefix = "string_strip_op", + deps = STRING_DEPS, +) + +tf_kernel_library( name = "substr_op", prefix = "substr_op", deps = STRING_DEPS, @@ -5947,8 +5961,7 @@ tf_mkl_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -5963,8 +5976,7 @@ tf_mkl_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -5980,8 +5992,7 @@ tf_mkl_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6001,8 +6012,7 @@ tf_mkl_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6018,8 +6028,7 @@ tf_mkl_kernel_library( "//tensorflow/core:nn_ops_op_lib", "//third_party/eigen3", "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6035,8 +6044,7 @@ tf_mkl_kernel_library( "//tensorflow/core:nn_ops_op_lib", "//third_party/eigen3", "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6044,8 +6052,7 @@ tf_mkl_kernel_library( srcs = ["mkl_fused_batch_norm_op.cc"], deps = NN_DEPS + [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6053,8 +6060,7 @@ tf_mkl_kernel_library( prefix = "mkl_aggregate_ops", deps = MATH_DEPS + [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6062,8 +6068,7 @@ tf_mkl_kernel_library( prefix = "mkl_concat_op", deps = ARRAY_DEPS + [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6071,8 +6076,7 @@ tf_mkl_kernel_library( prefix = "mkl_reshape_op", deps = ARRAY_DEPS + [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6080,8 +6084,7 @@ tf_mkl_kernel_library( prefix = "mkl_identity_op", deps = ARRAY_DEPS + [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( @@ -6089,8 +6092,7 @@ tf_mkl_kernel_library( prefix = "mkl_lrn_op", deps = NN_DEPS + [ "//third_party/mkl:intel_binary_blob", - "@mkl_dnn", - ], + ] + if_mkl(["@mkl_dnn"]), ) tf_mkl_kernel_library( diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h index edc88a0384..b4bce90841 100644 --- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h @@ -136,7 +136,7 @@ class SharedBatchScheduler // (inclusive). If there is a need to quantize the batch sizes, i.e. only // submit batches whose size is in a small set of allowed sizes, that can be // done by adding padding in the process-batch callback. - int max_batch_size = 1000; + size_t max_batch_size = 1000; // If a task has been enqueued for this amount of time (in microseconds), // and a thread is available, the scheduler will immediately form a batch @@ -157,7 +157,7 @@ class SharedBatchScheduler // If this limit is reached, Schedule() will return an UNAVAILABLE error. // See the class documentation above for guidelines on how to tune this // parameter. - int max_enqueued_batches = 10; + size_t max_enqueued_batches = 10; }; Status AddQueue(const QueueOptions& options, std::function<void(std::unique_ptr<Batch<TaskType>>)> @@ -394,7 +394,7 @@ Status SharedBatchScheduler<TaskType>::AddQueue( std::function<void(std::unique_ptr<Batch<TaskType>>)> process_batch_callback, std::unique_ptr<BatchScheduler<TaskType>>* queue) { - if (options.max_batch_size <= 0) { + if (options.max_batch_size == 0) { return errors::InvalidArgument("max_batch_size must be positive; was ", options.max_batch_size); } diff --git a/tensorflow/core/kernels/broadcast_to_op.cc b/tensorflow/core/kernels/broadcast_to_op.cc new file mode 100644 index 0000000000..2810925bbc --- /dev/null +++ b/tensorflow/core/kernels/broadcast_to_op.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/broadcast_to_op.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +template <typename Device, typename T> +class BroadcastToOp : public OpKernel { + public: + explicit BroadcastToOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& input_tensor = ctx->input(0); + const TensorShape& input_shape = input_tensor.shape(); + + const Tensor& shape_tensor = ctx->input(1); + + TensorShape output_shape; + OP_REQUIRES_OK(ctx, + ctx->op_kernel().MakeShape(shape_tensor, &output_shape)); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor)); + + const Device& d = ctx->eigen_device<Device>(); + functor::BroadcastTo<Device, T>()(d, ctx, *output_tensor, output_shape, + input_tensor, input_shape); + } +}; + +// As MakeShape is able to handle both DT_INT32 and DT_INT64, +// no need to have TypeConstraint for `Tidx` +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("BroadcastTo").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + BroadcastToOp<CPUDevice, type>); + +TF_CALL_ALL_TYPES(REGISTER_KERNEL); +#undef REGISTER_KERNEL + +#if GOOGLE_CUDA + +namespace functor { +#define DECLARE_GPU_TEMPLATE(Type) \ + template <> \ + void BroadcastTo<GPUDevice, Type>::operator()( \ + const GPUDevice& d, OpKernelContext* ctx, Tensor& output, \ + const TensorShape& output_shape, const Tensor& input, \ + const TensorShape& input_shape); \ + extern template struct BroadcastTo<GPUDevice, Type>; + +TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_TEMPLATE); +#undef DECLARE_GPU_KERNEL +} // namespace functor + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("BroadcastTo") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("shape"), \ + BroadcastToOp<GPUDevice, type>); + +TF_CALL_GPU_ALL_TYPES(REGISTER_KERNEL); +#undef REGISTER_KERNEL +#endif + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/broadcast_to_op.h b/tensorflow/core/kernels/broadcast_to_op.h new file mode 100644 index 0000000000..608e9b6ac9 --- /dev/null +++ b/tensorflow/core/kernels/broadcast_to_op.h @@ -0,0 +1,220 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_ +#define TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { + +namespace functor { + +template <typename Device, typename T> +struct BroadcastTo { + void operator()(const Device &d, OpKernelContext *ctx, Tensor &output_tensor, + const TensorShape &output_shape, const Tensor &input_tensor, + const TensorShape &input_shape) { +#define BROADCAST_SHAPE(broadcast, reshape, NDIMS, input_shape, output_shape) \ + for (int i = 0; i < NDIMS; i++) { \ + OP_REQUIRES(ctx, (broadcast[i] % reshape[i] == 0), \ + errors::InvalidArgument("invalid shape to broadcast from ", \ + input_shape.DebugString(), " to ", \ + output_shape.DebugString())); \ + broadcast[i] = broadcast[i] / reshape[i]; \ + } + + switch (output_shape.dims()) { + case 1: { + auto reshape = AsEigenDSizesWithPrefix<1>(input_shape); + auto broadcast = output_shape.AsEigenDSizes<1>(); + + BROADCAST_SHAPE(broadcast, reshape, 1, input_shape, output_shape); + + auto output = output_tensor.tensor<T, 1>(); + switch (input_shape.dims()) { + case 0: { + output.device(d) = output.constant(input_tensor.scalar<T>()()); + } break; + case 1: { + auto input = input_tensor.tensor<T, 1>(); + output.device(d) = input.broadcast(broadcast); + } break; + default: + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + } break; + case 2: { + auto reshape = AsEigenDSizesWithPrefix<2>(input_shape); + auto broadcast = output_shape.AsEigenDSizes<2>(); + + BROADCAST_SHAPE(broadcast, reshape, 2, input_shape, output_shape); + + auto output = output_tensor.tensor<T, 2>(); + switch (input_shape.dims()) { + case 0: { + output.device(d) = output.constant(input_tensor.scalar<T>()()); + } break; + case 1: { + auto input = input_tensor.tensor<T, 1>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 2: { + auto input = input_tensor.tensor<T, 2>(); + output.device(d) = input.broadcast(broadcast); + } break; + default: + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + } break; + case 3: { + auto reshape = AsEigenDSizesWithPrefix<3>(input_shape); + auto broadcast = output_shape.AsEigenDSizes<3>(); + + BROADCAST_SHAPE(broadcast, reshape, 3, input_shape, output_shape); + + auto output = output_tensor.tensor<T, 3>(); + switch (input_shape.dims()) { + case 0: { + output.device(d) = output.constant(input_tensor.scalar<T>()()); + } break; + case 1: { + auto input = input_tensor.tensor<T, 1>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 2: { + auto input = input_tensor.tensor<T, 2>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 3: { + auto input = input_tensor.tensor<T, 3>(); + output.device(d) = input.broadcast(broadcast); + } break; + default: + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + } break; + case 4: { + auto reshape = AsEigenDSizesWithPrefix<4>(input_shape); + auto broadcast = output_shape.AsEigenDSizes<4>(); + + BROADCAST_SHAPE(broadcast, reshape, 4, input_shape, output_shape); + + auto output = output_tensor.tensor<T, 4>(); + switch (input_shape.dims()) { + case 0: { + output.device(d) = output.constant(input_tensor.scalar<T>()()); + } break; + case 1: { + auto input = input_tensor.tensor<T, 1>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 2: { + auto input = input_tensor.tensor<T, 2>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 3: { + auto input = input_tensor.tensor<T, 3>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 4: { + auto input = input_tensor.tensor<T, 4>(); + output.device(d) = input.broadcast(broadcast); + } break; + default: + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + } break; + case 5: { + auto reshape = AsEigenDSizesWithPrefix<5>(input_shape); + auto broadcast = output_shape.AsEigenDSizes<5>(); + + BROADCAST_SHAPE(broadcast, reshape, 5, input_shape, output_shape); + auto output = output_tensor.tensor<T, 5>(); + switch (input_shape.dims()) { + case 0: { + output.device(d) = output.constant(input_tensor.scalar<T>()()); + } break; + case 1: { + auto input = input_tensor.tensor<T, 1>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 2: { + auto input = input_tensor.tensor<T, 2>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 3: { + auto input = input_tensor.tensor<T, 3>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 4: { + auto input = input_tensor.tensor<T, 4>(); + output.device(d) = input.reshape(reshape).broadcast(broadcast); + } break; + case 5: { + auto input = input_tensor.tensor<T, 5>(); + output.device(d) = input.broadcast(broadcast); + } break; + default: + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + } break; + default: + ctx->CtxFailure(errors::InvalidArgument( + "invalid shape to broadcast from ", input_shape.DebugString(), + " to ", output_shape.DebugString())); + break; + } + } + + private: + template <int NDIMS> + Eigen::DSizes<Eigen::DenseIndex, NDIMS> AsEigenDSizesWithPrefix( + const TensorShape &shape) const { + Eigen::DSizes<Eigen::DenseIndex, NDIMS> dsizes; + for (int d = 0; d < NDIMS - shape.dims(); d++) { + dsizes[d] = 1; + } + for (int d = NDIMS - shape.dims(); d < NDIMS; d++) { + dsizes[d] = shape.dim_size(d - (NDIMS - shape.dims())); + } + return dsizes; + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_ diff --git a/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc b/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc new file mode 100644 index 0000000000..6459571085 --- /dev/null +++ b/tensorflow/core/kernels/broadcast_to_op_gpu.cu.cc @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/broadcast_to_op.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +#define INSTANTIATE_GPU_KERNEL(Type) \ + template class functor::BroadcastTo<GPUDevice, Type>; +TF_CALL_GPU_ALL_TYPES(INSTANTIATE_GPU_KERNEL); +#undef INSTANTIATE_GPU_KERNEL + +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 4215c4541c..d2c8020bb6 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -139,9 +139,8 @@ class ConvParameters { bool ShouldIncludeWinogradNonfusedAlgo( se::StreamExecutor* stream_exec) const { // Skip this check for cuDNN 7 and newer. - se::port::StatusOr<std::tuple<int, int, int>> version = - stream_exec->AsDnn()->GetVersion(); - if (version.ok() && std::get<0>(version.ValueOrDie()) >= 7) { + auto version = stream_exec->AsDnn()->GetVersion(); + if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } return ShouldIncludeWinogradNonfusedAlgoPreCudnn7<T>(); diff --git a/tensorflow/core/kernels/ctc_decoder_ops.cc b/tensorflow/core/kernels/ctc_decoder_ops.cc index 96bdb6a241..8cadeac68d 100644 --- a/tensorflow/core/kernels/ctc_decoder_ops.cc +++ b/tensorflow/core/kernels/ctc_decoder_ops.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/ctc/ctc_beam_search.h" #include "tensorflow/core/util/sparse/sparse_tensor.h" +#include "tensorflow/core/util/work_sharder.h" namespace tensorflow { @@ -213,20 +214,29 @@ class CTCGreedyDecoderOp : public OpKernel { // Perform best path decoding std::vector<std::vector<std::vector<int> > > sequences(batch_size); - for (int b = 0; b < batch_size; ++b) { - sequences[b].resize(1); - auto& sequence = sequences[b][0]; - int prev_indices = -1; - for (int t = 0; t < seq_len_t(b); ++t) { - int max_class_indices; - log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); - if (max_class_indices != blank_index && - !(merge_repeated_ && max_class_indices == prev_indices)) { - sequence.push_back(max_class_indices); + auto decode = [&](const int64 begin, const int64 end) { + for (int b = begin; b < end; ++b) { + sequences[b].resize(1); + auto &sequence = sequences[b][0]; + int prev_indices = -1; + for (int t = 0; t < seq_len_t(b); ++t) { + int max_class_indices; + log_prob_t(b, 0) += -RowMax(input_list_t[t], b, &max_class_indices); + if (max_class_indices != blank_index && + !(merge_repeated_ && max_class_indices == prev_indices)) { + sequence.push_back(max_class_indices); + } + prev_indices = max_class_indices; } - prev_indices = max_class_indices; } - } + }; + + const int64 kCostPerUnit = 50 * max_time * num_classes; + const int64 total = batch_size; + const DeviceBase::CpuWorkerThreads& worker_threads = + *ctx->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, total, + kCostPerUnit, decode); OP_REQUIRES_OK( ctx, decode_helper_.StoreAllDecodedSequences( diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index ea763ce85b..cda1402b03 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -312,9 +312,8 @@ class MklInputConversionOp : public OpKernel { VLOG(1) << "MklInputConversionOp: Shape is same, but format is " "different, " << "need to convert to same format"; - - // Convert input0, and keep input1 unchanged - // Create MklDnnShape for output mkl tensor based on input0 + // TODO: For now, input0 is converted and input1 is unchanged + // we should choose the optimal MKL format to convert to. Tensor* tensor_out; MklDnnShape mkl_output_mkl_shape; mkl_output_mkl_shape.SetMklTensor(true); @@ -362,7 +361,8 @@ class MklInputConversionOp : public OpKernel { // with MKL tensors) VLOG(1) << "MklInputConversionOp: Broadcast needed, " << "converted MKL inputs to TF format"; - + // TODO: Cleanup op_data_type and has_avx512f_ after these two parameters + // are removed from ConvertMklToTf MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, op_data_type, has_avx512f_, kInputIndex_0); @@ -403,19 +403,7 @@ class MklInputConversionOp : public OpKernel { } // Broadcast is needed if the shapes are not the same - bool broadcast_needed; - - size_t in0_size = 1; - for (size_t i = 0; i < mkl_shape->GetDimension(); ++i) - in0_size *= mkl_shape->TfDimSize(i); - - size_t in1_size = 1; - for (size_t i = 0; i < tf_tensor->shape().dims(); ++i) - in1_size *= tf_tensor->shape().dim_size(i); - - broadcast_needed = (in0_size != in1_size); - - if (!broadcast_needed) { + if (mkl_shape->GetTfShape().num_elements() == tf_tensor->shape().num_elements() ) { // Both shapes are same, convert the TF input to MKL VLOG(1) << "MklInputConversionOp: No broadcast needed."; VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index @@ -446,10 +434,19 @@ class MklInputConversionOp : public OpKernel { // Create reorder between tensorflow layout and Mkl layout if necessary std::vector<primitive> net; - tf_input.CheckReorderToOpMem( + bool reordered = tf_input.CheckReorderToOpMem( memory::primitive_desc(output_mkl_md, cpu_engine), tensor_out, &net); - stream(stream::kind::eager).submit(net).wait(); + if(!reordered) { + // This is the case that the TF tensor has the same shape and format of + // mkl tensor. However, tf_tensor can not be simply forwarded to the output + // tensor since mkl data tensor is always one dimensional tensor. + // Tensor::CopyFrom shares the buffer of the other tensor while set its shape + // to the other tensor. + tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()); + } + else + stream(stream::kind::eager).submit(net).wait(); // -- The tensor in MKL format passes through -- ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 0a0f69522f..1ed43834dd 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -441,7 +441,9 @@ class MklReluOpBase : public OpKernel { // Allocate output and MklDnnShape tensors separately for possible // in-place operation OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {src_index}, dst_index, tf_shape_dst, &dst_tensor)); + {static_cast<const int>(src_index)}, + static_cast<const int>(dst_index), + tf_shape_dst, &dst_tensor)); AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst); // Destination memory descriptor is same as source memory descriptor. @@ -611,7 +613,9 @@ class MklReluGradOpBase : public OpKernel { // Allocate diff_src and MklDnnShape tensors separately for possible // in-place operation OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {diff_dst_index}, diff_src_index, tf_shape_diff_src, + {static_cast<const int>(diff_dst_index)}, + static_cast<const int>(diff_src_index), + tf_shape_diff_src, &diff_src_tensor)); AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src); diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc index bcbdbee058..4b630809c5 100644 --- a/tensorflow/core/kernels/roll_op.cc +++ b/tensorflow/core/kernels/roll_op.cc @@ -254,8 +254,11 @@ class RollOp : public OpKernel { // total modulo sum of shifts for each dimension gtl::InlinedVector<int, 4> shift_mod_sum(num_dims, 0); for (int i = 0; i < num_shifts; i++) { - const int axis = axis_flat(i); - OP_REQUIRES(context, axis < num_dims, + int axis = axis_flat(i); + if (axis < 0) { + axis += num_dims; + } + OP_REQUIRES(context, 0 <= axis && axis < num_dims, errors::InvalidArgument("axis ", axis, " is out of range")); const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1); const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i)); diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 183e5a1d58..bedd965966 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -16,6 +16,14 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #define TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ + +// This file requires the following include because it uses CudaAtomicMax: +// #include "tensorflow/core/util/cuda_kernel_helper.h" + +// Unfortunately we can't add the #include, since it breaks compilation for +// non-GPU targets. This only breaks in clang, because it's more strict for +// template code and CudaAtomicMax is used in template context. + // This file requires the following include because it uses CudaAtomicMax: // #include "tensorflow/core/util/cuda_kernel_helper.h" diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc new file mode 100644 index 0000000000..ae700f4294 --- /dev/null +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -0,0 +1,53 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/string_ops.cc. + +#include <string> + +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +class StringStripOp : public OpKernel { + public: + explicit StringStripOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + Tensor* output_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &output_tensor)); + + const auto input = input_tensor->flat<string>(); + auto output = output_tensor->flat<string>(); + + for (int64 i = 0; i < input.size(); ++i) { + StringPiece entry(input(i)); + str_util::RemoveWhitespaceContext(&entry); + output(i) = entry.ToString(); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("StringStrip").Device(DEVICE_CPU), StringStripOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index f53c567c4d..5b13b10937 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -330,6 +330,27 @@ struct ApplyAdamSYCL { template <typename T> struct ApplyAdam<CPUDevice, T> : ApplyAdamNonCuda<CPUDevice, T> {}; +template <typename Device, typename T> +struct ApplyAdaMaxNonCuda { + void operator()(const Device& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, + typename TTypes<T>::ConstScalar beta1_power, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar beta1, + typename TTypes<T>::ConstScalar beta2, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad) { + m.device(d) += (grad - m) * (T(1) - beta1()); + // Here v is u in section 7.1 + v.device(d) = (beta2() * v).cwiseMax(grad.abs()); + // var is θ in section 7.1 + var.device(d) -= lr() / (T(1) - beta1_power()) * (m / (v + epsilon())); + } +}; + +template <typename T> +struct ApplyAdaMax<CPUDevice, T> : ApplyAdaMaxNonCuda<CPUDevice, T> {}; + template <typename T> struct ApplyRMSProp<CPUDevice, T> { void operator()(const CPUDevice& d, typename TTypes<T>::Flat var, @@ -2752,6 +2773,135 @@ REGISTER_KERNELS(GPU, double); #undef REGISTER_KERNELS template <typename Device, typename T> +class ApplyAdaMaxOp : public OpKernel { + public: + explicit ApplyAdaMaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_)); + } + + void Compute(OpKernelContext* ctx) override { + auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, + {0, 1, 2}); + + Tensor var; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( + ctx, 0, use_exclusive_lock_, false, &var)); + Tensor m; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( + ctx, 1, use_exclusive_lock_, false, &m)); + Tensor v; + OP_REQUIRES_OK(ctx, GetInputTensorFromVariable<Device, T>( + ctx, 2, use_exclusive_lock_, false, &v)); + OP_REQUIRES( + ctx, var.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(0))); + OP_REQUIRES( + ctx, m.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(1))); + OP_REQUIRES( + ctx, v.IsInitialized(), + errors::FailedPrecondition( + "Attempting to use uninitialized variables: ", requested_input(2))); + + const Tensor& beta1_power = ctx->input(3); + const Tensor& lr = ctx->input(4); + const Tensor& beta1 = ctx->input(5); + const Tensor& beta2 = ctx->input(6); + const Tensor& epsilon = ctx->input(7); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1_power.shape()), + errors::InvalidArgument("beta1_power is not a scalar: ", + beta1_power.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr.shape()), + errors::InvalidArgument("lr is not a scalar : ", + lr.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta1.shape()), + errors::InvalidArgument("beta1 is not a scalar: ", + beta1.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(beta2.shape()), + errors::InvalidArgument("beta2 is not a scalar: ", + beta2.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(epsilon.shape()), + errors::InvalidArgument("epsilon is not a scalar: ", + epsilon.shape().DebugString())); + + const Tensor& grad = ctx->input(8); + OP_REQUIRES(ctx, var.shape().IsSameSize(m.shape()), + errors::InvalidArgument("var and m do not have the same shape", + var.shape().DebugString(), " ", + m.shape().DebugString())); + OP_REQUIRES(ctx, var.shape().IsSameSize(v.shape()), + errors::InvalidArgument("var and v do not have the same shape", + var.shape().DebugString(), " ", + v.shape().DebugString())); + OP_REQUIRES( + ctx, var.shape().IsSameSize(grad.shape()), + errors::InvalidArgument("var and grad do not have the same shape", + var.shape().DebugString(), " ", + grad.shape().DebugString())); + + const Device& device = ctx->template eigen_device<Device>(); + functor::ApplyAdaMax<Device, T>()( + device, var.flat<T>(), m.flat<T>(), v.flat<T>(), + beta1_power.scalar<T>(), lr.scalar<T>(), + beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(), + grad.flat<T>()); + + MaybeForwardRefInputToRefOutput(ctx, 0, 0); + } + + private: + bool use_exclusive_lock_; +}; + +#define REGISTER_KERNELS(D, T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \ + ApplyAdaMaxOp<D##Device, T>); \ + REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \ + .HostMemory("var") \ + .HostMemory("m") \ + .HostMemory("v") \ + .Device(DEVICE_##D) \ + .TypeConstraint<T>("T"), \ + ApplyAdaMaxOp<D##Device, T>); +#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T); + +TF_CALL_half(REGISTER_CPU_KERNELS); +TF_CALL_float(REGISTER_CPU_KERNELS); +TF_CALL_double(REGISTER_CPU_KERNELS); + +#if GOOGLE_CUDA +// Forward declarations of the functor specializations for GPU. +namespace functor { +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ApplyAdaMax<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::Flat var, \ + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \ + typename TTypes<T>::ConstScalar beta1_power, \ + typename TTypes<T>::ConstScalar lr, \ + typename TTypes<T>::ConstScalar beta1, \ + typename TTypes<T>::ConstScalar beta2, \ + typename TTypes<T>::ConstScalar epsilon, \ + typename TTypes<T>::ConstFlat grad); \ + extern template struct ApplyAdaMax<GPUDevice, T>; +DECLARE_GPU_SPEC(Eigen::half); +DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(double); +#undef DECLARE_GPU_SPEC +} // namespace functor + +REGISTER_KERNELS(GPU, Eigen::half); +REGISTER_KERNELS(GPU, float); +REGISTER_KERNELS(GPU, double); +#endif +#undef REGISTER_CPU_KERNELS +#undef REGISTER_KERNELS + +template <typename Device, typename T> class ApplyRMSPropOp : public OpKernel { public: explicit ApplyRMSPropOp(OpKernelConstruction* ctx) : OpKernel(ctx) { diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h index 7ee956053a..f536a61eb0 100644 --- a/tensorflow/core/kernels/training_ops.h +++ b/tensorflow/core/kernels/training_ops.h @@ -140,6 +140,18 @@ struct ApplyAdam { }; template <typename Device, typename T> +struct ApplyAdaMax { + void operator()(const Device& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, + typename TTypes<T>::ConstScalar beta1_power, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar beta1, + typename TTypes<T>::ConstScalar beta2, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad); +}; + +template <typename Device, typename T> struct ApplyRMSProp { void operator()(const Device& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc index 0376a3b2c6..2aa17f2a0f 100644 --- a/tensorflow/core/kernels/training_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc @@ -143,6 +143,32 @@ struct ApplyAdam<GPUDevice, T> { }; template <typename T> +struct ApplyAdaMax<GPUDevice, T> { + void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, + typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, + typename TTypes<T>::ConstScalar beta1_power, + typename TTypes<T>::ConstScalar lr, + typename TTypes<T>::ConstScalar beta1, + typename TTypes<T>::ConstScalar beta2, + typename TTypes<T>::ConstScalar epsilon, + typename TTypes<T>::ConstFlat grad) { + Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; + bcast[0] = grad.dimension(0); + Eigen::Sizes<1> single; + const auto one = static_cast<T>(1.0); + m.device(d) = + m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * + (grad - m); + v.device(d) = + (beta2.reshape(single).broadcast(bcast) * v).cwiseMax(grad.abs()); + var.device(d) -= + lr / (beta1_power.constant(one) - + beta1_power).reshape(single).broadcast(bcast) * + (m / (v + epsilon)); + } +}; + +template <typename T> struct ApplyRMSProp<GPUDevice, T> { void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, @@ -278,6 +304,10 @@ template struct functor::ApplyAdam<GPUDevice, Eigen::half>; template struct functor::ApplyAdam<GPUDevice, float>; template struct functor::ApplyAdam<GPUDevice, double>; +template struct functor::ApplyAdaMax<GPUDevice, Eigen::half>; +template struct functor::ApplyAdaMax<GPUDevice, float>; +template struct functor::ApplyAdaMax<GPUDevice, double>; + template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>; template struct functor::ApplyRMSProp<GPUDevice, float>; template struct functor::ApplyRMSProp<GPUDevice, double>; |