aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/BUILD29
-rw-r--r--tensorflow/core/kernels/as_string_op.cc11
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc3
-rw-r--r--tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h57
-rw-r--r--tensorflow/core/kernels/cast_op.cc4
-rw-r--r--tensorflow/core/kernels/cast_op.h164
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc48
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h152
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bfloat.cc11
-rw-r--r--tensorflow/core/kernels/cast_op_impl_bool.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex128.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_complex64.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_double.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_float.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_half.cc6
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int16.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int32.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int64.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_int8.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint16.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint32.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint64.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_impl_uint8.cc9
-rw-r--r--tensorflow/core/kernels/cast_op_test.cc29
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc4
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc72
-rw-r--r--tensorflow/core/kernels/cwise_op_tan.cc3
-rw-r--r--tensorflow/core/kernels/data/BUILD62
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc392
-rw-r--r--tensorflow/core/kernels/data/filter_by_component_dataset_op.cc169
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc292
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.h41
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc717
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.h140
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc192
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc270
-rw-r--r--tensorflow/core/kernels/data/optional_ops.h36
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc286
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc318
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h44
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc550
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.h39
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc2
-rw-r--r--tensorflow/core/kernels/function_ops.cc296
-rw-r--r--tensorflow/core/kernels/function_ops.h79
-rw-r--r--tensorflow/core/kernels/functional_ops.cc80
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc24
-rw-r--r--tensorflow/core/kernels/matmul_op.cc34
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc277
-rw-r--r--tensorflow/core/kernels/mkl_concat_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc9
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc9
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc17
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h1
-rw-r--r--tensorflow/core/kernels/mkl_fused_batch_norm_op.cc908
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc30
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc255
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc181
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h435
-rw-r--r--tensorflow/core/kernels/mkl_reshape_op.cc6
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc111
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op_test.cc55
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc33
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op.h16
-rw-r--r--tensorflow/core/kernels/quantize_and_dequantize_op_test.cc16
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc68
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc173
-rw-r--r--tensorflow/core/kernels/softmax_op.cc9
-rw-r--r--tensorflow/core/kernels/softmax_op_gpu.cu.cc7
-rw-r--r--tensorflow/core/kernels/spacetobatch_op.cc113
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc4
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc9
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h14
73 files changed, 5039 insertions, 2492 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 2cb54bd973..d142e36772 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -22,6 +22,7 @@ package_group(
"//learning/brain/research/sparse_matrix/...",
"//learning/faster_training/...",
"//tensorflow/...",
+ "//third_party/car/...",
],
)
@@ -124,6 +125,7 @@ tf_kernel_library(
":bounds_check",
":dense_update_functor",
":ops_util",
+ ":training_op_helpers",
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -781,7 +783,7 @@ tf_kernel_library(
tf_kernel_library(
name = "quantize_and_dequantize_op",
prefix = "quantize_and_dequantize_op",
- deps = ARRAY_DEPS,
+ deps = ARRAY_DEPS + [":cwise_op"],
)
tf_kernel_library(
@@ -2346,6 +2348,22 @@ tf_cuda_cc_test(
)
tf_cuda_cc_test(
+ name = "crop_and_resize_op_benchmark_test",
+ srcs = ["crop_and_resize_op_benchmark_test.cc"],
+ deps = [
+ ":image",
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cuda_cc_test(
name = "resize_benchmark_test",
srcs = ["resize_op_benchmark_test.cc"],
deps = [
@@ -2836,6 +2854,8 @@ tf_kernel_library(
srcs = [] + if_mkl([
"mkl_batch_matmul_op.cc",
]),
+ # <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
+ hdrs = ["batch_matmul_op_impl.h"],
# Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true,
# to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
@@ -3772,7 +3792,7 @@ tf_kernel_library(
"spacetodepth_op.h",
"spacetodepth_op_gpu.cu.cc",
],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -4869,6 +4889,7 @@ filegroup(
"fill_functor.cc",
"fill_functor.h",
"function_ops.cc",
+ "function_ops.h",
"gather_functor.h",
"gather_nd_op.cc",
"gather_nd_op.h",
@@ -5350,10 +5371,6 @@ cc_library(
srcs = if_android(["decode_image_op.cc"]),
copts = tf_copts(),
linkopts = ["-ldl"],
- tags = [
- "manual",
- "notap",
- ],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:android_gif_internal",
diff --git a/tensorflow/core/kernels/as_string_op.cc b/tensorflow/core/kernels/as_string_op.cc
index a7757d1361..e6d6c40f76 100644
--- a/tensorflow/core/kernels/as_string_op.cc
+++ b/tensorflow/core/kernels/as_string_op.cc
@@ -47,6 +47,7 @@ class AsStringOp : public OpKernel {
case DT_FLOAT:
case DT_DOUBLE:
case DT_COMPLEX64:
+ case DT_COMPLEX128:
break;
default:
OP_REQUIRES(ctx, !(scientific || shortest),
@@ -83,6 +84,7 @@ class AsStringOp : public OpKernel {
case DT_FLOAT:
case DT_DOUBLE:
case DT_COMPLEX64:
+ case DT_COMPLEX128:
if (shortest) {
strings::Appendf(&format_, "g");
} else if (scientific) {
@@ -100,7 +102,7 @@ class AsStringOp : public OpKernel {
DataTypeString(dtype)));
}
- if (dtype == DT_COMPLEX64) {
+ if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
format_ = strings::Printf("(%s,%s)", format_.c_str(), format_.c_str());
}
}
@@ -144,6 +146,13 @@ class AsStringOp : public OpKernel {
format_.c_str(), input_flat(i).real(), input_flat(i).imag());
}
} break;
+ case (DT_COMPLEX128): {
+ const auto& input_flat = input_tensor->flat<complex128>();
+ for (int i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = strings::Printf(
+ format_.c_str(), input_flat(i).real(), input_flat(i).imag());
+ }
+ } break;
default:
bool can_encode_type = false;
OP_REQUIRES(context, can_encode_type,
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index fe259c1634..aa7a2752e8 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -31,8 +31,7 @@ TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_GPU);
-// TODO(csigg): Implement Stream::ThenBlasGemv for Eigen::half and uncomment.
-// TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
+TF_CALL_half(REGISTER_BATCH_MATMUL_GPU);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
index b77c14d012..656b6ced6d 100644
--- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
@@ -147,13 +147,21 @@ class AdaptiveSharedBatchScheduler
// Tracks processing latency and adjusts in_flight_batches_limit to minimize.
void CallbackWrapper(const internal::ASBSBatch<TaskType>* batch,
- BatchProcessor callback);
+ BatchProcessor callback, bool is_express);
// Schedules batch if in_flight_batches_limit_ is not met.
void MaybeScheduleNextBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ // Schedules the earliest closed batch in batches_
+ // if batch_thread_pool_ has an idle thead.
+ // Batches scheduled this way are called express batches.
+ // Express batches are not limited by in_flight_batches_limit_, and
+ // their latencies will not affect in_flight_batches_limit_.
+ void MaybeScheduleClosedBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
// Notifies scheduler of non-empty batch which is eligible for processing.
- void AddBatch(const internal::ASBSBatch<TaskType>* batch);
+ void AddBatch(const internal::ASBSBatch<TaskType>* batch,
+ bool also_schedule_closed_batch);
// Removes queue from scheduler.
void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
@@ -180,8 +188,10 @@ class AdaptiveSharedBatchScheduler
// results in an actual cap of 3 80% of the time, and 4 20% of the time.
double in_flight_batches_limit_ GUARDED_BY(mu_);
- // Number of batches currently being processed.
+ // Number of regular batches currently being processed.
int64 in_flight_batches_ GUARDED_BY(mu_) = 0;
+ // Number of express batches currently being processed.
+ int64 in_flight_express_batches_ GUARDED_BY(mu_) = 0;
// RNG engine and distribution.
std::default_random_engine rand_engine_;
@@ -363,10 +373,14 @@ Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
- const internal::ASBSBatch<TaskType>* batch) {
+ const internal::ASBSBatch<TaskType>* batch,
+ bool also_schedule_closed_batch) {
mutex_lock l(mu_);
batches_.push_back(batch);
MaybeScheduleNextBatch();
+ if (also_schedule_closed_batch) {
+ MaybeScheduleClosedBatch();
+ }
}
template <typename TaskType>
@@ -407,19 +421,45 @@ void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleNextBatch() {
batch->queue()->ReleaseBatch(batch);
batch_thread_pool_->Schedule(
std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper, this,
- batch, queues_and_callbacks_[batch->queue()]));
+ batch, queues_and_callbacks_[batch->queue()], false));
in_flight_batches_++;
}
template <typename TaskType>
+void AdaptiveSharedBatchScheduler<TaskType>::MaybeScheduleClosedBatch() {
+ if (in_flight_batches_ + in_flight_express_batches_ >=
+ options_.num_batch_threads) {
+ return;
+ }
+ for (auto it = batches_.begin(); it != batches_.end(); it++) {
+ if ((*it)->IsClosed()) {
+ const internal::ASBSBatch<TaskType>* batch = *it;
+ batches_.erase(it);
+ batch->queue()->ReleaseBatch(batch);
+ batch_thread_pool_->Schedule(
+ std::bind(&AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper,
+ this, batch, queues_and_callbacks_[batch->queue()], true));
+ in_flight_express_batches_++;
+ return;
+ }
+ }
+}
+
+template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::CallbackWrapper(
const internal::ASBSBatch<TaskType>* batch,
- AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback) {
+ AdaptiveSharedBatchScheduler<TaskType>::BatchProcessor callback,
+ bool is_express) {
int64 start_time = batch->creation_time_micros();
callback(std::unique_ptr<Batch<TaskType>>(
const_cast<internal::ASBSBatch<TaskType>*>(batch)));
int64 end_time = GetEnv()->NowMicros();
mutex_lock l(mu_);
+ if (is_express) {
+ in_flight_express_batches_--;
+ MaybeScheduleClosedBatch();
+ return;
+ }
in_flight_batches_--;
batch_count_++;
batch_latency_sum_ += end_time - start_time;
@@ -496,6 +536,7 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
" is larger than maximum batch size ",
options_.max_batch_size);
}
+ bool is_old_batch_closed = false;
{
mutex_lock l(mu_);
// Current batch is full, create another if allowed.
@@ -505,6 +546,7 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
return errors::Unavailable("The batch scheduling queue is full");
}
current_batch_->Close();
+ is_old_batch_closed = true;
current_batch_ = nullptr;
}
if (!current_batch_) {
@@ -516,7 +558,8 @@ Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
num_enqueued_tasks_++;
}
// AddBatch must be called outside of lock, since it may call ReleaseBatch.
- if (new_batch != nullptr) scheduler_->AddBatch(new_batch);
+ if (new_batch != nullptr)
+ scheduler_->AddBatch(new_batch, is_old_batch_closed);
return Status::OK();
}
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index b4c97df38b..0478c93280 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -59,6 +59,8 @@ CastOpBase::CastOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &external_dst_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Truncate", &use_truncation_));
+
// Quantized data types use the same underlying format as their non quantized
// version so we use the non quantized implementation for casting.
if (external_dst_dtype_ == DT_QUINT8) {
@@ -100,7 +102,7 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
- work_(ctx, in, out);
+ work_(ctx, in, out, use_truncation_);
out->set_dtype(external_dst_dtype_);
}
}
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index aae1e7ff19..527ab528c9 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -24,8 +24,71 @@ limitations under the License.
#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
+// Note that the GPU cast functor templates need to be instantiated unlike the
+// CPU ones, and hence their specializations are different than that for CPUs.
+#ifdef SPECIALIZE_FOR_GPUS
+#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
+ template <typename Device> \
+ struct CastFunctor<Device, OUT_TYPE, IN_OUT> { \
+ void operator()(const Device& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ if (truncate) { \
+ out_tensor.device(d) = \
+ in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
+ .template cast<OUT_TYPE>(); \
+ } else { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ } \
+ }; \
+ template struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT>;
+#else
+#define SPECIALIZE_CAST(DEVICE, OUT_TYPE, IN_OUT) \
+ template <> \
+ struct CastFunctor<DEVICE, OUT_TYPE, IN_OUT> { \
+ void operator()(const DEVICE& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ if (truncate) { \
+ out_tensor.device(d) = \
+ in_tensor.unaryExpr(LSBZeroSetter<IN_OUT, OUT_TYPE>()) \
+ .template cast<OUT_TYPE>(); \
+ } else { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ } \
+ };
+#endif
+
+#define CAST_FUNCTORS(devname) \
+ SPECIALIZE_CAST(devname, float, double) \
+ SPECIALIZE_CAST(devname, float, std::complex<double>) \
+ SPECIALIZE_CAST(devname, std::complex<float>, std::complex<double>) \
+ SPECIALIZE_CAST(devname, std::complex<float>, double) \
+ SPECIALIZE_CAST(devname, Eigen::half, double) \
+ SPECIALIZE_CAST(devname, Eigen::half, float) \
+ SPECIALIZE_CAST(devname, Eigen::half, std::complex<double>) \
+ SPECIALIZE_CAST(devname, Eigen::half, std::complex<float>) \
+ SPECIALIZE_CAST(devname, bfloat16, float) \
+ template <typename OUT_TYPE, typename IN_OUT> \
+ struct CastFunctor<devname, OUT_TYPE, IN_OUT> { \
+ void operator()(const devname& d, \
+ typename TTypes<OUT_TYPE>::Flat out_tensor, \
+ typename TTypes<IN_OUT>::ConstFlat in_tensor, \
+ bool truncate = false) { \
+ out_tensor.device(d) = in_tensor.template cast<OUT_TYPE>(); \
+ } \
+ };
+
namespace tensorflow {
+typedef std::function<void(OpKernelContext*, const Tensor&, Tensor*,
+ bool trunc)>
+ CastFunctorType;
+
// Common base class of Cast kernels
class CastOpBase : public OpKernel {
public:
@@ -38,8 +101,8 @@ class CastOpBase : public OpKernel {
DataType dst_dtype_;
DataType external_src_dtype_;
DataType external_dst_dtype_;
- std::function<void(OpKernelContext*, const Tensor&, Tensor*)> work_ = nullptr;
-
+ bool use_truncation_;
+ CastFunctorType work_ = nullptr;
Status Unimplemented();
TF_DISALLOW_COPY_AND_ASSIGN(CastOpBase);
@@ -56,6 +119,23 @@ class CpuCastOp : public CastOpBase {
namespace functor {
+template <typename I>
+constexpr int MantissaWidth() {
+ return std::numeric_limits<I>::digits;
+}
+
+template <>
+constexpr int MantissaWidth<Eigen::half>() {
+ // Remember, there's 1 hidden bit
+ return 10 + 1;
+}
+
+template <>
+constexpr int MantissaWidth<bfloat16>() {
+ // Remember, there's 1 hidden bit
+ return 7 + 1;
+}
+
template <typename Device, typename Tout, typename Tin>
void Cast(const Device& d, typename TTypes<Tout>::Flat o,
typename TTypes<Tin>::ConstFlat i) {
@@ -65,7 +145,85 @@ void Cast(const Device& d, typename TTypes<Tout>::Flat o,
template <typename Device, typename Tout, typename Tin>
struct CastFunctor {
void operator()(const Device& d, typename TTypes<Tout>::Flat o,
- typename TTypes<Tin>::ConstFlat i);
+ typename TTypes<Tin>::ConstFlat i, bool truncate = false);
+};
+
+// Only enable LSBZeroSetterHelper for 64 and 32 bit input data types.
+// Specialize for others if needed in future.
+template <typename I>
+typename std::enable_if<sizeof(I) == 8, void>::type EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
+ // Only zero the bits for non-NaNs.
+ // For NaNs, let the non-truncation version handle it.
+ if (!std::isnan(t)) {
+ uint64_t* p = reinterpret_cast<uint64_t*>(&t);
+ *p &= (0xFFFFFFFFFFFFFFFF << n);
+ }
+}
+
+template <typename I>
+typename std::enable_if<sizeof(I) == 4, void>::type EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE static LSBZeroSetterHelper(I& t, int n) {
+ // Only zero the bits for non-NaNs.
+ // For NaNs, let the non-truncation version handle it.
+ if (!std::isnan(t)) {
+ uint32_t* p = reinterpret_cast<uint32_t*>(&t);
+ *p &= (0xFFFFFFFF << n);
+ }
+}
+
+// Set n least significant bits to 0
+template <typename I, typename O>
+struct LSBZeroSetter {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const I operator()(const I& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I t = a;
+ LSBZeroSetterHelper(t, bits);
+ return t;
+ }
+};
+
+template <typename I, typename O>
+struct LSBZeroSetter<std::complex<I>, std::complex<O>> {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
+ const std::complex<I>& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I re = std::real(a);
+ I img = std::imag(a);
+ LSBZeroSetterHelper(re, bits);
+ LSBZeroSetterHelper(img, bits);
+ std::complex<I> toReturn(re, img);
+ return toReturn;
+ }
+};
+
+template <typename I, typename O>
+struct LSBZeroSetter<std::complex<I>, O> {
+ EIGEN_EMPTY_STRUCT_CTOR(LSBZeroSetter)
+ // Sets the 16 LSBits of the float to 0
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const std::complex<I> operator()(
+ const std::complex<I>& a) const {
+ constexpr int bits = MantissaWidth<I>() - MantissaWidth<O>();
+ static_assert(
+ bits > 0,
+ "The output type must have fewer mantissa bits than the input type\n");
+ I re = std::real(a);
+ I img = std::imag(a);
+ LSBZeroSetterHelper(re, bits);
+ LSBZeroSetterHelper(img, bits);
+ std::complex<I> toReturn(re, img);
+ return toReturn;
+ }
};
} // end namespace functor
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 607e7f5efd..036996fca2 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -18,22 +18,19 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/bfloat16.h"
+#define SPECIALIZE_FOR_GPUS
#include "tensorflow/core/kernels/cast_op.h"
+#undef SPECIALIZE_FOR_GPUS
namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-template <typename O, typename I>
-struct CastFunctor<GPUDevice, O, I> {
- void operator()(const GPUDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- Cast<GPUDevice, O, I>(d, o, i);
- }
-};
+CAST_FUNCTORS(GPUDevice);
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
+
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
@@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
-DEFINE_ALL_FROM(Eigen::half);
-DEFINE_ALL_FROM(float);
DEFINE_ALL_FROM(double);
-DEFINE_ALL_FROM(std::complex<float>);
DEFINE_ALL_FROM(std::complex<double>);
-DEFINE(bfloat16, float);
DEFINE(float, bfloat16);
+#define DEFINE_ALL_TO_FLOAT(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half); \
+ DEFINE(out_type, float); \
+ DEFINE(out_type, std::complex<float>)
+
+#define DEFINE_ALL_TO_HALF(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half)
+
+DEFINE_ALL_TO_HALF(Eigen::half);
+DEFINE_ALL_TO_HALF(bfloat16);
+DEFINE_ALL_TO_FLOAT(float);
+DEFINE_ALL_TO_FLOAT(std::complex<float>);
+
+#undef DEFINE_ALL_TO_FLOAT
+#undef DEFINE_ALL_TO_HALF
#undef DEFINE_ALL_FROM
#undef DEFINE
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index fe821b25df..b899bac681 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -25,22 +25,10 @@ namespace tensorflow {
namespace functor {
-template <typename O, typename I>
-struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
- void operator()(const Eigen::ThreadPoolDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
+CAST_FUNCTORS(Eigen::ThreadPoolDevice);
#ifdef TENSORFLOW_USE_SYCL
-template <typename O, typename I>
-struct CastFunctor<Eigen::SyclDevice, O, I> {
- void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- o.device(d) = i.template cast<O>();
- }
-};
+CAST_FUNCTORS(Eigen::SyclDevice);
#endif // TENSORFLOW_USE_SYCL
} // namespace functor
@@ -68,139 +56,103 @@ struct CastFunctor<Eigen::SyclDevice, O, I> {
CURRY_TYPES3_NO_BF16(FN, arg0, arg1) \
FN(arg0, arg1, bfloat16);
-#define CAST_CASE(DEVICE, IN, OUT) \
- if (DataTypeToEnum<OUT>::value == dst_dtype) { \
- return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) { \
- functor::CastFunctor<DEVICE, OUT, IN> func; \
- func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>()); \
- }; \
+#define CAST_CASE(DEVICE, IN, OUT) \
+ if (DataTypeToEnum<OUT>::value == dst_dtype) { \
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out, \
+ bool truncate) { \
+ functor::CastFunctor<DEVICE, OUT, IN> func; \
+ func(ctx->eigen_device<DEVICE>(), out->flat<OUT>(), inp.flat<IN>(), \
+ truncate); \
+ }; \
}
// The functions below are implemented in the cast_op_impl_*.cc files.
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBool(DataType dst_dtype);
+CastFunctorType GetCpuCastFromBool(DataType dst_dtype);
+
+CastFunctorType GetCpuCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint8(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint32(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt16(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt32(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromHalf(DataType dst_dtype);
+CastFunctorType GetCpuCastFromHalf(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromFloat(DataType dst_dtype);
+CastFunctorType GetCpuCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromDouble(DataType dst_dtype);
+CastFunctorType GetCpuCastFromDouble(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex64(DataType dst_dtype);
+CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex128(DataType dst_dtype);
+CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBfloat(DataType dst_dtype);
+CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype);
#if GOOGLE_CUDA
// Same, for GPU.
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBool(DataType dst_dtype);
+CastFunctorType GetGpuCastFromBool(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint8(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint32(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype);
+CastFunctorType GetGpuCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt16(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt32(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromHalf(DataType dst_dtype);
+CastFunctorType GetGpuCastFromHalf(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromFloat(DataType dst_dtype);
+CastFunctorType GetGpuCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromDouble(DataType dst_dtype);
+CastFunctorType GetGpuCastFromDouble(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex64(DataType dst_dtype);
+CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex128(DataType dst_dtype);
+CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBfloat(DataType dst_dtype);
+CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromBool(DataType dst_dtype);
+CastFunctorType GetSyclCastFromBool(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint8(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint8(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint16(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint32(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint64(DataType dst_dtype);
+CastFunctorType GetSyclCastFromUint64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt16(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt16(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt32(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt32(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt64(DataType dst_dtype);
+CastFunctorType GetSyclCastFromInt64(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromFloat(DataType dst_dtype);
+CastFunctorType GetSyclCastFromFloat(DataType dst_dtype);
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromDouble(DataType dst_dtype);
+CastFunctorType GetSyclCastFromDouble(DataType dst_dtype);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cast_op_impl_bfloat.cc b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
index bfa7ba0d47..96aae15608 100644
--- a/tensorflow/core/kernels/cast_op_impl_bfloat.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bfloat.cc
@@ -22,20 +22,19 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBfloat(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromBfloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, bfloat16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBfloat(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromBfloat(DataType dst_dtype) {
if (dst_dtype == DT_FLOAT) {
- return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out) {
+ return [](OpKernelContext* ctx, const Tensor& inp, Tensor* out,
+ bool truncate) {
functor::CastFunctor<GPUDevice, float, bfloat16> func;
func(ctx->eigen_device<GPUDevice>(), out->flat<float>(),
- inp.flat<bfloat16>());
+ inp.flat<bfloat16>(), truncate);
};
}
return nullptr;
diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc
index c5c7394b43..792d4781f2 100644
--- a/tensorflow/core/kernels/cast_op_impl_bool.cc
+++ b/tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromBool(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromBool(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, bool);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromBool(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, bool);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromBool(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromBool(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromBool(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, bool);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_complex128.cc b/tensorflow/core/kernels/cast_op_impl_complex128.cc
index 52899d58cd..9a184e5954 100644
--- a/tensorflow/core/kernels/cast_op_impl_complex128.cc
+++ b/tensorflow/core/kernels/cast_op_impl_complex128.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex128(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromComplex128(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<double>);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex128(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromComplex128(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<double>);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_complex64.cc b/tensorflow/core/kernels/cast_op_impl_complex64.cc
index 617bda53d5..77bc620b46 100644
--- a/tensorflow/core/kernels/cast_op_impl_complex64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_complex64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromComplex64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromComplex64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, std::complex<float>);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromComplex64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromComplex64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, std::complex<float>);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc
index 7dc485ddad..ff9056897f 100644
--- a/tensorflow/core/kernels/cast_op_impl_double.cc
+++ b/tensorflow/core/kernels/cast_op_impl_double.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, double);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, double);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromDouble(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromDouble(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromDouble(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, double);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc
index 1c933914fd..f1e8f0e37b 100644
--- a/tensorflow/core/kernels/cast_op_impl_float.cc
+++ b/tensorflow/core/kernels/cast_op_impl_float.cc
@@ -22,15 +22,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, float);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, GPUDevice, float);
return nullptr;
}
@@ -38,8 +36,7 @@ GetGpuCastFromFloat(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromFloat(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromFloat(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, float);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_half.cc b/tensorflow/core/kernels/cast_op_impl_half.cc
index ef4b94e326..5da3a01352 100644
--- a/tensorflow/core/kernels/cast_op_impl_half.cc
+++ b/tensorflow/core/kernels/cast_op_impl_half.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromHalf(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromHalf(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, Eigen::half);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromHalf(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromHalf(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, Eigen::half);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int16.cc b/tensorflow/core/kernels/cast_op_impl_int16.cc
index 59360f7445..440ee88fb5 100644
--- a/tensorflow/core/kernels/cast_op_impl_int16.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int16.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int16);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt16(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt16(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int16);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc
index a867392fde..4b3e7efddc 100644
--- a/tensorflow/core/kernels/cast_op_impl_int32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int32.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int32);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int32);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt32(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt32(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int32);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc
index 467a8f6c89..0f711aa560 100644
--- a/tensorflow/core/kernels/cast_op_impl_int64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int64);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int64);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt64(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt64(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int64);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_int8.cc b/tensorflow/core/kernels/cast_op_impl_int8.cc
index 21002a4321..eac185d5a0 100644
--- a/tensorflow/core/kernels/cast_op_impl_int8.cc
+++ b/tensorflow/core/kernels/cast_op_impl_int8.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, int8);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, int8);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromInt8(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromInt8(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromInt8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, int8);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint16.cc b/tensorflow/core/kernels/cast_op_impl_uint16.cc
index cd829bae2a..3aebbdc1f3 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint16.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint16.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint16);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint16);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint16(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint16(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint16(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint16);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc
index d1a854d98b..86f5961bcc 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint32.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint32);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint32(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint32(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint64.cc b/tensorflow/core/kernels/cast_op_impl_uint64.cc
index 604e0424fc..6478c266ee 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint64.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint64.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint64);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint64);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint64(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint64(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint64(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint64);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_impl_uint8.cc b/tensorflow/core/kernels/cast_op_impl_uint8.cc
index 2d1a6f3a4e..b22547a23e 100644
--- a/tensorflow/core/kernels/cast_op_impl_uint8.cc
+++ b/tensorflow/core/kernels/cast_op_impl_uint8.cc
@@ -20,15 +20,13 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetCpuCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetCpuCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3(CAST_CASE, CPUDevice, uint8);
return nullptr;
}
#if GOOGLE_CUDA
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetGpuCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetGpuCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint8);
return nullptr;
}
@@ -36,8 +34,7 @@ GetGpuCastFromUint8(DataType dst_dtype) {
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
-std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
-GetSyclCastFromUint8(DataType dst_dtype) {
+CastFunctorType GetSyclCastFromUint8(DataType dst_dtype) {
CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint8);
return nullptr;
}
diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc
index 9bbf7afb16..cb305de5e3 100644
--- a/tensorflow/core/kernels/cast_op_test.cc
+++ b/tensorflow/core/kernels/cast_op_test.cc
@@ -40,17 +40,27 @@ static Graph* Cast(int num) {
class CastOpTest : public OpsTestBase {
protected:
- void MakeOp(DataType src, DataType dst) {
- TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
- .Input(FakeInput(src))
- .Attr("SrcT", src)
- .Attr("DstT", dst)
- .Finalize(node_def()));
+ void MakeOp(DataType src, DataType dst, bool trunc = false) {
+ if (trunc) {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Attr("Truncate", true)
+ .Finalize(node_def()));
+ } else {
+ TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
+ .Input(FakeInput(src))
+ .Attr("SrcT", src)
+ .Attr("DstT", dst)
+ .Finalize(node_def()));
+ }
+
TF_EXPECT_OK(InitOp());
}
template <typename INPUT, typename OUTPUT>
- void CheckCast() {
+ void CheckCast(bool trunc = false) {
DataType in_type = DataTypeToEnum<INPUT>::v();
DataType out_type = DataTypeToEnum<OUTPUT>::v();
MakeOp(in_type, out_type);
@@ -64,8 +74,9 @@ class CastOpTest : public OpsTestBase {
}
};
-#define TEST_CAST(in, out) \
- TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); }
+#define TEST_CAST(in, out) \
+ TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \
+ TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); }
#define TEST_ALL_CASTS_FROM(in) \
TEST_CAST(in, uint8); \
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index c281153795..1236f27051 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -229,7 +229,7 @@ class FusedResizePadConvOpTest : public OpsTestBase {
std::vector<Tensor> fused_tensors;
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
- test::ExpectTensorNear<T>(unfused_tensors[0], fused_tensors[0], 1e-5);
+ test::ExpectClose(unfused_tensors[0], fused_tensors[0]);
}
template <typename T>
@@ -282,7 +282,7 @@ class FusedResizePadConvOpTest : public OpsTestBase {
std::vector<Tensor> fused_tensors;
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
- test::ExpectTensorNear<T>(unfused_tensors[0], fused_tensors[0], 1e-5);
+ test::ExpectClose(unfused_tensors[0], fused_tensors[0]);
}
};
diff --git a/tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc b/tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc
new file mode 100644
index 0000000000..d7ca64bea0
--- /dev/null
+++ b/tensorflow/core/kernels/crop_and_resize_op_benchmark_test.cc
@@ -0,0 +1,72 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static Graph* BM_CropAndResize(int batches, int width, int height, int depth,
+ int crop_height, int crop_width) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor in(DT_FLOAT, TensorShape({batches, height, width, depth}));
+ in.flat<float>().setRandom();
+ Tensor boxes(DT_FLOAT, TensorShape({batches, 4}));
+ auto boxes_tensor = boxes.matrix<float>();
+ Tensor box_ind(DT_INT32, TensorShape({batches}));
+ auto box_ind_flat = box_ind.flat<int32>();
+ for (int i = 0; i < batches; ++i) {
+ boxes_tensor(i, 0) = 0.2;
+ boxes_tensor(i, 1) = 0.2;
+ boxes_tensor(i, 2) = 0.8;
+ boxes_tensor(i, 3) = 0.7;
+ box_ind_flat(i) = i;
+ }
+ Tensor crop_size(DT_INT32, TensorShape({2}));
+ auto crop_size_flat = crop_size.flat<int32>();
+ crop_size_flat(0) = crop_height;
+ crop_size_flat(1) = crop_width;
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CropAndResize")
+ .Input(test::graph::Constant(g, in))
+ .Input(test::graph::Constant(g, boxes))
+ .Input(test::graph::Constant(g, box_ind))
+ .Input(test::graph::Constant(g, crop_size))
+ .Finalize(g, &ret));
+ return g;
+}
+
+#define BM_CropAndResizeDev(DEVICE, B, W, H, D, CH, CW) \
+ static void BM_CropAndResize_##DEVICE##_##B##_##W##_##H##_##D##_##CH##_##CW( \
+ int iters) { \
+ testing::ItemsProcessed(iters* B* W* H* D); \
+ test::Benchmark(#DEVICE, BM_CropAndResize(B, W, H, D, CH, CW)).Run(iters); \
+ } \
+ BENCHMARK(BM_CropAndResize_##DEVICE##_##B##_##W##_##H##_##D##_##CH##_##CW);
+
+// Benchmark results using CPU:Intel Haswell with HyperThreading (6 cores)
+// Benchmark Time(ns) CPU(ns) Iterations
+// BM_CropAndResize_cpu_1_640_640_3_512_512 7078765 7173520 100 163.361M items/s
+// BM_CropAndResize_cpu_1_640_640_1_512_512 3801232 3914692 185 99.784M items/s
+// BM_CropAndResize_cpu_1_80_80_512_7_7 182470 241767 2941 1.372G items/s
+
+BM_CropAndResizeDev(cpu, 1, 640, 640, 3, 512, 512);
+BM_CropAndResizeDev(cpu, 1, 640, 640, 1, 512, 512);
+BM_CropAndResizeDev(cpu, 1, 80, 80, 512, 7, 7);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_tan.cc b/tensorflow/core/kernels/cwise_op_tan.cc
index c1a25767d3..90762fb1b0 100644
--- a/tensorflow/core/kernels/cwise_op_tan.cc
+++ b/tensorflow/core/kernels/cwise_op_tan.cc
@@ -16,7 +16,8 @@ limitations under the License.
#include "tensorflow/core/kernels/cwise_ops_common.h"
namespace tensorflow {
-REGISTER2(UnaryOp, CPU, "Tan", functor::tan, float, double);
+REGISTER4(UnaryOp, CPU, "Tan", functor::tan, float, double, complex64,
+ complex128);
#if GOOGLE_CUDA
REGISTER2(UnaryOp, GPU, "Tan", functor::tan, float, double);
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index e04fa20414..607a694dba 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -177,6 +177,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "filter_by_component_dataset_op",
+ srcs = ["filter_by_component_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "map_dataset_op",
srcs = ["map_dataset_op.cc"],
deps = [
@@ -204,12 +217,28 @@ tf_kernel_library(
],
)
+cc_library(
+ name = "parallel_map_iterator",
+ srcs = ["parallel_map_iterator.cc"],
+ hdrs = ["parallel_map_iterator.h"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
tf_kernel_library(
name = "parallel_map_dataset_op",
srcs = ["parallel_map_dataset_op.cc"],
deps = [
":captured_function",
":dataset",
+ ":parallel_map_iterator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -222,6 +251,7 @@ tf_kernel_library(
tf_kernel_library(
name = "generator_dataset_op",
srcs = ["generator_dataset_op.cc"],
+ hdrs = ["generator_dataset_op.h"],
deps = [
":captured_function",
"//tensorflow/core:core_cpu_internal",
@@ -314,6 +344,7 @@ tf_cc_test(
tf_kernel_library(
name = "prefetch_dataset_op",
srcs = ["prefetch_dataset_op.cc"],
+ hdrs = ["prefetch_dataset_op.h"],
deps = [
":dataset",
":prefetch_autotuner",
@@ -535,9 +566,11 @@ tf_kernel_library(
tf_kernel_library(
name = "iterator_ops",
srcs = ["iterator_ops.cc"],
+ hdrs = ["iterator_ops.h"],
deps = [
":dataset",
":dataset_utils",
+ ":optional_ops",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -550,6 +583,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "optional_ops",
+ srcs = ["optional_ops.cc"],
+ hdrs = ["optional_ops.h"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "cache_dataset_ops",
srcs = ["cache_dataset_ops.cc"],
deps = [
@@ -605,6 +652,7 @@ tf_kernel_library(
":dataset",
":dataset_ops",
":dense_to_sparse_batch_dataset_op",
+ ":filter_by_component_dataset_op",
":filter_dataset_op",
":flat_map_dataset_op",
":generator_dataset_op",
@@ -614,7 +662,9 @@ tf_kernel_library(
":iterator_ops",
":map_and_batch_dataset_op",
":map_dataset_op",
+ ":map_defun_op",
":optimize_dataset_op",
+ ":optional_ops",
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
@@ -655,3 +705,15 @@ tf_kernel_library(
"//tensorflow/core/kernels:ops_util",
],
)
+
+tf_kernel_library(
+ name = "map_defun_op",
+ srcs = ["map_defun_op.cc"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:functional_ops_op_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index ed4932bf32..86b0840aea 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -39,7 +39,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument<string>(ctx, "filename", &filename));
if (filename.empty()) {
- *output = new MemoryDataset(input);
+ *output = new MemoryDataset(ctx, input);
} else {
*output = new FileDataset(ctx, input, filename, ctx->env());
}
@@ -68,8 +68,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new FileCacheIterator(
- {this, strings::StrCat(prefix, "::FileCacheIterator")}));
+ return std::unique_ptr<IteratorBase>(
+ new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -105,9 +105,9 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
tensor_index);
}
- class FileCacheIterator : public DatasetIterator<FileDataset> {
+ class FileIterator : public DatasetIterator<FileDataset> {
public:
- explicit FileCacheIterator(const Params& params)
+ explicit FileIterator(const Params& params)
: DatasetIterator<FileDataset>(params) {
if (params.dataset->env_
->FileExists(MetaFilename(params.dataset->filename_))
@@ -526,7 +526,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
enum Mode { read, write };
Mode mode_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
- }; // FileCacheIterator
+ }; // FileIterator
const DatasetBase* const input_;
const string filename_;
@@ -538,9 +538,10 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
const string tensor_format_string_;
}; // FileDataset
- class MemoryDataset : public DatasetBase {
+ class MemoryDataset : public GraphDatasetBase {
public:
- explicit MemoryDataset(const DatasetBase* input) : input_(input) {
+ explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
+ : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) {
input->Ref();
}
@@ -548,18 +549,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- mutex_lock l(mu_);
- if (cache_) {
- return std::unique_ptr<IteratorBase>(new MemoryReaderIterator(
- {this, strings::StrCat(prefix, "::MemoryReader")}, cache_.get()));
- }
- if (!writer_iterator_created_) {
- writer_iterator_created_ = true;
- return std::unique_ptr<IteratorBase>(new MemoryWriterIterator(
- {this, strings::StrCat(prefix, "::MemoryWriter")}));
- }
- return std::unique_ptr<IteratorBase>(new DuplicateWriterIterator(
- {this, strings::StrCat(prefix, "::DuplicateWriter")}));
+ return std::unique_ptr<IteratorBase>(new MemoryIterator(
+ {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
@@ -574,114 +565,321 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return "CacheDatasetOp::MemoryDataset";
}
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ Node* filename_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(string(""), &filename_node));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_node, filename_node}, output));
+ return Status::OK();
+ }
+
private:
- // MemoryWriterIterator passes through and appends items from the input
- // dataset to its vector.
+ // A thread-safe data structure for caching dataset elements.
//
- // This iterator is used when dataset->cache_ is null. After buffering
- // the tensors in memory, upon exhausing the underlying iterator, they are
- // updated into the parent dataset's cache_ pointer.
- class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
+ // The expected use is that a single `MemoryWriterIterator` populates the
+ // cache with dataset elements. Once all elements are cached, the cache can
+ // be used by one or more `MemoryReaderIterator`s.
+ class MemoryCache {
public:
- explicit MemoryWriterIterator(const Params& params)
- : DatasetIterator<MemoryDataset>(params),
- cache_(new std::vector<std::vector<Tensor>>) {}
+ MemoryCache() = default;
- ~MemoryWriterIterator() override {
+ // Marks the cache as completed.
+ void Complete() {
mutex_lock l(mu_);
- if (cache_) {
- LOG(ERROR)
- << "The calling iterator did not fully read the dataset we were "
- "attempting to cache. In order to avoid unexpected truncation "
- "of the sequence, the current [partially cached] sequence "
- "will be dropped. This can occur if you have a sequence "
- "similar to `dataset.cache().take(k).repeat()`. Instead, swap "
- "the order (i.e. `dataset.take(k).cache().repeat()`)";
- mutex_lock l2(dataset()->mu_);
- dataset()->writer_iterator_created_ = false;
- }
+ completed_ = true;
}
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ // Returns whether the cache is claimed.
+ bool IsClaimed() {
+ tf_shared_lock l(mu_);
+ return claimed_;
}
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
+ // Returns whether the cache is completed.
+ bool IsCompleted() {
+ tf_shared_lock l(mu_);
+ return completed_;
+ }
+
+ // Attempts to claim the cache, returning whether the cache was claimed.
+ bool MaybeClaim() {
mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (*end_of_sequence) {
- // Guard on cache_ to not crash if GetNext is called a second time
- // after *end_of_sequence == true
- if (cache_) {
- mutex_lock l(dataset()->mu_);
- DCHECK(dataset()->writer_iterator_created_);
- DCHECK(!dataset()->cache_);
- cache_.swap(dataset()->cache_);
- }
- return Status::OK();
+ if (!claimed_) {
+ claimed_ = true;
+ return true;
}
- cache_->emplace_back(*out_tensors);
- return Status::OK();
+ return false;
+ }
+
+ // Resets the cache.
+ void Reset() {
+ mutex_lock l(mu_);
+ claimed_ = false;
+ completed_ = false;
+ cache_.clear();
+ }
+
+ // Returns the element at the given index.
+ const std::vector<Tensor>& at(int64 index) {
+ tf_shared_lock l(mu_);
+ DCHECK(index < cache_.size());
+ return cache_[index];
+ }
+
+ // Adds the element to the cache.
+ void emplace_back(std::vector<Tensor> element) {
+ mutex_lock l(mu_);
+ cache_.emplace_back(std::move(element));
+ }
+
+ // Returns the size of the cache.
+ size_t size() {
+ tf_shared_lock l(mu_);
+ return cache_.size();
}
private:
mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- std::unique_ptr<std::vector<std::vector<Tensor>>> cache_ GUARDED_BY(mu_);
- }; // MemoryWriterIterator
-
- class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
+ // Determines whether a writer has claimed the cache.
+ bool claimed_ GUARDED_BY(mu_) = false;
+ // Determines whether all elements of the dataset have been cached.
+ bool completed_ GUARDED_BY(mu_) = false;
+ std::vector<std::vector<Tensor>> cache_ GUARDED_BY(mu_);
+ };
+
+ class MemoryIterator : public DatasetIterator<MemoryDataset> {
public:
- explicit MemoryReaderIterator(
- const Params& params, const std::vector<std::vector<Tensor>>* cache)
- : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
- CHECK(cache);
+ explicit MemoryIterator(const Params& params,
+ const std::shared_ptr<MemoryCache>& cache)
+ : DatasetIterator<MemoryDataset>(params), cache_(cache) {
+ mode_ = cache->MaybeClaim() ? Mode::write : Mode::read;
+ InitializeIterator();
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ if (mode_ == Mode::read && !cache_->IsCompleted()) {
+ return errors::Internal(
+ "Cache should only be read after it has been completed.");
+ }
+ return iterator_->Initialize(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
- if (index_ < cache_->size()) {
- const std::vector<Tensor>& cache_tensors = (*cache_)[index_];
- out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
- cache_tensors.end());
- index_++;
- *end_of_sequence = false;
- return Status::OK();
- } else {
- *end_of_sequence = true;
- return Status::OK();
+ return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("mode"), mode_));
+ if (cache_->IsClaimed()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cache_claimed"), ""));
+ size_t cache_size = cache_->size();
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cache_size"), cache_size));
+ for (size_t i = 0; i < cache_size; i++) {
+ auto& element = cache_->at(i);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("cache[", i, "].size")),
+ element.size()));
+ for (size_t j = 0; j < element.size(); ++j) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("cache[", i, "][", j, "]")),
+ element[j]));
+ }
+ }
+ if (cache_->IsCompleted()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cache_completed"), ""));
+ }
}
+ return SaveParent(writer, iterator_);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ iterator_.reset();
+ cache_->Reset();
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("mode"), &temp));
+ mode_ = static_cast<Mode>(temp);
+ }
+ if (reader->Contains(full_name("cache_claimed"))) {
+ CHECK(cache_->MaybeClaim());
+ size_t cache_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cache_size"), &temp));
+ cache_size = static_cast<size_t>(temp);
+ }
+ for (size_t i = 0; i < cache_size; ++i) {
+ std::vector<Tensor> element;
+ size_t element_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("cache[", i, "].size")), &temp));
+ element_size = static_cast<size_t>(temp);
+ }
+ element.reserve(element_size);
+ for (size_t j = 0; j < element_size; ++j) {
+ element.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("cache[", i, "][", j, "]")),
+ &element.back()));
+ }
+ cache_->emplace_back(std::move(element));
+ }
+ if (reader->Contains(full_name("cache_completed"))) {
+ cache_->Complete();
+ }
+ }
+ InitializeIterator();
+ TF_RETURN_IF_ERROR(iterator_->Initialize(ctx));
+ return RestoreParent(ctx, reader, iterator_);
}
private:
- mutex mu_;
- const std::vector<std::vector<Tensor>>* const cache_;
- size_t index_ GUARDED_BY(mu_);
- }; // MemoryReaderIterator
+ class MemoryWriterIterator : public DatasetIterator<MemoryDataset> {
+ public:
+ explicit MemoryWriterIterator(const Params& params,
+ const std::shared_ptr<MemoryCache>& cache)
+ : DatasetIterator<MemoryDataset>(params), cache_(cache) {
+ CHECK(cache_);
+ }
- class DuplicateWriterIterator : public DatasetIterator<MemoryDataset> {
- public:
- explicit DuplicateWriterIterator(const Params& params)
- : DatasetIterator<MemoryDataset>(params) {}
+ ~MemoryWriterIterator() override {
+ mutex_lock l(mu_);
+ if (cache_->size() > 0 && !cache_->IsCompleted()) {
+ LOG(WARNING)
+ << "The calling iterator did not fully read the dataset being "
+ "cached. In order to avoid unexpected truncation of the "
+ "dataset, the partially cached contents of the dataset"
+ "will be discarded. This can happen if you have an input "
+ "pipeline similar to `dataset.cache().take(k).repeat()`. "
+ "You should use `dataset.take(k).cache().repeat()` instead.";
+ cache_->Reset();
+ }
+ }
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- return errors::AlreadyExists(
- "There appears to be a concurrent caching iterator running.");
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ cache_->Complete();
+ return Status::OK();
+ }
+ cache_->emplace_back(*out_tensors);
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ return SaveParent(writer, input_impl_);
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ return RestoreParent(ctx, reader, input_impl_);
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::shared_ptr<MemoryCache> cache_;
+ }; // MemoryWriterIterator
+
+ class MemoryReaderIterator : public DatasetIterator<MemoryDataset> {
+ public:
+ explicit MemoryReaderIterator(const Params& params,
+ const std::shared_ptr<MemoryCache>& cache)
+ : DatasetIterator<MemoryDataset>(params), cache_(cache), index_(0) {
+ CHECK(cache);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("index"), index_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("index"), &temp));
+ index_ = static_cast<size_t>(temp);
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (index_ < cache_->size()) {
+ const std::vector<Tensor>& cache_tensors = cache_->at(index_);
+ out_tensors->insert(out_tensors->begin(), cache_tensors.begin(),
+ cache_tensors.end());
+ index_++;
+ *end_of_sequence = false;
+ return Status::OK();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+
+ private:
+ mutex mu_;
+ const std::shared_ptr<MemoryCache> cache_;
+ size_t index_ GUARDED_BY(mu_);
+ }; // MemoryReaderIterator
+
+ void InitializeIterator() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ switch (mode_) {
+ case Mode::read:
+ iterator_.reset(
+ new MemoryReaderIterator({dataset(), prefix()}, cache_));
+ break;
+ case Mode::write:
+ iterator_.reset(
+ new MemoryWriterIterator({dataset(), prefix()}, cache_));
+ }
}
- }; // DuplicateWriterIterator
+
+ mutex mu_;
+ std::shared_ptr<MemoryCache> cache_;
+ enum Mode { read, write };
+ Mode mode_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> iterator_ GUARDED_BY(mu_);
+ }; // MemoryIterator
const DatasetBase* const input_;
- mutable mutex mu_;
- mutable std::unique_ptr<std::vector<std::vector<Tensor>>> cache_
- GUARDED_BY(mu_);
- mutable bool writer_iterator_created_ GUARDED_BY(mu_) = false;
+ const std::shared_ptr<MemoryCache> cache_;
}; // MemoryDataset
}; // CacheDatasetOp
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
new file mode 100644
index 0000000000..8b29456354
--- /dev/null
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -0,0 +1,169 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/random/random.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+// TODO(prazek): Filter already has a logic of filtering by the given tensor,
+// but it must return both components. We could introduce kernel like
+// DropComponentDatasetOp and use FilterDataset for filtering.
+class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit FilterByLastComponentDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input, output_types_, output_shapes_);
+ }
+
+ private:
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const DataTypeVector& output_types,
+ std::vector<PartialTensorShape> output_shapes)
+ : GraphDatasetBase(ctx),
+ input_(input),
+ output_types_(output_types),
+ output_shapes_(std::move(output_shapes)) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<Iterator>(new Iterator(
+ {this, strings::StrCat(prefix, "::FilterByLastComponent")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "FilterByLastComponentDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {std::make_pair(0, input_graph_node)}, // Single tensor inputs.
+ {}, {}, output));
+ return Status::OK();
+ }
+
+ private:
+ const DatasetBase* const input_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ // NOTE(mrry): This method is thread-safe as long as `input_impl_` is
+ // thread-safe. However, if multiple threads enter this method, outputs
+ // may be observed in a non-deterministic order.
+ bool matched;
+ do {
+ {
+ tf_shared_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ }
+ if (*end_of_sequence) {
+ mutex_lock l(mu_);
+ input_impl_.reset();
+ return Status::OK();
+ }
+
+ matched = out_tensors->back().scalar<bool>()();
+ out_tensors->pop_back();
+ if (!matched) {
+ // Clear the output tensor list since it didn't match.
+ out_tensors->clear();
+ }
+ } while (!matched);
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU),
+ FilterByLastComponentDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 0981e42ba1..c4dd849b8b 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -15,192 +15,174 @@ limitations under the License.
#include <iterator>
#include <vector>
-#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/kernels/data/generator_dataset_op.h"
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-namespace {
-
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class GeneratorDatasetOp : public DatasetOpKernel {
+class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
public:
- explicit GeneratorDatasetOp(OpKernelConstruction* ctx)
- : DatasetOpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
+ std::unique_ptr<CapturedFunction> next_func,
+ std::unique_ptr<CapturedFunction> finalize_func,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphDatasetBase(ctx),
+ init_func_(std::move(init_func)),
+ next_func_(std::move(next_func)),
+ finalize_func_(std::move(finalize_func)),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Generator")}));
}
- void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- init_func_, std::move(init_func_other_args), &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- next_func_, std::move(next_func_other_args), &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
-
- *output =
- new Dataset(ctx, std::move(init_func), std::move(next_func),
- std::move(finalize_func), output_types_, output_shapes_);
+ const DataTypeVector& output_dtypes() const override { return output_types_; }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
}
+ string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
+
private:
- class Dataset : public GraphDatasetBase {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
- std::unique_ptr<CapturedFunction> next_func,
- std::unique_ptr<CapturedFunction> finalize_func,
- const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
- init_func_(std::move(init_func)),
- next_func_(std::move(next_func)),
- finalize_func_(std::move(finalize_func)),
- output_types_(output_types),
- output_shapes_(output_shapes) {}
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Generator")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return output_types_;
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
- }
-
- string DebugString() const override {
- return "GeneratorDatasetOp::Dataset";
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- ~Iterator() override {
- if (!finalized_) {
- std::vector<Tensor> ignored;
- Status s =
- dataset()->finalize_func_->RunInstantiated(state_, &ignored);
- if (!s.ok()) {
- LOG(WARNING)
- << "Error occurred when finalizing GeneratorDataset iterator: "
- << s;
- }
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ ~Iterator() override {
+ if (!finalized_) {
+ std::vector<Tensor> ignored;
+ Status s = dataset()->finalize_func_->RunInstantiated(state_, &ignored);
+ if (!s.ok()) {
+ LOG(WARNING)
+ << "Error occurred when finalizing GeneratorDataset iterator: "
+ << s;
}
}
+ }
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- mutex_lock l(mu_);
-
- if (!initialized_) {
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
- // Explicitly instantiate the finalize function here so that
- // we can invoke it in the destructor.
- TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- initialized_ = true;
- }
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ // Explicitly instantiate the finalize function here so that
+ // we can invoke it in the destructor.
+ TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
+ initialized_ = true;
+ }
- if (finalized_) {
- *end_of_sequence = true;
- return Status::OK();
- }
+ if (finalized_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
- Status s = dataset()->next_func_->RunWithBorrowedArgs(ctx, state_,
- out_tensors);
- if (s.ok()) {
- *end_of_sequence = false;
- } else if (errors::IsOutOfRange(s)) {
- // `next_func` may deliberately raise `errors::OutOfRange`
- // to indicate that we should terminate the iteration.
- s = Status::OK();
- *end_of_sequence = true;
-
- // NOTE(mrry): We ignore any tensors returned by the
- // finalize function.
- std::vector<Tensor> ignored;
- TF_RETURN_IF_ERROR(
- dataset()->finalize_func_->RunInstantiated(state_, &ignored));
- finalized_ = true;
- }
- return s;
+ Status s =
+ dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, out_tensors);
+ if (s.ok()) {
+ *end_of_sequence = false;
+ } else if (errors::IsOutOfRange(s)) {
+ // `next_func` may deliberately raise `errors::OutOfRange`
+ // to indicate that we should terminate the iteration.
+ s = Status::OK();
+ *end_of_sequence = true;
+
+ // NOTE(mrry): We ignore any tensors returned by the
+ // finalize function.
+ std::vector<Tensor> ignored;
+ TF_RETURN_IF_ERROR(
+ dataset()->finalize_func_->RunInstantiated(state_, &ignored));
+ finalized_ = true;
}
+ return s;
+ }
- private:
- mutex mu_;
- bool initialized_ GUARDED_BY(mu_) = false;
- bool finalized_ GUARDED_BY(mu_) = false;
- std::vector<Tensor> state_ GUARDED_BY(mu_);
- };
-
- const std::unique_ptr<CapturedFunction> init_func_;
- const std::unique_ptr<CapturedFunction> next_func_;
- const std::unique_ptr<CapturedFunction> finalize_func_;
- const DataTypeVector output_types_;
- const std::vector<PartialTensorShape> output_shapes_;
+ private:
+ mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
+ bool finalized_ GUARDED_BY(mu_) = false;
+ std::vector<Tensor> state_ GUARDED_BY(mu_);
};
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
- NameAttrList init_func_;
- NameAttrList next_func_;
- NameAttrList finalize_func_;
+ const std::unique_ptr<CapturedFunction> init_func_;
+ const std::unique_ptr<CapturedFunction> next_func_;
+ const std::unique_ptr<CapturedFunction> finalize_func_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
};
+GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+}
+
+void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
+ DatasetBase** output) {
+ OpInputList init_func_other_args_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
+ &init_func_other_args_input));
+ std::vector<Tensor> init_func_other_args;
+ init_func_other_args.reserve(init_func_other_args_input.size());
+ for (const Tensor& t : init_func_other_args_input) {
+ init_func_other_args.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> init_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
+ &init_func));
+
+ OpInputList next_func_other_args_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
+ &next_func_other_args_input));
+ std::vector<Tensor> next_func_other_args;
+ next_func_other_args.reserve(next_func_other_args_input.size());
+ for (const Tensor& t : next_func_other_args_input) {
+ next_func_other_args.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> next_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
+ &next_func));
+
+ OpInputList finalize_func_other_args_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
+ &finalize_func_other_args_input));
+ std::vector<Tensor> finalize_func_other_args;
+ finalize_func_other_args.reserve(finalize_func_other_args_input.size());
+ for (const Tensor& t : finalize_func_other_args_input) {
+ finalize_func_other_args.push_back(t);
+ }
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ finalize_func_, std::move(finalize_func_other_args),
+ &finalize_func));
+
+ *output =
+ new Dataset(ctx, std::move(init_func), std::move(next_func),
+ std::move(finalize_func), output_types_, output_shapes_);
+}
+
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
GeneratorDatasetOp);
-} // namespace
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h
new file mode 100644
index 0000000000..3f84fa9c2e
--- /dev/null
+++ b/tensorflow/core/kernels/data/generator_dataset_op.h
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
+
+namespace tensorflow {
+
+class GeneratorDatasetOp : public DatasetOpKernel {
+ public:
+ explicit GeneratorDatasetOp(OpKernelConstruction* ctx);
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
+
+ private:
+ class Dataset;
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList init_func_;
+ NameAttrList next_func_;
+ NameAttrList finalize_func_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index da489db7c8..e2df14337c 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/kernels/data/iterator_ops.h"
+
#include "tensorflow/core/common_runtime/graph_runner.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/common_runtime/threadpool_device.h"
@@ -23,8 +24,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/data/optional_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -80,6 +81,8 @@ Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
return Status::OK();
}
+} // namespace
+
class IteratorResource : public ResourceBase {
public:
IteratorResource(const DataTypeVector& output_dtypes,
@@ -437,300 +440,179 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
// Note that IteratorHandleOp holds a reference to the resource it creates. If
// cleaning up resources with DestroyResourceOp is important, consider creating
// resource containers with AnonymousIteratorHandleOp instead.
-class IteratorHandleOp : public OpKernel {
- public:
- explicit IteratorHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
- }
+IteratorHandleOp::IteratorHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+}
- // The resource is deleted from the resource manager only when it is private
- // to kernel. Ideally the resource should be deleted when it is no longer held
- // by anyone, but it would break backward compatibility.
- ~IteratorHandleOp() override {
- if (resource_ != nullptr) {
- resource_->Unref();
- if (cinfo_.resource_is_private_to_kernel()) {
- if (!cinfo_.resource_manager()
- ->template Delete<IteratorResource>(cinfo_.container(),
- cinfo_.name())
- .ok()) {
- // Do nothing; the resource can have been deleted by session resets.
- }
+// The resource is deleted from the resource manager only when it is private
+// to kernel. Ideally the resource should be deleted when it is no longer held
+// by anyone, but it would break backward compatibility.
+IteratorHandleOp::~IteratorHandleOp() {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<IteratorResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
}
}
}
+}
- void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- if (resource_ == nullptr) {
- FunctionLibraryRuntime* lib;
- std::unique_ptr<DeviceMgr> device_mgr(nullptr);
- std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
- // If the iterator is shared then we construct a new FLR, and pass that
- // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
- // functions from the iterator. We may add this functionality if there
- // is sufficient demand, but it will require a significant refactoring.
- if (!name_.empty()) {
- lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
- } else {
- OP_REQUIRES_OK(context, context->function_library()->Clone(
- &flib_def, &pflr, &lib));
- }
-
- ResourceMgr* mgr = context->resource_manager();
- OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
-
- IteratorResource* resource;
- OP_REQUIRES_OK(
- context,
- mgr->LookupOrCreate<IteratorResource>(
- cinfo_.container(), cinfo_.name(), &resource,
- [lib, &device_mgr, &flib_def, &pflr,
- this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- *ret = new IteratorResource(
- output_dtypes_, output_shapes_, graph_def_version_,
- std::move(device_mgr), std::move(flib_def),
- std::move(pflr), lib);
- return Status::OK();
- }));
+void IteratorHandleOp::Compute(OpKernelContext* context) LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<DeviceMgr> device_mgr(nullptr);
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ // If the iterator is shared then we construct a new FLR, and pass that
+ // in. NOTE(mrry,rohanj): In this case it is not possible to call remote
+ // functions from the iterator. We may add this functionality if there
+ // is sufficient demand, but it will require a significant refactoring.
+ if (!name_.empty()) {
+ lib = CreatePrivateFLR(context, &device_mgr, &flib_def, &pflr);
+ } else {
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
+ }
- Status s = VerifyResource(resource);
- if (TF_PREDICT_FALSE(!s.ok())) {
- resource->Unref();
- context->SetStatus(s);
- return;
- }
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
- resource_ = resource;
+ IteratorResource* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<IteratorResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [lib, &device_mgr, &flib_def, &pflr, this](IteratorResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new IteratorResource(
+ output_dtypes_, output_shapes_, graph_def_version_,
+ std::move(device_mgr), std::move(flib_def),
+ std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
}
- }
- OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
- context, 0, cinfo_.container(), cinfo_.name(),
- MakeTypeIndex<IteratorResource>()));
- }
-
- private:
- // During the first Compute(), resource is either created or looked up using
- // shared_name. In the latter case, the resource found should be verified if
- // it is compatible with this op's configuration. The verification may fail in
- // cases such as two graphs asking queues of the same shared name to have
- // inconsistent capacities.
- Status VerifyResource(IteratorResource* resource) {
- TF_RETURN_IF_ERROR(
- VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
- TF_RETURN_IF_ERROR(
- VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
- return Status::OK();
- }
- template <typename To, typename From> // use like this: down_cast<T*>(foo);
- static inline To down_cast(From* f) { // so we only accept pointers
- static_assert(
- (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
- "target type not derived from source type");
-
- // We skip the assert and hence the dynamic_cast if RTTI is disabled.
-#if !defined(__GNUC__) || defined(__GXX_RTTI)
- // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
- assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
-#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
- return static_cast<To>(f);
+ resource_ = resource;
+ }
}
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<IteratorResource>()));
+}
- FunctionLibraryRuntime* CreatePrivateFLR(
- OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
- std::unique_ptr<FunctionLibraryDefinition>* flib_def,
- std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
- // Wrap the existing device in order to see any captured resources
- // in its resource manager. The existing device will outlive the
- // IteratorResource, because we are storing the IteratorResource
- // in that device's resource manager.
- Device* wrapped_device = RenamedDevice::NewRenamedDevice(
- ctx->device()->name(), down_cast<Device*>(ctx->device()),
- false /* owns_underlying */, false /* isolate_session_state */);
- device_mgr->reset(new DeviceMgr({wrapped_device}));
- flib_def->reset(new FunctionLibraryDefinition(
- *ctx->function_library()->GetFunctionLibraryDefinition()));
- pflr->reset(new ProcessFunctionLibraryRuntime(
- device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
- {} /* TODO(mrry): OptimizerOptions? */,
- nullptr /* TODO(mrry): ClusterFLR */));
-
- return (*pflr)->GetFLR(ctx->device()->name());
- }
+Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+}
- mutex mu_;
- ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
- IteratorResource* resource_ GUARDED_BY(mu_) = nullptr;
- DataTypeVector output_dtypes_;
- std::vector<PartialTensorShape> output_shapes_;
- const int graph_def_version_;
- string name_;
-};
+FunctionLibraryRuntime* IteratorHandleOp::CreatePrivateFLR(
+ OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
+ std::unique_ptr<FunctionLibraryDefinition>* flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr) {
+ // Wrap the existing device in order to see any captured resources
+ // in its resource manager. The existing device will outlive the
+ // IteratorResource, because we are storing the IteratorResource
+ // in that device's resource manager.
+ Device* wrapped_device = RenamedDevice::NewRenamedDevice(
+ ctx->device()->name(), down_cast<Device*>(ctx->device()),
+ false /* owns_underlying */, false /* isolate_session_state */);
+ device_mgr->reset(new DeviceMgr({wrapped_device}));
+ flib_def->reset(new FunctionLibraryDefinition(
+ *ctx->function_library()->GetFunctionLibraryDefinition()));
+ pflr->reset(new ProcessFunctionLibraryRuntime(
+ device_mgr->get(), ctx->env(), graph_def_version_, flib_def->get(),
+ {} /* TODO(mrry): OptimizerOptions? */,
+ nullptr /* TODO(mrry): ClusterFLR */));
+
+ return (*pflr)->GetFLR(ctx->device()->name());
+}
// Like IteratorHandleOp, but creates handles which are never shared, and does
// not hold a reference to these handles. The latter is important for eager
// execution, since OpKernel instances generally live as long as the program
// running them.
-class AnonymousIteratorHandleOp : public OpKernel {
- public:
- explicit AnonymousIteratorHandleOp(OpKernelConstruction* context)
- : OpKernel(context), graph_def_version_(context->graph_def_version()) {
- OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
- OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
- }
+AnonymousIteratorHandleOp::AnonymousIteratorHandleOp(
+ OpKernelConstruction* context)
+ : OpKernel(context), graph_def_version_(context->graph_def_version()) {
+ OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(context, context->GetAttr("output_shapes", &output_shapes_));
+}
- void Compute(OpKernelContext* context) override {
- FunctionLibraryRuntime* lib;
- std::unique_ptr<DeviceMgr> device_mgr(nullptr);
- std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
- OP_REQUIRES_OK(context,
- context->function_library()->Clone(&flib_def, &pflr, &lib));
+void AnonymousIteratorHandleOp::Compute(OpKernelContext* context) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<DeviceMgr> device_mgr(nullptr);
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context,
+ context->function_library()->Clone(&flib_def, &pflr, &lib));
- ResourceMgr* mgr = context->resource_manager();
+ ResourceMgr* mgr = context->resource_manager();
- const string container_name = "AnonymousIterator";
- string unique_name;
- {
- mutex_lock l(static_resource_lookup_mutex_);
- while (true) { // Find an unused name
- IteratorResource* existing_resource = nullptr;
- unique_name = strings::StrCat("AnonymousIterator", current_id_++);
- Status status = mgr->Lookup<IteratorResource>(
- container_name, unique_name, &existing_resource);
- if (status.code() == error::NOT_FOUND) {
- break;
- }
- OP_REQUIRES_OK(context, status);
- existing_resource->Unref();
+ const string container_name = "AnonymousIterator";
+ string unique_name;
+ {
+ mutex_lock l(static_resource_lookup_mutex_);
+ while (true) { // Find an unused name
+ IteratorResource* existing_resource = nullptr;
+ unique_name = strings::StrCat("AnonymousIterator", current_id_++);
+ Status status = mgr->Lookup<IteratorResource>(container_name, unique_name,
+ &existing_resource);
+ if (status.code() == error::NOT_FOUND) {
+ break;
}
- IteratorResource* new_resource = new IteratorResource(
- output_dtypes_, output_shapes_, graph_def_version_,
- std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
- // Create the resource with our chosen name under the resource lookup
- // mutex to avoid another kernel racily creating a resource with this
- // name.
- OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
- container_name, unique_name, new_resource));
+ OP_REQUIRES_OK(context, status);
+ existing_resource->Unref();
}
- OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
- context, 0, container_name, unique_name,
- MakeTypeIndex<IteratorResource>()));
+ IteratorResource* new_resource = new IteratorResource(
+ output_dtypes_, output_shapes_, graph_def_version_,
+ std::move(device_mgr), std::move(flib_def), std::move(pflr), lib);
+ // Create the resource with our chosen name under the resource lookup
+ // mutex to avoid another kernel racily creating a resource with this
+ // name.
+ OP_REQUIRES_OK(context, mgr->Create<IteratorResource>(
+ container_name, unique_name, new_resource));
}
-
- private:
- // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
- // instances.
- static mutex static_resource_lookup_mutex_;
- // current_id_ is just a hint for creating unique names. If it turns out
- // there's a collision (e.g. because another AnonymousIteratorHandleOp
- // instance is generating handles) we'll just skip that id.
- static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
- DataTypeVector output_dtypes_;
- std::vector<PartialTensorShape> output_shapes_;
- const int graph_def_version_;
-};
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, container_name, unique_name,
+ MakeTypeIndex<IteratorResource>()));
+}
// Static initializers for AnonymousIteratorHandleOp id counting.
mutex AnonymousIteratorHandleOp::static_resource_lookup_mutex_{
LINKER_INITIALIZED};
int64 AnonymousIteratorHandleOp::current_id_(0);
-class MakeIteratorOp : public OpKernel {
- public:
- explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- DatasetBase* dataset;
- OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
- IteratorResource* iterator_resource;
- OP_REQUIRES_OK(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
- core::ScopedUnref unref(iterator_resource);
-
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
- std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx,
- dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
- OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
- }
-};
-
-// A simple background worker that executes closures asynchronously and without
-// blocking.
-//
-// A `BackgroundWorker` is used to offload blocking work from an `AsyncOpKernel`
-// to avoid blocking an executor thread that may be required by the blocking
-// work.
-//
-// NOTE(mrry): We do not use a regular `tensorflow::thread::ThreadPool` for this
-// purpose because its current implementation (in Eigen) uses a finite-length
-// queue and will block the caller when full. This can lead to deadlock under
-// heavy load. Since the number of concurrent work items in each user of a
-// `BackgroundWorker` is at most one per op invocation, the dynamic allocation
-// overhead is tolerable.
-class BackgroundWorker {
- public:
- BackgroundWorker(Env* env, const string& name) {
- thread_.reset(env->StartThread({} /* thread_options */, name,
- [this]() { WorkerLoop(); }));
- }
-
- ~BackgroundWorker() {
- {
- mutex_lock l(mu_);
- cancelled_ = true;
- }
- cond_var_.notify_one();
- // Block until the background thread has terminated.
- //
- // NOTE(mrry): We explicitly free and join the thread here because
- // `WorkerLoop()` uses other members of this object, and so we must join
- // the thread before destroying them.
- thread_.reset();
- }
-
- void Schedule(std::function<void()> work_item) {
- {
- mutex_lock l(mu_);
- work_queue_.push_back(std::move(work_item));
- }
- cond_var_.notify_one();
- }
-
- private:
- void WorkerLoop() {
- while (true) {
- std::function<void()> work_item = nullptr;
- {
- mutex_lock l(mu_);
- while (!cancelled_ && work_queue_.empty()) {
- cond_var_.wait(l);
- }
- if (cancelled_) {
- return;
- }
- DCHECK(!work_queue_.empty());
- work_item = std::move(work_queue_.front());
- work_queue_.pop_front();
- }
- DCHECK(work_item != nullptr);
- work_item();
- }
- }
-
- std::unique_ptr<Thread> thread_;
- mutex mu_;
- condition_variable cond_var_;
- bool cancelled_ GUARDED_BY(mu_) = false;
- std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
-};
+void MakeIteratorOp::Compute(OpKernelContext* ctx) {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
+ core::ScopedUnref unref(iterator_resource);
+
+ IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
+}
class ToSingleElementOp : public AsyncOpKernel {
public:
@@ -995,13 +877,92 @@ class OneShotIteratorOp : public AsyncOpKernel {
const int graph_def_version_;
};
-class IteratorGetNextOp : public AsyncOpKernel {
+void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule(std::bind(
+ [ctx, iterator](DoneCallback done) {
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
+
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ Status s = iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ // NOTE(mrry): We must unref the iterator before calling `done()`, to
+ // avoid destruction races.
+ iterator->Unref();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < components.size(); ++i) {
+ // TODO(mrry): Check that the shapes match the shape attrs.
+ ctx->set_output(i, components[i]);
+ }
+ }
+ done();
+ },
+ std::move(done)));
+}
+
+class IteratorGetNextSyncOp : public OpKernel {
+ public:
+ explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
+ core::ScopedUnref unref_iterator(iterator);
+
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
+
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ OP_REQUIRES_OK(ctx,
+ iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
+ OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
+
+ for (int i = 0; i < components.size(); ++i) {
+ // TODO(mrry): Check that the shapes match the shape attrs.
+ ctx->set_output(i, components[i]);
+ }
+ }
+};
+
+class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
public:
- explicit IteratorGetNextOp(OpKernelConstruction* ctx)
+ explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
- background_worker_(ctx->env(),
- strings::StrCat("iterator_get_next_thread_",
- SanitizeThreadSuffix(name()))) {}
+ background_worker_(
+ ctx->env(), strings::StrCat("iterator_get_next_as_optional_thread_",
+ SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
@@ -1011,7 +972,7 @@ class IteratorGetNextOp : public AsyncOpKernel {
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
background_worker_.Schedule(std::bind(
- [ctx, iterator](DoneCallback done) {
+ [this, ctx, iterator](DoneCallback done) {
std::vector<Tensor> components;
bool end_of_sequence = false;
@@ -1034,12 +995,32 @@ class IteratorGetNextOp : public AsyncOpKernel {
if (!s.ok()) {
ctx->SetStatus(s);
} else if (end_of_sequence) {
- ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
} else {
for (int i = 0; i < components.size(); ++i) {
- // TODO(mrry): Check that the shapes match the shape attrs.
- ctx->set_output(i, components[i]);
+ OP_REQUIRES_ASYNC(
+ ctx, components[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The given optional does not match the expected type for "
+ "component ",
+ i, ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(components[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx,
+ output_shapes_[i].IsCompatibleWith(components[i].shape()),
+ errors::InvalidArgument(
+ "The given optional does not match the expected shape "
+ "for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", components[i].shape().DebugString(), "."),
+ done);
}
+
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
+ done);
}
done();
},
@@ -1048,126 +1029,80 @@ class IteratorGetNextOp : public AsyncOpKernel {
private:
BackgroundWorker background_worker_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
};
-class IteratorGetNextSyncOp : public OpKernel {
- public:
- explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IteratorResource* iterator;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
- core::ScopedUnref unref_iterator(iterator);
-
- std::vector<Tensor> components;
- bool end_of_sequence = false;
-
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.function_library = iterator->function_library();
- DeviceBase* device = ctx->function_library()->device();
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- IteratorContext iter_ctx(std::move(params));
-
- OP_REQUIRES_OK(ctx,
- iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
- OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
-
- for (int i = 0; i < components.size(); ++i) {
- // TODO(mrry): Check that the shapes match the shape attrs.
- ctx->set_output(i, components[i]);
- }
- }
-};
-
-class IteratorToStringHandleOp : public OpKernel {
- public:
- explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& resource_handle_t = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
- errors::InvalidArgument("resource_handle must be a scalar"));
-
- // Validate that the handle corresponds to a real resource, and
- // that it is an IteratorResource.
- IteratorResource* iterator_resource;
- OP_REQUIRES_OK(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
- iterator_resource->Unref();
+void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an IteratorResource.
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator_resource));
+ iterator_resource->Unref();
+
+ Tensor* string_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &string_handle_t));
+ string_handle_t->scalar<string>()() =
+ resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
+}
- Tensor* string_handle_t;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_output(0, TensorShape({}), &string_handle_t));
- string_handle_t->scalar<string>()() =
- resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
- }
-};
+IteratorFromStringHandleOp::IteratorFromStringHandleOp(
+ OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES(
+ ctx,
+ output_dtypes_.empty() || output_shapes_.empty() ||
+ output_dtypes_.size() == output_shapes_.size(),
+ errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
+ "are set, they must have the same length."));
+}
-class IteratorFromStringHandleOp : public OpKernel {
- public:
- explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx)
- : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- OP_REQUIRES(
- ctx,
- output_dtypes_.empty() || output_shapes_.empty() ||
- output_dtypes_.size() == output_shapes_.size(),
- errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
- "are set, they must have the same length."));
+void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
+ const Tensor& string_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
+ errors::InvalidArgument("string_handle must be a scalar"));
+
+ ResourceHandle resource_handle;
+ OP_REQUIRES(
+ ctx, resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+ errors::InvalidArgument(
+ "Could not parse string_handle as a valid ResourceHandle"));
+
+ OP_REQUIRES(
+ ctx, resource_handle.device() == ctx->device()->attributes().name(),
+ errors::InvalidArgument("Attempted create an iterator on device \"",
+ ctx->device()->attributes().name(),
+ "\" from handle defined on device \"",
+ resource_handle.device(), "\""));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an IteratorResource.
+ IteratorResource* iterator_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &iterator_resource));
+ core::ScopedUnref unref_iterator(iterator_resource);
+ if (!output_dtypes_.empty()) {
+ OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
+ iterator_resource->output_dtypes()));
}
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& string_handle_t = ctx->input(0);
- OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
- errors::InvalidArgument("string_handle must be a scalar"));
-
- ResourceHandle resource_handle;
- OP_REQUIRES(
- ctx,
- resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
- errors::InvalidArgument(
- "Could not parse string_handle as a valid ResourceHandle"));
-
- OP_REQUIRES(
- ctx, resource_handle.device() == ctx->device()->attributes().name(),
- errors::InvalidArgument("Attempted create an iterator on device \"",
- ctx->device()->attributes().name(),
- "\" from handle defined on device \"",
- resource_handle.device(), "\""));
-
- // Validate that the handle corresponds to a real resource, and
- // that it is an IteratorResource.
- IteratorResource* iterator_resource;
+ if (!output_shapes_.empty()) {
OP_REQUIRES_OK(ctx,
- LookupResource(ctx, resource_handle, &iterator_resource));
- core::ScopedUnref unref_iterator(iterator_resource);
- if (!output_dtypes_.empty()) {
- OP_REQUIRES_OK(ctx, VerifyTypesMatch(output_dtypes_,
- iterator_resource->output_dtypes()));
- }
- if (!output_shapes_.empty()) {
- OP_REQUIRES_OK(
- ctx, VerifyShapesCompatible(output_shapes_,
- iterator_resource->output_shapes()));
- }
-
- Tensor* resource_handle_t;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
- resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+ VerifyShapesCompatible(output_shapes_,
+ iterator_resource->output_shapes()));
}
- private:
- DataTypeVector output_dtypes_;
- std::vector<PartialTensorShape> output_shapes_;
-};
+ Tensor* resource_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
+ resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+}
class SerializeIteratorOp : public OpKernel {
public:
@@ -1240,6 +1175,10 @@ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_CPU),
IteratorGetNextSyncOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE_GPU),
IteratorGetNextSyncOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_CPU),
+ IteratorGetNextAsOptionalOp);
+REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE_GPU),
+ IteratorGetNextAsOptionalOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle").Device(DEVICE_CPU),
IteratorToStringHandleOp);
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")
@@ -1259,6 +1198,4 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
-} // namespace
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
new file mode 100644
index 0000000000..e426febcce
--- /dev/null
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/ops_util.h"
+
+namespace tensorflow {
+
+class IteratorResource;
+
+class IteratorHandleOp : public OpKernel {
+ public:
+ explicit IteratorHandleOp(OpKernelConstruction* ctx);
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel. Ideally the resource should be deleted when it is no longer held
+ // by anyone, but it would break backward compatibility.
+ ~IteratorHandleOp() override;
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_);
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(IteratorResource* resource);
+
+ template <typename To, typename From> // use like this: down_cast<T*>(foo);
+ static inline To down_cast(From* f) { // so we only accept pointers
+ static_assert(
+ (std::is_base_of<From, typename std::remove_pointer<To>::type>::value),
+ "target type not derived from source type");
+
+ // We skip the assert and hence the dynamic_cast if RTTI is disabled.
+#if !defined(__GNUC__) || defined(__GXX_RTTI)
+ // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds.
+ assert(f == nullptr || dynamic_cast<To>(f) != nullptr);
+#endif // !defined(__GNUC__) || defined(__GXX_RTTI)
+ return static_cast<To>(f);
+ }
+
+ FunctionLibraryRuntime* CreatePrivateFLR(
+ OpKernelContext* ctx, std::unique_ptr<DeviceMgr>* device_mgr,
+ std::unique_ptr<FunctionLibraryDefinition>* flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime>* pflr);
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ IteratorResource* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+ string name_;
+};
+
+// Like IteratorHandleOp, but creates handles which are never shared, and does
+// not hold a reference to these handles. The latter is important for eager
+// execution, since OpKernel instances generally live as long as the program
+// running them.
+class AnonymousIteratorHandleOp : public OpKernel {
+ public:
+ explicit AnonymousIteratorHandleOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* context) override;
+
+ private:
+ // Coordinates Iterator unique name creation across AnonymousIteratorHandleOp
+ // instances.
+ static mutex static_resource_lookup_mutex_;
+ // current_id_ is just a hint for creating unique names. If it turns out
+ // there's a collision (e.g. because another AnonymousIteratorHandleOp
+ // instance is generating handles) we'll just skip that id.
+ static int64 current_id_ GUARDED_BY(static_resource_lookup_mutex_);
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+};
+
+class MakeIteratorOp : public OpKernel {
+ public:
+ explicit MakeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
+class IteratorGetNextOp : public AsyncOpKernel {
+ public:
+ explicit IteratorGetNextOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(ctx->env(),
+ strings::StrCat("iterator_get_next_thread_",
+ SanitizeThreadSuffix(name()))) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
+
+ private:
+ BackgroundWorker background_worker_;
+};
+
+class IteratorToStringHandleOp : public OpKernel {
+ public:
+ explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
+class IteratorFromStringHandleOp : public OpKernel {
+ public:
+ explicit IteratorFromStringHandleOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
new file mode 100644
index 0000000000..d66716ef66
--- /dev/null
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -0,0 +1,192 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/batch_util.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+namespace tensorflow {
+namespace {
+
+void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
+ bool always_collect_stats) {
+ opts->step_id = ctx->step_id();
+ opts->rendezvous = ctx->rendezvous();
+ opts->cancellation_manager = ctx->cancellation_manager();
+ if (always_collect_stats) {
+ opts->stats_collector = ctx->stats_collector();
+ }
+ opts->runner = ctx->runner();
+}
+
+class MapDefunOp : public AsyncOpKernel {
+ public:
+ explicit MapDefunOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ auto func_lib = ctx->function_library();
+ OP_REQUIRES(ctx, func_lib != nullptr,
+ errors::Internal("No function library."));
+ const NameAttrList* func;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func));
+ OP_REQUIRES_OK(ctx,
+ func_lib->Instantiate(func->name(), AttrSlice(&func->attr()),
+ &func_handle_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+
+ OP_REQUIRES(ctx, ctx->num_inputs() >= 0,
+ errors::InvalidArgument("Must have at least one input."));
+ OP_REQUIRES(ctx, ctx->num_outputs() >= 0,
+ errors::InvalidArgument("Must have at least one output."));
+ OP_REQUIRES(ctx, ctx->num_outputs() == output_shapes_.size(),
+ errors::InvalidArgument(
+ "Length of output_shapes and output_types must match."));
+ }
+
+ ~MapDefunOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ int64 batch_size = ctx->input(0).dim_size(0);
+ // Inputs
+ auto* args = new std::vector<Tensor>;
+ auto* arg_shapes = new std::vector<TensorShape>;
+ arg_shapes->reserve(ctx->num_inputs());
+ args->reserve(ctx->num_inputs());
+
+ for (size_t i = 0; i < ctx->num_inputs(); ++i) {
+ args->push_back(ctx->input(i));
+ arg_shapes->push_back(ctx->input(i).shape());
+ arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
+ OP_REQUIRES_ASYNC(
+ ctx, batch_size == ctx->input(i).dim_size(0),
+ errors::InvalidArgument("All inputs must have the same dimension 0."),
+ done);
+ }
+
+ // Outputs
+ auto* output = new OpOutputList;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ Tensor* out = nullptr;
+ TensorShape output_shape = output_shapes_.at(i);
+ output_shape.InsertDim(0, batch_size);
+ OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
+ }
+
+ SetRunOptions(ctx, &opts_, false);
+
+ // Run loop
+ StatusCallback callback = std::bind(
+ [](OpKernelContext* ctx, std::vector<Tensor>* args,
+ std::vector<TensorShape>* arg_shapes, OpOutputList* output,
+ DoneCallback& done, const Status& status) {
+ delete args;
+ delete arg_shapes;
+ delete output;
+ ctx->SetStatus(status);
+ done();
+ },
+ ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+
+ auto* refcounted = new ReffedStatusCallback(std::move(callback));
+
+ for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
+ // Start from i = 1 because refcounted is initialized with refcount = 1
+ refcounted->Ref();
+ }
+ for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
+ auto* call_frame =
+ new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
+ ctx->function_library()->Run(
+ opts_, func_handle_, call_frame,
+ [call_frame, refcounted](const Status& func_status) {
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
+ }
+ }
+
+ private:
+ FunctionLibraryRuntime::Handle func_handle_;
+ FunctionLibraryRuntime::Options opts_;
+ std::vector<TensorShape> output_shapes_;
+
+ class MapFunctionCallFrame : public CallFrameInterface {
+ public:
+ MapFunctionCallFrame(const std::vector<Tensor>& args,
+ const std::vector<TensorShape>& arg_shapes,
+ OpOutputList* output, OpKernel* kernel, size_t iter)
+ : args_(args),
+ arg_shapes_(arg_shapes),
+ output_(output),
+ kernel_(kernel),
+ iter_(iter) {}
+
+ ~MapFunctionCallFrame() override {}
+
+ size_t num_args() const override { return args_.size(); }
+ size_t num_retvals() const override {
+ return static_cast<size_t>(kernel_->num_outputs());
+ }
+
+ Status GetArg(int index, Tensor* val) const override {
+ if (index < 0 || index >= args_.size()) {
+ return errors::InvalidArgument(
+ "Mismatch in number of function inputs.");
+ }
+ bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
+ arg_shapes_.at(index));
+ if (!result) {
+ return errors::Internal("GetArg failed.");
+ } else if (!val->IsAligned()) {
+ // Ensure alignment
+ *val = tensor::DeepCopy(*val);
+ }
+
+ return Status::OK();
+ }
+
+ Status SetRetval(int index, const Tensor& val) override {
+ if (index < 0 || index >= kernel_->num_outputs()) {
+ return errors::InvalidArgument(
+ "Mismatch in number of function outputs.");
+ }
+
+ if (val.dtype() != kernel_->output_type(index)) {
+ return errors::InvalidArgument(
+ "Mismatch in function return type and expected output type for "
+ "output: ",
+ index);
+ }
+ return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
+ }
+
+ private:
+ const std::vector<Tensor>& args_;
+ const std::vector<TensorShape>& arg_shapes_;
+ OpOutputList* output_;
+ const OpKernel* kernel_;
+ const size_t iter_;
+ };
+}; // namespace
+
+REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
new file mode 100644
index 0000000000..cfac45dbc7
--- /dev/null
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -0,0 +1,270 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/optional_ops.h"
+
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
+
+namespace tensorflow {
+namespace {
+const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
+
+// An `OptionalVariant` can represent either an "actual value" (a tuple of
+// tensors) or "none", and may be stored in a DT_VARIANT tensor.
+class OptionalVariant {
+ public:
+ // Create an `OptionalVariant` with no actual value.
+ OptionalVariant() : values_(nullptr) {}
+
+ // Create an `OptionalVariant` with the actual value given by the tuple of
+ // tensors in `values`.
+ explicit OptionalVariant(std::vector<Tensor> values)
+ : values_(new std::vector<Tensor>(std::move(values))) {}
+
+ OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
+
+ // Returns true if `this` represents an actual value.
+ bool has_value() const { return values_ != nullptr; }
+
+ // REQUIRES: `this->has_value()` must be true.
+ const std::vector<Tensor>& get_values() const {
+ CHECK(values_) << "Tried to get values from an empty OptionalVariant";
+ return *values_;
+ }
+
+ // Implementations of the necessary methods for using `OptionalVariant`
+ // objects in DT_VARIANT tensors.
+ string TypeName() const { return kOptionalVariantTypeName; }
+ void Encode(VariantTensorData* data) const {
+ data->set_metadata(values_ != nullptr);
+ if (values_ != nullptr) {
+ for (const auto& t : *values_) {
+ *(data->add_tensors()) = t;
+ }
+ }
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ if (data.type_name() != TypeName()) {
+ return false;
+ }
+ bool has_value = false;
+ if (!data.get_metadata(&has_value)) {
+ return false;
+ }
+ if (has_value) {
+ values_.reset(new std::vector<Tensor>(data.tensors()));
+ } else {
+ values_.reset();
+ }
+ return true;
+ }
+
+ string DebugString() const {
+ if (values_) {
+ return strings::StrCat("OptionalVariant<", "values: (",
+ str_util::Join(*values_, ", ",
+ [](string* s, const Tensor& elem) {
+ *s = elem.DebugString();
+ }),
+ ")>");
+ } else {
+ return strings::StrCat("OptionalVariant<None>");
+ }
+ }
+
+ private:
+ std::shared_ptr<const std::vector<Tensor>> values_;
+};
+
+class OptionalNoneOp : public OpKernel {
+ public:
+ explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0));
+ }
+};
+
+class OptionalFromValueOp : public OpKernel {
+ public:
+ explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ OpInputList components_input;
+ OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
+ std::vector<Tensor> components;
+ components.reserve(components_input.size());
+ for (const Tensor& component_t : components_input) {
+ components.push_back(component_t);
+ }
+ OP_REQUIRES_OK(
+ ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
+ }
+};
+
+class OptionalHasValueOp : public OpKernel {
+ public:
+ explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* optional_input;
+ OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be a scalar tensor "
+ "containing an OptionalVariant object."));
+ const OptionalVariant* optional =
+ optional_input->scalar<Variant>()().get<OptionalVariant>();
+ OP_REQUIRES(
+ ctx, optional != nullptr,
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be an OptionalVariant object."));
+ Tensor* result;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
+ result->scalar<bool>()() = optional->has_value();
+ }
+};
+
+class OptionalGetValueOp : public OpKernel {
+ public:
+ explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* optional_input;
+ OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be a scalar tensor "
+ "containing an OptionalVariant object."));
+ const OptionalVariant* optional =
+ optional_input->scalar<Variant>()().get<OptionalVariant>();
+ OP_REQUIRES(
+ ctx, optional != nullptr,
+ errors::InvalidArgument(
+ "Input to OptionalHasValue must be an OptionalVariant object."));
+ OP_REQUIRES(
+ ctx, optional->has_value(),
+ errors::InvalidArgument("The given optional does not have a value."));
+ const auto& components = optional->get_values();
+ for (int i = 0; i < components.size(); ++i) {
+ OP_REQUIRES(
+ ctx, components[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The given optional does not match the expected type for "
+ "component ",
+ i, ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(components[i].dtype()), "."));
+ OP_REQUIRES(ctx,
+ output_shapes_[i].IsCompatibleWith(components[i].shape()),
+ errors::InvalidArgument(
+ "The given optional does not match the expected shape "
+ "for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", components[i].shape().DebugString(), "."));
+ ctx->set_output(i, components[i]);
+ }
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU),
+ OptionalNoneOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU),
+ OptionalNoneOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_CPU),
+ OptionalFromValueOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE_GPU),
+ OptionalFromValueOp);
+
+REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU),
+ OptionalHasValueOp);
+REGISTER_KERNEL_BUILDER(
+ Name("OptionalHasValue").Device(DEVICE_GPU).HostMemory("has_value"),
+ OptionalHasValueOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU),
+ OptionalGetValueOp);
+REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU),
+ OptionalGetValueOp);
+
+static Status OptionalDeviceCopy(
+ const OptionalVariant& from, OptionalVariant* to,
+ const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
+ if (from.has_value()) {
+ const std::vector<Tensor>& from_values = from.get_values();
+ std::vector<Tensor> to_values;
+ to_values.reserve(from_values.size());
+ for (const Tensor& t : from_values) {
+ if (DMAHelper::CanUseDMA(&t)) {
+ Tensor tmp(t.dtype());
+ TF_RETURN_IF_ERROR(copy(t, &tmp));
+ to_values.push_back(std::move(tmp));
+ } else {
+ to_values.push_back(t);
+ }
+ }
+ *to = OptionalVariant(std::move(to_values));
+ } else {
+ *to = from;
+ }
+ return Status::OK();
+}
+
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
+ OptionalDeviceCopy)
+
+REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
+REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
+REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
+
+REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
+ kOptionalVariantTypeName);
+
+} // namespace
+
+Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
+ std::vector<Tensor> value) {
+ OptionalVariant v(std::move(value));
+ Tensor* variant_t;
+ AllocatorAttributes cpu_alloc;
+ cpu_alloc.set_on_host(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
+ &variant_t, cpu_alloc));
+ variant_t->scalar<Variant>()() = v;
+ return Status::OK();
+}
+
+Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
+ OptionalVariant v;
+ Tensor* variant_t;
+ AllocatorAttributes cpu_alloc;
+ cpu_alloc.set_on_host(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
+ &variant_t, cpu_alloc));
+ variant_t->scalar<Variant>()() = v;
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
new file mode 100644
index 0000000000..6f25567678
--- /dev/null
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+
+namespace tensorflow {
+
+// Stores a DT_VARIANT value representing an Optional with the given value
+// in the `output_index`^th output of the given kernel execution context.
+Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
+ std::vector<Tensor> value);
+
+// Stores a DT_VARIANT value representing an Optional with no value
+// in the `output_index`^th output of the given kernel execution context.
+Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 15f3dc3b1d..b736b33c2e 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -87,8 +88,16 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::ParallelMap")}));
+ auto map_func = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done));
+ };
+
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
+ std::move(map_func), num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -148,279 +157,6 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- ~Iterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
- mutex_lock l(mu_);
- // Cancel the runner thread.
- cancelled_ = true;
- cond_var_.notify_all();
- // Wait for all in-flight calls to complete.
- while (num_calls_ > 0) {
- cond_var_.wait(l);
- }
- }
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- std::shared_ptr<InvocationResult> result;
- {
- mutex_lock l(mu_);
- EnsureRunnerThreadStarted(ctx);
- while (invocation_results_.empty()) {
- cond_var_.wait(l);
- }
- std::swap(result, invocation_results_.front());
- invocation_results_.pop_front();
- }
- cond_var_.notify_all();
- result->notification.WaitForNotification();
- return ProcessResult(result, out_tensors, end_of_sequence);
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- // Wait for all in-flight calls to complete.
- while (num_calls_ > 0) {
- cond_var_.wait(l);
- }
- CHECK_EQ(num_calls_, 0);
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name("invocation_results.size"), invocation_results_.size()));
- for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(
- strings::StrCat("invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
- }
- if (result->end_of_input) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("invocation_results[", i,
- "].end_of_input")),
- ""));
- }
- }
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- int64 invocation_results_size;
- TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name("invocation_results.size"), &invocation_results_size));
- for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
- size_t num_return_values;
- {
- int64 size;
- TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("invocation_results[", i, "].size")),
- &size));
- num_return_values = static_cast<size_t>(size);
- if (num_return_values != size) {
- return errors::InvalidArgument(strings::StrCat(
- full_name(
- strings::StrCat("invocation_results[", i, "].size")),
- ": ", size, " is not a valid value of type size_t."));
- }
- }
- result->return_values.reserve(num_return_values);
- for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
- }
- result->end_of_input = reader->Contains(full_name(
- strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
- }
- return Status::OK();
- }
-
- private:
- struct InvocationResult {
- Notification notification;
- Status status;
- std::vector<Tensor> return_values;
- bool end_of_input;
- };
-
- void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
- runner_thread_.reset(ctx->env()->StartThread(
- {}, "runner_thread",
- std::bind(&Iterator::RunnerThread, this, ctx_copy)));
- }
- }
-
- void CallCompleted(const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- }
- result->notification.Notify();
- cond_var_.notify_all();
- }
-
- void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
- const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
- // Get the next input element.
- std::vector<Tensor> input_element;
- result->status = input_impl_->GetNext(ctx.get(), &input_element,
- &result->end_of_input);
- if (result->end_of_input || !result->status.ok()) {
- CallCompleted(result);
- return;
- }
-
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification` to unblock
- // a consumer.
- auto done = [this, result](Status status) {
- result->status.Update(status);
- CallCompleted(result);
- };
- dataset()->captured_func_->RunAsync(ctx.get(), std::move(input_element),
- &result->return_values, done);
- }
-
- int64 MaxInvocationResults() { return dataset()->num_parallel_calls_; }
-
- Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) {
- if (!result->end_of_input && result->status.ok()) {
- *out_tensors = std::move(result->return_values);
- *end_of_sequence = false;
- return Status::OK();
- }
- if (errors::IsOutOfRange(result->status)) {
- // `f` may deliberately raise `errors::OutOfRange` to indicate that we
- // should terminate the iteration early.
- *end_of_sequence = true;
- return Status::OK();
- }
- *end_of_sequence = result->end_of_input;
- return result->status;
- }
-
- void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
- std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
- while (true) {
- {
- mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- cond_var_.wait(l);
- }
- if (cancelled_) {
- return;
- }
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- invocation_results_.emplace_back(new InvocationResult());
- new_calls.push_back(invocation_results_.back());
- num_calls_++;
- }
- }
- cond_var_.notify_all();
- for (const auto& call : new_calls) {
- CallFunction(ctx, call);
- }
- new_calls.clear();
- }
- }
-
- Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
- const Status& status)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- CodeKey(index), static_cast<int64>(status.code())));
- if (!status.ok()) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
- status.error_message()));
- }
- return Status::OK();
- }
-
- Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 code_int;
- TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
- error::Code code = static_cast<error::Code>(code_int);
-
- if (code != error::Code::OK) {
- string error_message;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(ErrorMessageKey(index), &error_message));
- *status = Status(code, error_message);
- } else {
- *status = Status::OK();
- }
- return Status::OK();
- }
-
- string CodeKey(size_t index) {
- return full_name(
- strings::StrCat("invocation_results[", index, "].code"));
- }
-
- string ErrorMessageKey(size_t index) {
- return full_name(
- strings::StrCat("invocation_results[", index, "].error_message"));
- }
-
- // Used for coordination between the main thread and the runner thread.
- mutex mu_;
- // Used for coordination between the main thread and the runner thread. In
- // particular, the runner thread should only schedule new calls when the
- // number of in-flight calls is less than the user specified level of
- // parallelism and there are slots available in the `invocation_results_`
- // buffer.
- condition_variable cond_var_;
- // Counts the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
- std::unique_ptr<IteratorBase> input_impl_;
- // Buffer for storing the invocation results.
- std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
- };
-
const DatasetBase* const input_;
const NameAttrList func_;
const int32 num_parallel_calls_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
new file mode 100644
index 0000000000..10549df25e
--- /dev/null
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -0,0 +1,318 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+
+#include <deque>
+#include <functional>
+#include <utility>
+#include <vector>
+
+namespace tensorflow {
+namespace {
+
+class ParallelMapIterator : public DatasetBaseIterator {
+ public:
+ explicit ParallelMapIterator(
+ const typename DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
+ int32 num_parallel_calls)
+ : DatasetBaseIterator(params),
+ input_dataset_(input_dataset),
+ map_func_(std::move(map_func)),
+ num_parallel_calls_(num_parallel_calls) {}
+
+ ~ParallelMapIterator() override {
+ // TODO(mrry): Replace this cancellation logic with a
+ // CancellationManager. The syntax would be more heavyweight,
+ // but it would be possible to thread a cancellation manager
+ // through the IteratorContext to upstream,
+ // potentially-blocking iterators, when we add these.
+ mutex_lock l(mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_.notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::shared_ptr<InvocationResult> result;
+ {
+ mutex_lock l(mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty()) {
+ cond_var_.wait(l);
+ }
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
+ }
+ cond_var_.notify_all();
+ result->notification.WaitForNotification();
+ return ProcessResult(result, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_.wait(l);
+ }
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("invocation_results.size"),
+ invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->end_of_input) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].end_of_input")),
+ ""));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name(strings::StrCat(
+ "invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
+ }
+ }
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->end_of_input = reader->Contains(full_name(
+ strings::StrCat("invocation_results[", i, "].end_of_input")));
+ result->notification.Notify();
+ }
+ return Status::OK();
+ }
+
+ private:
+ struct InvocationResult {
+ Notification notification;
+ Status status;
+ std::vector<Tensor> return_values;
+ bool end_of_input;
+ };
+
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
+ }
+ }
+
+ void CallCompleted(const std::shared_ptr<InvocationResult>& result)
+ LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ num_calls_--;
+ }
+ result->notification.Notify();
+ cond_var_.notify_all();
+ }
+
+ void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
+ const std::shared_ptr<InvocationResult>& result)
+ LOCKS_EXCLUDED(mu_) {
+ // Get the next input element.
+ std::vector<Tensor> input_element;
+ result->status =
+ input_impl_->GetNext(ctx.get(), &input_element, &result->end_of_input);
+ if (result->end_of_input || !result->status.ok()) {
+ CallCompleted(result);
+ return;
+ }
+
+ // Call `func_(input_element)`, store the result in
+ // `result->return_values`, and notify `result->notification` to unblock
+ // a consumer.
+ auto done = [this, result](Status status) {
+ result->status.Update(status);
+ CallCompleted(result);
+ };
+
+ map_func_(ctx.get(), std::move(input_element), &result->return_values,
+ std::move(done));
+ }
+
+ int64 MaxInvocationResults() { return num_parallel_calls_; }
+
+ Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ if (!result->end_of_input && result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (errors::IsOutOfRange(result->status)) {
+ // `f` may deliberately raise `errors::OutOfRange` to indicate that we
+ // should terminate the iteration early.
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ *end_of_sequence = result->end_of_input;
+ return result->status;
+ }
+
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ std::vector<std::shared_ptr<InvocationResult>> new_calls;
+ new_calls.reserve(num_parallel_calls_);
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ (num_calls_ >= num_parallel_calls_ ||
+ invocation_results_.size() >= MaxInvocationResults())) {
+ cond_var_.wait(l);
+ }
+ if (cancelled_) {
+ return;
+ }
+ while (num_calls_ < num_parallel_calls_ &&
+ invocation_results_.size() < MaxInvocationResults()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ new_calls.push_back(invocation_results_.back());
+ num_calls_++;
+ }
+ }
+ cond_var_.notify_all();
+ for (const auto& call : new_calls) {
+ CallFunction(ctx, call);
+ }
+ new_calls.clear();
+ }
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(ErrorMessageKey(index), status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].error_message"));
+ }
+
+ const DatasetBase* const input_dataset_; // Not owned.
+ const ParallelMapIteratorFunction map_func_;
+ const int32 num_parallel_calls_;
+ // Used for coordination between the main thread and the runner thread.
+ mutex mu_;
+ // Used for coordination between the main thread and the runner thread. In
+ // particular, the runner thread should only schedule new calls when the
+ // number of in-flight calls is less than the user specified level of
+ // parallelism and there are slots available in the `invocation_results_`
+ // buffer.
+ condition_variable cond_var_;
+ // Counts the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<IteratorBase> input_impl_;
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
+};
+
+} // namespace
+
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
+ int32 num_parallel_calls) {
+ return std::unique_ptr<IteratorBase>(new ParallelMapIterator(
+ params, input_dataset, std::move(map_func), num_parallel_calls));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
new file mode 100644
index 0000000000..2ce36c3869
--- /dev/null
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
+
+#include <memory>
+
+#include "tensorflow/core/framework/dataset.h"
+
+namespace tensorflow {
+
+// A function that transforms elements of one dataset into another
+// asynchronously. The arguments are:
+// 1. An `IteratorContext*` for the context in which the function should
+// execute.
+// 2. A `std::vector<Tensor>` containing the input element.
+// 3. A `std::vector<Tensor>*` to which the function will write the result.
+// 4. A `StatusCallback` that should be invoked when the function is complete.
+using ParallelMapIteratorFunction =
+ std::function<void(IteratorContext*, std::vector<Tensor>,
+ std::vector<Tensor>*, StatusCallback)>;
+
+// Returns a new iterator that applies `map_func` to the elements of
+// `input_dataset` using the given degree of parallelism.
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
+ int32 num_parallel_calls);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index cc16108dce..9000842840 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -14,347 +14,335 @@ limitations under the License.
==============================================================================*/
#include <deque>
+#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/data/dataset.h"
-#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
namespace tensorflow {
-namespace {
-
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class PrefetchDatasetOp : public UnaryDatasetOpKernel {
+class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
public:
- explicit PrefetchDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- protected:
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- int64 buffer_size;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
- OP_REQUIRES(ctx,
- buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune,
- errors::InvalidArgument("buffer_size must be >= 0"));
-
- *output = new Dataset(ctx, input, buffer_size);
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size)
+ : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
+ input_->Ref();
}
- private:
- class Dataset : public GraphDatasetBase {
- public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size)
- : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
- input_->Ref();
- }
+ ~Dataset() override { input_->Unref(); }
- ~Dataset() override { input_->Unref(); }
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Prefetch")}));
+ }
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Prefetch")}));
- }
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
- string DebugString() const override { return "PrefetchDatasetOp::Dataset"; }
+ string DebugString() const override { return "PrefetchDatasetOp::Dataset"; }
- protected:
- Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_graph_node = nullptr;
- TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
- Node* buffer_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
- TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, buffer_size}, output));
- return Status::OK();
- }
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_graph_node));
+ Node* buffer_size = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(buffer_size_, &buffer_size));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, buffer_size}, output));
+ return Status::OK();
+ }
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params),
- auto_tuner_(params.dataset->buffer_size_) {}
-
- ~Iterator() override {
- // Signal the prefetch thread to terminate it. We will then
- // join that thread when we delete `this->prefetch_thread_`.
- //
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
- {
- mutex_lock l(mu_);
- cancelled_ = true;
- cond_var_.notify_all();
- }
- }
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ auto_tuner_(params.dataset->buffer_size_) {}
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ ~Iterator() override {
+ // Signal the prefetch thread to terminate it. We will then
+ // join that thread when we delete `this->prefetch_thread_`.
+ //
+ // TODO(mrry): Replace this cancellation logic with a
+ // CancellationManager. The syntax would be more heavyweight,
+ // but it would be possible to thread a cancellation manager
+ // through the IteratorContext to upstream,
+ // potentially-blocking iterators, when we add these.
+ {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ cond_var_.notify_all();
}
+ }
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
- // Wait until the next element in the buffer has been
- // produced, or we are shutting down.
- while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
- auto_tuner_.buffer_limit() != 0) {
- auto_tuner_.RecordEmpty();
- cond_var_.wait(l);
- }
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
- if (cancelled_) {
- return errors::Cancelled(
- "PrefetchDatasetOp::Dataset::Iterator::GetNext");
- }
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
+ // Wait until the next element in the buffer has been
+ // produced, or we are shutting down.
+ while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
+ auto_tuner_.buffer_limit() != 0) {
+ auto_tuner_.RecordEmpty();
+ cond_var_.wait(l);
+ }
- if (!buffer_.empty()) {
- return Consume(out_tensors, end_of_sequence);
- }
+ if (cancelled_) {
+ return errors::Cancelled(
+ "PrefetchDatasetOp::Dataset::Iterator::GetNext");
+ }
- if (prefetch_thread_finished_) {
- *end_of_sequence = true;
- return Status::OK();
- }
+ if (!buffer_.empty()) {
+ return Consume(out_tensors, end_of_sequence);
+ }
- DCHECK_EQ(auto_tuner_.buffer_limit(), 0);
+ if (prefetch_thread_finished_) {
+ *end_of_sequence = true;
+ return Status::OK();
}
- mutex_lock parent_l(parent_mu_);
- mutex_lock l(mu_);
- return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ DCHECK_EQ(auto_tuner_.buffer_limit(), 0);
}
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- // Acquire both locks to ensure that the prefetch thread and
- // all GetNext threads are blocked.
- mutex_lock parent_l(parent_mu_);
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
- for (size_t i = 0; i < buffer_.size(); i++) {
- auto& buffer_element = buffer_[i];
- TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
- if (buffer_element.status.ok()) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("buffer[", i, "].size")),
- buffer_element.value.size()));
- for (size_t j = 0; j < buffer_element.value.size(); j++) {
- TF_RETURN_IF_ERROR(writer->WriteTensor(
- full_name(strings::StrCat("buffer[", i, "][", j, "]")),
- buffer_element.value[j]));
- }
+ mutex_lock parent_l(parent_mu_);
+ mutex_lock l(mu_);
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ // Acquire both locks to ensure that the prefetch thread and
+ // all GetNext threads are blocked.
+ mutex_lock parent_l(parent_mu_);
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("buffer_size"), buffer_.size()));
+ for (size_t i = 0; i < buffer_.size(); i++) {
+ auto& buffer_element = buffer_[i];
+ TF_RETURN_IF_ERROR(WriteStatus(writer, i, buffer_element.status));
+ if (buffer_element.status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("buffer[", i, "].size")),
+ buffer_element.value.size()));
+ for (size_t j = 0; j < buffer_element.value.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("buffer[", i, "][", j, "]")),
+ buffer_element.value[j]));
}
}
- return Status::OK();
}
+ return Status::OK();
+ }
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock parent_l(parent_mu_);
- mutex_lock l(mu_);
- buffer_.clear();
- TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
- size_t buffer_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(full_name("buffer_size"), &temp));
- buffer_size = static_cast<size_t>(temp);
- }
- for (size_t i = 0; i < buffer_size; i++) {
- buffer_.emplace_back();
- auto& buffer_element = buffer_.back();
- TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
- if (buffer_element.status.ok()) {
- size_t value_size;
- {
- int64 temp;
- TF_RETURN_IF_ERROR(reader->ReadScalar(
- full_name(strings::StrCat("buffer[", i, "].size")), &temp));
- value_size = static_cast<size_t>(temp);
- }
- buffer_element.value.reserve(value_size);
- for (size_t j = 0; j < value_size; j++) {
- buffer_element.value.emplace_back();
- TF_RETURN_IF_ERROR(reader->ReadTensor(
- full_name(strings::StrCat("buffer[", i, "][", j, "]")),
- &buffer_element.value.back()));
- }
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock parent_l(parent_mu_);
+ mutex_lock l(mu_);
+ buffer_.clear();
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ size_t buffer_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("buffer_size"), &temp));
+ buffer_size = static_cast<size_t>(temp);
+ }
+ for (size_t i = 0; i < buffer_size; i++) {
+ buffer_.emplace_back();
+ auto& buffer_element = buffer_.back();
+ TF_RETURN_IF_ERROR(ReadStatus(reader, i, &buffer_element.status));
+ if (buffer_element.status.ok()) {
+ size_t value_size;
+ {
+ int64 temp;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("buffer[", i, "].size")), &temp));
+ value_size = static_cast<size_t>(temp);
+ }
+ buffer_element.value.reserve(value_size);
+ for (size_t j = 0; j < value_size; j++) {
+ buffer_element.value.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("buffer[", i, "][", j, "]")),
+ &buffer_element.value.back()));
}
}
- return Status::OK();
}
+ return Status::OK();
+ }
- private:
- // A buffer element comprises a status and (if that status is
- // OK) a vector of tensors, representing an element of the input dataset.
- struct BufferElement {
- // The producer sets `status` if getting the input element fails.
- Status status;
- // The buffered data element.
- std::vector<Tensor> value;
- };
-
- Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- // A new element is available. Forward the status from computing it, and
- // (if we successfully got an element) the output values.
- Status s = buffer_.front().status;
- if (s.ok()) {
- *out_tensors = std::move(buffer_.front().value);
- }
- buffer_.pop_front();
- *end_of_sequence = false;
-
- // Wake the prefetch thread, in case it has been waiting for space
- // in the buffer. Also wake up threads from other calls to GetNext.
- //
- // TODO(mrry): Consider using different condition variables for
- // GetNext and Prefetch.
- cond_var_.notify_all();
- return s;
- }
+ private:
+ // A buffer element comprises a status and (if that status is
+ // OK) a vector of tensors, representing an element of the input dataset.
+ struct BufferElement {
+ // The producer sets `status` if getting the input element fails.
+ Status status;
+ // The buffered data element.
+ std::vector<Tensor> value;
+ };
- Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- if (!prefetch_thread_) {
- prefetch_thread_.reset(
- ctx->env()->StartThread({}, "prefetch_thread",
- std::bind(&Iterator::PrefetchThread, this,
- new IteratorContext(*ctx))));
- }
- return Status::OK();
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ // A new element is available. Forward the status from computing it, and
+ // (if we successfully got an element) the output values.
+ Status s = buffer_.front().status;
+ if (s.ok()) {
+ *out_tensors = std::move(buffer_.front().value);
}
+ buffer_.pop_front();
+ *end_of_sequence = false;
- // Prefetches elements of the input, storing results in an internal
- // buffer.
+ // Wake the prefetch thread, in case it has been waiting for space
+ // in the buffer. Also wake up threads from other calls to GetNext.
//
- // It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
- while (true) {
- std::vector<Tensor> value;
+ // TODO(mrry): Consider using different condition variables for
+ // GetNext and Prefetch.
+ cond_var_.notify_all();
+ return s;
+ }
- // 1. Wait for a slot in the buffer.
- {
- mutex_lock l(mu_);
- while (!cancelled_ &&
- buffer_.size() >= auto_tuner_.buffer_limit()) {
- cond_var_.wait(l);
- }
-
- if (cancelled_) {
- return;
- }
- }
+ Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!prefetch_thread_) {
+ prefetch_thread_.reset(
+ ctx->env()->StartThread({}, "prefetch_thread",
+ std::bind(&Iterator::PrefetchThread, this,
+ new IteratorContext(*ctx))));
+ }
+ return Status::OK();
+ }
- // 2. Read the next element.
- // Acquire the parent lock since we will be reading an element
- // from the input iterator. Note that we do not wish to release
- // this lock till we have added the fetched element to the
- // `buffer_` else there will be local state that may be missed
- // by SaveInternal.
- mutex_lock parent_l(parent_mu_);
- bool end_of_sequence;
- BufferElement buffer_element;
- buffer_element.status = input_impl_->GetNext(
- ctx, &buffer_element.value, &end_of_sequence);
- if (buffer_element.status.ok() && end_of_sequence) {
- mutex_lock l(mu_);
- prefetch_thread_finished_ = true;
- cond_var_.notify_all();
- return;
+ // Prefetches elements of the input, storing results in an internal
+ // buffer.
+ //
+ // It owns the iterator context passed to it.
+ void PrefetchThread(IteratorContext* ctx) {
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ while (true) {
+ std::vector<Tensor> value;
+
+ // 1. Wait for a slot in the buffer.
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+ cond_var_.wait(l);
}
- // 3. Signal that the element has been produced.
- {
- mutex_lock l(mu_);
- buffer_.push_back(std::move(buffer_element));
- cond_var_.notify_all();
+ if (cancelled_) {
+ return;
}
}
- }
- Status WriteStatus(IteratorStateWriter* writer, size_t index,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- CodeKey(index), static_cast<int64>(status.code())));
- if (!status.ok()) {
- TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
- status.error_message()));
+ // 2. Read the next element.
+ // Acquire the parent lock since we will be reading an element
+ // from the input iterator. Note that we do not wish to release
+ // this lock till we have added the fetched element to the
+ // `buffer_` else there will be local state that may be missed
+ // by SaveInternal.
+ mutex_lock parent_l(parent_mu_);
+ bool end_of_sequence;
+ BufferElement buffer_element;
+ buffer_element.status =
+ input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+ if (buffer_element.status.ok() && end_of_sequence) {
+ mutex_lock l(mu_);
+ prefetch_thread_finished_ = true;
+ cond_var_.notify_all();
+ return;
}
- return Status::OK();
- }
- Status ReadStatus(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- int64 code_int;
- TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
- error::Code code = static_cast<error::Code>(code_int);
-
- if (code != error::Code::OK) {
- string error_message;
- TF_RETURN_IF_ERROR(
- reader->ReadScalar(ErrorMessageKey(index), &error_message));
- *status = Status(code, error_message);
- } else {
- *status = Status::OK();
+ // 3. Signal that the element has been produced.
+ {
+ mutex_lock l(mu_);
+ buffer_.push_back(std::move(buffer_element));
+ cond_var_.notify_all();
}
- return Status::OK();
}
+ }
- string CodeKey(size_t index) {
- return full_name(strings::StrCat("status[", index, "].code"));
+ Status WriteStatus(IteratorStateWriter* writer, size_t index,
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
}
+ return Status::OK();
+ }
- string ErrorMessageKey(size_t index) {
- return full_name(strings::StrCat("status[", index, "].error_message"));
+ Status ReadStatus(IteratorStateReader* reader, size_t index, Status* status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
}
+ return Status::OK();
+ }
- // This mutex is used to ensure exclusivity between multiple threads
- // reading/writing this iterator's local state.
- mutex mu_;
- // This mutex is used to ensure exclusivity between multiple threads
- // accessing the parent iterator. We keep this separate from `mu_` to
- // allow prefetching to run in parallel with GetNext calls.
- mutex parent_mu_ ACQUIRED_BEFORE(mu_);
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
- condition_variable cond_var_;
- PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
- std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
- bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
- };
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("status[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("status[", index, "].error_message"));
+ }
- const DatasetBase* const input_;
- const int64 buffer_size_;
+ // This mutex is used to ensure exclusivity between multiple threads
+ // reading/writing this iterator's local state.
+ mutex mu_;
+ // This mutex is used to ensure exclusivity between multiple threads
+ // accessing the parent iterator. We keep this separate from `mu_` to
+ // allow prefetching to run in parallel with GetNext calls.
+ mutex parent_mu_ ACQUIRED_BEFORE(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
+ condition_variable cond_var_;
+ PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
+ std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ bool prefetch_thread_finished_ GUARDED_BY(mu_) = false;
};
+ const DatasetBase* const input_;
+ const int64 buffer_size_;
};
+void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) {
+ int64 buffer_size;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
+ OP_REQUIRES(ctx,
+ buffer_size >= 0 || buffer_size == PrefetchAutotuner::kAutoTune,
+ errors::InvalidArgument("buffer_size must be >= 0"));
+
+ *output = new Dataset(ctx, input, buffer_size);
+}
+
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
PrefetchDatasetOp);
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
@@ -363,6 +351,4 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
.HostMemory("input_dataset")
.HostMemory("handle"),
PrefetchDatasetOp);
-} // namespace
-
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h
new file mode 100644
index 0000000000..c40c4b00da
--- /dev/null
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
+
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
+
+namespace tensorflow {
+
+class PrefetchDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit PrefetchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override;
+
+ private:
+ class Dataset;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 754c32b6ca..58ec3d4495 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -390,7 +390,7 @@ class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
for (const auto& feature_list :
example.feature_lists().feature_list()) {
- stats_aggregator->IncrementCounter("feature_lists_count", "reainer",
+ stats_aggregator->IncrementCounter("feature_lists_count", "trainer",
1);
for (const auto& feature : feature_list.second.feature()) {
feature_values_list_size_sum += AddStatsFeatureValues(feature);
diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index d5c33c0188..bfdabc3a9f 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -16,13 +16,13 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "tensorflow/core/kernels/function_ops.h"
+
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/memory_types.h"
-#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/gradients.h"
@@ -33,64 +33,40 @@ limitations under the License.
namespace tensorflow {
-static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
-static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp;
-class ArgOp : public OpKernel {
- public:
- explicit ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
- }
-
- void Compute(OpKernelContext* ctx) override {
- auto frame = ctx->call_frame();
- OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
- Tensor val;
- OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
- OP_REQUIRES(ctx, val.dtype() == dtype_,
- errors::InvalidArgument(
- "Type mismatch: actual ", DataTypeString(val.dtype()),
- " vs. expect ", DataTypeString(dtype_)));
- ctx->set_output(0, val);
- }
-
- bool IsExpensive() override { return false; }
-
- private:
- int index_;
- DataType dtype_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
-};
-
-class RetvalOp : public OpKernel {
- public:
- explicit RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
- }
-
- void Compute(OpKernelContext* ctx) override {
- const Tensor& val = ctx->input(0);
- OP_REQUIRES(ctx, val.dtype() == dtype_,
- errors::InvalidArgument(
- "Type mismatch: actual ", DataTypeString(val.dtype()),
- " vs. expect ", DataTypeString(dtype_)));
- auto frame = ctx->call_frame();
- OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
- OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
- }
-
- bool IsExpensive() override { return false; }
-
- private:
- int index_;
- DataType dtype_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
-};
+ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+}
+
+void ArgOp::Compute(OpKernelContext* ctx) {
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ Tensor val;
+ OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument("Type mismatch: actual ",
+ DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ ctx->set_output(0, val);
+}
+
+RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+}
+
+void RetvalOp::Compute(OpKernelContext* ctx) {
+ const Tensor& val = ctx->input(0);
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument("Type mismatch: actual ",
+ DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ auto frame = ctx->call_frame();
+ OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
+ OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
+}
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
@@ -304,123 +280,105 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL),
#endif // TENSORFLOW_USE_SYCL
-class RemoteCallOp : public AsyncOpKernel {
- public:
- explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
- OP_REQUIRES_OK(ctx,
- ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
- }
-
- ~RemoteCallOp() override {}
-
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- FunctionLibraryRuntime* lib = ctx->function_library();
- OP_REQUIRES_ASYNC(ctx, lib != nullptr,
- errors::Internal("No function library is provided."),
- done);
-
- const string& source_device = lib->device()->name();
- const Tensor* target;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
- string target_device;
- OP_REQUIRES_OK_ASYNC(
- ctx,
- DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
- source_device, &target_device),
- done);
-
- AttrValueMap attr_values = func_.attr();
- FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
- instantiate_opts.target = target_device;
-
- FunctionTarget function_target = {target_device, lib};
-
- FunctionLibraryRuntime::Handle handle;
- {
- mutex_lock l(mu_);
- auto cached_entry = handle_cache_.find(function_target);
- if (cached_entry != handle_cache_.end()) {
- handle = cached_entry->second;
- } else {
- VLOG(1) << "Instantiating " << func_.name() << " on " << target_device;
- tracing::ScopedActivity activity(strings::StrCat(
- "RemoteCall: Instantiate: ", func_.name(), " on ", target_device));
- OP_REQUIRES_OK_ASYNC(
- ctx,
- lib->Instantiate(func_.name(), AttrSlice(&attr_values),
- instantiate_opts, &handle),
- done);
- auto insert_result = handle_cache_.insert({function_target, handle});
- CHECK(insert_result.second) << "Insert unsuccessful.";
- VLOG(1) << "Instantiated " << func_.name() << " on " << target_device
- << ", resulting in handle: " << handle << " flr: " << lib;
- }
+RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
+}
+
+void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."), done);
+
+ const string& source_device = lib->device()->name();
+ const Tensor* target;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
+ string target_device;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ DeviceNameUtils::CanonicalizeDeviceName(target->scalar<string>()(),
+ source_device, &target_device),
+ done);
+
+ AttrValueMap attr_values = func_.attr();
+ FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
+ instantiate_opts.target = target_device;
+
+ FunctionTarget function_target = {target_device, lib};
+
+ FunctionLibraryRuntime::Handle handle;
+ {
+ mutex_lock l(mu_);
+ auto cached_entry = handle_cache_.find(function_target);
+ if (cached_entry != handle_cache_.end()) {
+ handle = cached_entry->second;
+ } else {
+ VLOG(1) << "Instantiating " << func_.name() << " on " << target_device;
+ tracing::ScopedActivity activity(strings::StrCat(
+ "RemoteCall: Instantiate: ", func_.name(), " on ", target_device));
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ lib->Instantiate(func_.name(), AttrSlice(&attr_values),
+ instantiate_opts, &handle),
+ done);
+ auto insert_result = handle_cache_.insert({function_target, handle});
+ CHECK(insert_result.second) << "Insert unsuccessful.";
+ VLOG(1) << "Instantiated " << func_.name() << " on " << target_device
+ << ", resulting in handle: " << handle << " flr: " << lib;
}
+ }
- OpInputList arguments;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
+ OpInputList arguments;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
- FunctionLibraryRuntime::Options opts;
- opts.step_id = ctx->step_id();
- opts.runner = ctx->runner();
- opts.source_device = source_device;
- if (opts.source_device != target_device) {
- opts.remote_execution = true;
- }
- opts.create_rendezvous = true;
- std::vector<Tensor> args;
- args.reserve(arguments.size());
- for (const Tensor& argument : arguments) {
- args.push_back(argument);
- }
- for (const auto& dtype : input_dtypes_) {
- AllocatorAttributes arg_alloc_attrs;
- if (DataTypeAlwaysOnHost(dtype)) {
- arg_alloc_attrs.set_on_host(true);
- }
- opts.args_alloc_attrs.push_back(arg_alloc_attrs);
+ FunctionLibraryRuntime::Options opts;
+ opts.step_id = ctx->step_id();
+ opts.runner = ctx->runner();
+ opts.source_device = source_device;
+ if (opts.source_device != target_device) {
+ opts.remote_execution = true;
+ }
+ opts.create_rendezvous = true;
+ std::vector<Tensor> args;
+ args.reserve(arguments.size());
+ for (const Tensor& argument : arguments) {
+ args.push_back(argument);
+ }
+ for (const auto& dtype : input_dtypes_) {
+ AllocatorAttributes arg_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ arg_alloc_attrs.set_on_host(true);
}
- for (const auto& dtype : output_dtypes_) {
- AllocatorAttributes ret_alloc_attrs;
- if (DataTypeAlwaysOnHost(dtype)) {
- ret_alloc_attrs.set_on_host(true);
- }
- opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ opts.args_alloc_attrs.push_back(arg_alloc_attrs);
+ }
+ for (const auto& dtype : output_dtypes_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
}
- auto* rets = new std::vector<Tensor>;
- auto* activity = new tracing::ScopedActivity(strings::StrCat(
- "RemoteCall: Run: ", func_.name(), " on ", target_device));
- VLOG(1) << "Running " << func_.name() << " on " << target_device
- << " with handle: " << handle;
- lib->Run(opts, handle, args, rets,
- [rets, activity, done, ctx](const Status& status) {
- if (!status.ok()) {
- ctx->SetStatus(status);
- } else {
- for (size_t i = 0; i < rets->size(); ++i) {
- ctx->set_output(i, (*rets)[i]);
- }
- }
- delete rets;
- delete activity;
- done();
- });
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
}
-
- private:
- NameAttrList func_;
- DataTypeVector input_dtypes_;
- DataTypeVector output_dtypes_;
-
- mutex mu_;
- typedef std::pair<string, FunctionLibraryRuntime*> FunctionTarget;
- std::map<FunctionTarget, FunctionLibraryRuntime::Handle> handle_cache_
- GUARDED_BY(mu_);
-
- TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
-};
+ auto* rets = new std::vector<Tensor>;
+ auto* activity = new tracing::ScopedActivity(strings::StrCat(
+ "RemoteCall: Run: ", func_.name(), " on ", target_device));
+ VLOG(1) << "Running " << func_.name() << " on " << target_device
+ << " with handle: " << handle;
+ lib->Run(opts, handle, args, rets,
+ [rets, activity, done, ctx](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ delete activity;
+ done();
+ });
+}
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
diff --git a/tensorflow/core/kernels/function_ops.h b/tensorflow/core/kernels/function_ops.h
new file mode 100644
index 0000000000..9e88cc6d8c
--- /dev/null
+++ b/tensorflow/core/kernels/function_ops.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+static const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
+static const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
+
+class ArgOp : public OpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+class RetvalOp : public OpKernel {
+ public:
+ explicit RetvalOp(OpKernelConstruction* ctx);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
+};
+
+class RemoteCallOp : public AsyncOpKernel {
+ public:
+ explicit RemoteCallOp(OpKernelConstruction* ctx);
+
+ ~RemoteCallOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
+
+ private:
+ NameAttrList func_;
+ DataTypeVector input_dtypes_;
+ DataTypeVector output_dtypes_;
+
+ mutex mu_;
+ typedef std::pair<string, FunctionLibraryRuntime*> FunctionTarget;
+ std::map<FunctionTarget, FunctionLibraryRuntime::Handle> handle_cache_
+ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_FUNCTION_OPS_H_
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index cb285bf732..1529d2e336 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -127,31 +127,47 @@ class IfOp : public AsyncOpKernel {
explicit IfOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
auto lib = ctx->function_library();
OP_REQUIRES(ctx, lib != nullptr, errors::Internal("No function library"));
- const NameAttrList* func;
- OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &func));
- OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &then_handle_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &func));
- OP_REQUIRES_OK(ctx, Instantiate(lib, *func, &else_handle_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("then_branch", &then_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("else_branch", &else_func_));
}
~IfOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ auto lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library"), done);
+
+ // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
+ // registration, this kernel may be shared by multiple subgraphs, which have
+ // different associated `FunctionLibraryRuntime` objects and hence different
+ // `FHandle` namespaces. So we must call Instantiate() to make sure we get
+ // the correct function handles with respect to `lib`. Note the underlying
+ // `lib->Instantiate()` caches the created function handles, so calling
+ // `Instantiate()` repeatedly on the same `lib` and function is cheap.
+ FHandle then_handle;
+ FHandle else_handle;
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, then_func_, &then_handle), done);
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, else_func_, &else_handle), done);
+
bool cond;
OP_REQUIRES_OK(ctx, ToBool({ctx->input(0)}, &cond));
- (new State(this, ctx, cond, done))->Start();
+ (new State(this, ctx, cond, then_handle, else_handle, done))->Start();
}
private:
- FHandle then_handle_;
- FHandle else_handle_;
+ NameAttrList then_func_;
+ NameAttrList else_func_;
class State {
public:
- State(IfOp* kernel, OpKernelContext* ctx, bool cond, DoneCallback done)
+ State(IfOp* kernel, OpKernelContext* ctx, bool cond, FHandle then_handle,
+ FHandle else_handle, DoneCallback done)
: kernel_(kernel),
ctx_(ctx),
cond_(cond),
+ then_handle_(then_handle),
+ else_handle_(else_handle),
done_(std::move(done)),
lib_(CHECK_NOTNULL(ctx_->function_library())) {
SetRunOptions(ctx_, &opts_, true /* always_collect_stats */);
@@ -163,7 +179,7 @@ class IfOp : public AsyncOpKernel {
~State() {}
void Start() {
- FHandle handle = cond_ ? kernel_->then_handle_ : kernel_->else_handle_;
+ FHandle handle = cond_ ? then_handle_ : else_handle_;
rets_.clear();
lib_->Run(
// Evaluate one of the branch.
@@ -184,6 +200,8 @@ class IfOp : public AsyncOpKernel {
IfOp* const kernel_;
OpKernelContext* const ctx_;
const bool cond_;
+ FHandle then_handle_;
+ FHandle else_handle_;
DoneCallback done_;
FunctionLibraryRuntime* const lib_;
FunctionLibraryRuntime::Options opts_;
@@ -200,6 +218,10 @@ REGISTER_KERNEL_BUILDER(Name("_If").Device(DEVICE_GPU).HostMemory("cond"),
REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_CPU), IfOp);
REGISTER_KERNEL_BUILDER(Name("If").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
+REGISTER_KERNEL_BUILDER(Name("StatelessIf").Device(DEVICE_CPU), IfOp);
+REGISTER_KERNEL_BUILDER(
+ Name("StatelessIf").Device(DEVICE_GPU).HostMemory("cond"), IfOp);
+
class WhileOp : public AsyncOpKernel {
public:
explicit WhileOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
@@ -214,30 +236,17 @@ class WhileOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, lib != nullptr,
errors::Internal("No function library"), done);
- // TODO(b/37549631): Because this op has `SetIsStateful()` in its
- // op registration, this kernel may be shared by multiple
- // subgraphs, which have different associated
- // `FunctionLibraryRuntime` objects and hence different `FHandle`
- // namespaces. We currently work around this by caching the map
- // from `FunctionLibraryRuntime*` to `FHandle` pairs for the two
- // functions this op uses.
+ // TODO(b/37549631): Because this op has `SetIsStateful()` in its op
+ // registration, this kernel may be shared by multiple subgraphs, which have
+ // different associated `FunctionLibraryRuntime` objects and hence different
+ // `FHandle` namespaces. So we must call Instantiate() to make sure we get
+ // the correct function handles with respect to `lib`. Note the underlying
+ // `lib->Instantiate()` caches the created function handles, so calling
+ // `Instantiate()` repeatedly on the same `lib` and function is cheap.
FHandle cond_handle;
FHandle body_handle;
- {
- mutex_lock l(mu_);
- const auto iter = handles_.find(lib);
- if (iter == handles_.end()) {
- OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle),
- done);
- OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle),
- done);
- handles_[lib] = {cond_handle, body_handle};
- } else {
- cond_handle = iter->second.first;
- body_handle = iter->second.second;
- }
- }
-
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, cond_func_, &cond_handle), done);
+ OP_REQUIRES_OK_ASYNC(ctx, Instantiate(lib, body_func_, &body_handle), done);
(new State(this, ctx, cond_handle, body_handle, done))->Start();
}
@@ -245,10 +254,6 @@ class WhileOp : public AsyncOpKernel {
NameAttrList cond_func_;
NameAttrList body_func_;
- mutex mu_;
- std::unordered_map<FunctionLibraryRuntime*, std::pair<FHandle, FHandle>>
- handles_ GUARDED_BY(mu_);
-
class State {
public:
State(WhileOp* kernel, OpKernelContext* ctx, FHandle cond_handle,
@@ -378,6 +383,9 @@ REGISTER_KERNEL_BUILDER(Name("_While").Device(DEVICE_GPU), WhileOp);
REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_CPU), WhileOp);
REGISTER_KERNEL_BUILDER(Name("While").Device(DEVICE_GPU), WhileOp);
+REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_CPU), WhileOp);
+REGISTER_KERNEL_BUILDER(Name("StatelessWhile").Device(DEVICE_GPU), WhileOp);
+
Status GetScalar(OpKernelContext* ctx, int index, int32* value,
const char* label) {
Tensor t = ctx->input(index);
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index f99dd643f7..d89f1592bd 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -45,6 +45,24 @@ struct FusedBatchNorm;
template <typename Device, typename T, typename U>
struct FusedBatchNormGrad;
+template <bool IsSame, typename Y, typename X, typename T>
+struct CastIfNecessary {
+ static inline void process(
+ Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth,
+ const CPUDevice& d) {
+ y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
+ }
+};
+
+template <typename Y, typename X, typename T>
+struct CastIfNecessary<true, Y, X, T> {
+ static inline void process(
+ Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth,
+ const CPUDevice& d) {
+ y.reshape(rest_by_depth).device(d) = x_shifted;
+ }
+};
+
template <typename T, typename U>
struct FusedBatchNorm<CPUDevice, T, U> {
void operator()(OpKernelContext* context, const Tensor& x_input,
@@ -125,7 +143,11 @@ struct FusedBatchNorm<CPUDevice, T, U> {
auto x_shifted =
x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec);
- y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
+ // Explicitly checks the types of T and U and only casts x_shifted when
+ // T != U. (Not doing so caused a 35-50% performance slowdown for
+ // some compiler flags.)
+ CastIfNecessary<std::is_same<T, U>::value, decltype(y), decltype(x_shifted),
+ T>::process(y, x_shifted, rest_by_depth, d);
}
};
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 80376c61aa..5d4737549b 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,25 +578,41 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
+#if defined(INTEL_MKL)
-// MKL does not support half and int32 types for matrix-multiplication, so
-// register the kernel to use default Eigen based implementations for these
-// types. Registration for NO-LABEL version is in mkl_matmul_op.cc
-TF_CALL_float(REGISTER_CPU_EIGEN);
-TF_CALL_double(REGISTER_CPU_EIGEN);
+// MKL does not support half, bfloat16 and int32 types for
+// matrix-multiplication, so register the kernel to use default Eigen based
+// implementations for these types. REGISTER_CPU defines two versions - Eigen
+// label and NO-LABEL
TF_CALL_half(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);
-
TF_CALL_int32(REGISTER_CPU);
+
+// Float is supported in both MKL DNN as well as in MKL ML
+// Registration for NO-LABEL version is in mkl_matmul_op.cc for types supported
+// by MKL. However we define Eigen label version here just to pass a few unit
+// tests
+TF_CALL_float(REGISTER_CPU_EIGEN);
+
+// MKL DNN does not support complex64/complex128/double, if user specifies
+// to use only opensource MKL DNN then use default implementation for these
+// types otherwise use GEMM from MKL ML binary
+
+#if defined(DO_NOT_USE_ML)
+TF_CALL_complex64(REGISTER_CPU);
+TF_CALL_complex128(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+#else // DO_NOT_USE_ML
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
-#else
+TF_CALL_double(REGISTER_CPU_EIGEN);
+#endif
+
+#else // INTEL MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);
-
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index d545d34fdf..d3566c2e37 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -442,7 +442,6 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
const Tensor& input_tensor =
MklGetInput(context, this->kInputTensorIndexInput);
MklDnnShape dnn_shape_input;
@@ -450,14 +449,14 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
this->SanityCheckInput(context, input_tensor, dnn_shape_input);
if (!context->status().ok()) return;
- MklDnnData<T> dnn_data_input(&cpu_engine);
- MklDnnData<T> dnn_data_output(&cpu_engine);
+ MklDnnData<T> dnn_data_input(&cpu_engine_);
// initialize variables for the pooling op
MklPoolParameters pool_params;
// Get the input tensor and initialize the pooling parameters
- this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
- &dnn_data_input);
+ TensorShape input_tensor_shape = input_tensor.shape();
+ this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
+ input_tensor_shape);
OP_REQUIRES_OK(context, context->status());
// Declare output tensor
@@ -467,65 +466,62 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
// If input is an empty tensor, allocate an empty output tensor and return
if (input_tensor.NumElements() == 0) {
- MklDnnShape output_mkl_shape;
- output_mkl_shape.SetMklTensor(false);
- TensorShape output_tf_shape;
- if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
- output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
- } else {
- memory::dims output_dims_NHWC_order;
- output_dims_NHWC_order = {pool_params.tensor_in_batch,
- static_cast<int>(pool_params.out_height),
- static_cast<int>(pool_params.out_width),
- pool_params.out_depth};
- output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
- }
const int kOutputIndex = 0;
- AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
- output_tf_shape, output_mkl_shape);
- CHECK_NOTNULL(output_tensor);
+ this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
+ output_dims_mkl_order, &output_tensor);
return;
}
- // If input is in Mkl layout, then just get the memory format from it
- // directly, instead of using input data_format to AvgPool.
- if (dnn_shape_input.IsMklTensor()) {
- dnn_data_output.SetUsrMem(
- output_dims_mkl_order,
- static_cast<memory::format>(
- dnn_data_input.GetUsrMemDesc().data.format));
-
- } else {
- dnn_data_output.SetUsrMem(output_dims_mkl_order,
- this->data_format_mkldnn_);
- }
-
- // describe the memory layout
- dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
-
- // 3. create a pooling primitive descriptor
- auto pool_desc = pooling_forward::desc(
- prop_kind::forward, algorithm::pooling_avg_exclude_padding,
- dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
- memory::dims({pool_params.row_stride, pool_params.col_stride}),
- memory::dims({pool_params.window_rows, pool_params.window_cols}),
- memory::dims({static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)}),
- memory::dims({static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)}),
- TFPaddingToMklDnnPadding(this->padding_));
- auto pool_prim_desc =
- pooling_forward::primitive_desc(pool_desc, cpu_engine);
-
- this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order,
+ memory::dims filter_dims, strides, padding_left, padding_right;
+ this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
+ &padding_left, &padding_right);
+
+ // Get the input memory descriptor
+ memory::desc input_md =
+ dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetMklLayout()
+ : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_);
+
+ // Get src/filter/stride/padding information
+ memory::dims src_dims =
+ dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_);
+
+ // Get an average pooling primitive from the op pool
+ MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
+ MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims,
+ strides, padding_left, padding_right,
+ algorithm::pooling_avg_exclude_padding);
+ pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
+
+ // allocate output tensor
+ this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
+ output_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
CHECK_NOTNULL(output_tensor);
OP_REQUIRES_OK(context, context->status());
- dnn_data_output.SetUsrMemDataHandle(output_tensor);
- this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input,
- &dnn_data_output);
+ // check whether we need to reorder src
+ const T* src_data = input_tensor.flat<T>().data();
+ if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) {
+ dnn_data_input.SetUsrMem(input_md, &input_tensor);
+ auto src_target_primitive_desc = memory::primitive_desc(
+ {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()},
+ cpu_engine_);
+ dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc);
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle()));
+ }
+
+ T* dst_data = output_tensor->flat<T>().data();
+
+ // execute pooling
+ pooling_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -535,9 +531,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
errors::Aborted("Operation received an exception:", error_msg));
}
} // Compute
-}; // MklAvgPoolingOp
-//-----------------------------------------------------------------------------
+ private:
+ engine cpu_engine_ = engine(engine::cpu, 0);
+}; // MklAvgPoolingOp
template <class Device, class T>
class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
@@ -547,91 +544,78 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
- MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape;
- const Tensor& tensor_in_shape =
+ const Tensor& orig_input_tensor =
MklGetInput(context, kInputTensorIndexInputShape);
- const Tensor& input_gradient_tensor =
+ const Tensor& grad_tensor =
MklGetInput(context, kInputTensorIndexInputGradient);
- GetMklShape(context, kInputTensorIndexInputShape,
- &original_input_mkl_shape);
- GetMklShape(context, kInputTensorIndexInputGradient,
- &input_gradient_mkl_shape);
- SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor,
- original_input_mkl_shape, input_gradient_mkl_shape);
+ MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
+ GetMklShape(context, kInputTensorIndexInputShape, &orig_input_mkl_shape);
+ GetMklShape(context, kInputTensorIndexInputGradient, &grad_mkl_shape);
if (!context->status().ok()) return;
// Used to allocate output_diff_src/diff_src
- // and create pool_fwd mdm desc
- // 0. Input("orig_input_shape: int32") //NOT a T Tensor!
- // 1. Input("grad: T")
-
- MklDnnData<T> input_gradient_diff_dst(&cpu_engine);
- MklDnnData<T> output_diff_src(&cpu_engine);
- Tensor* output_tensor_diff_src = nullptr;
- TensorShape original_input_shape;
+ MklDnnData<T> grad_dnn_data(&cpu_engine_);
MklPoolParameters pool_params;
- memory::dims output_dims_mkl_order, original_input_dims_nchw;
- // Configure the original input memory descriptor
- memory::desc original_input_md = ConfigureOriginalInput(
- context, tensor_in_shape, original_input_mkl_shape,
- &original_input_dims_nchw, &pool_params, &original_input_shape);
-
- // configure the original output memory descriptor
- // by definition, the shape of the original output is the same
- // as the shape of the gradient diff_dst
- memory::desc original_output_md = this->ConfigureOriginalOutput(
- pool_params, input_gradient_mkl_shape, output_dims_mkl_order);
-
- memory::desc target_diff_dst_md = this->ConfigureInputGradient(
- input_gradient_mkl_shape, input_gradient_tensor,
- &input_gradient_diff_dst, original_output_md);
- // The shape of the output diff src needs to be the same shape as the
- // original input. But we will set its format to be same as the format of
- // input gradient. We won't use format of original input since it will
- // always be in Tensorflow layout (given that AvgPoolGrad gets shape of
- // the input rather than actual input).
- output_diff_src.SetUsrMem(
- original_input_dims_nchw,
- static_cast<memory::format>(target_diff_dst_md.data.format));
-
- // Create the forward pooling primitive descriptor so we can reference it
- // in the backward pooling primitive descriptor
- auto pool_fwd_desc = pooling_forward::desc(
- prop_kind::forward, algorithm::pooling_avg_exclude_padding,
- original_input_md, original_output_md,
- memory::dims({pool_params.row_stride, pool_params.col_stride}),
- memory::dims({pool_params.window_rows, pool_params.window_cols}),
- memory::dims({static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)}),
- memory::dims({static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)}),
- TFPaddingToMklDnnPadding(this->padding_));
- auto pool_fwd_prim_desc =
- pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);
-
- auto pool_bkwd_desc = pooling_backward::desc(
- algorithm::pooling_avg_exclude_padding,
- output_diff_src.GetUsrMemDesc(), target_diff_dst_md,
- memory::dims({pool_params.row_stride, pool_params.col_stride}),
- memory::dims({pool_params.window_rows, pool_params.window_cols}),
- memory::dims({static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)}),
- memory::dims({static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)}),
- TFPaddingToMklDnnPadding(this->padding_));
- auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
- pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
- this->AllocateOutputTensor(
- context, pool_bkwd_prim_desc, original_input_dims_nchw,
- this->data_format_mkldnn_, &output_tensor_diff_src);
-
- output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src);
-
- this->PrepareAndExecuteNet(
- pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src,
- memory::primitive_desc(target_diff_dst_md, cpu_engine));
+ auto shape_vec = orig_input_tensor.vec<int32>();
+ TensorShape orig_input_shape;
+ for (int i = 0; i < orig_input_tensor.NumElements(); i++) {
+ orig_input_shape.AddDim(shape_vec(i));
+ }
+ this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
+ orig_input_shape);
+
+ memory::dims filter_dims, strides, padding_left, padding_right;
+ this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
+ &padding_left, &padding_right);
+
+ memory::dims orig_input_dims_mkl_order =
+ orig_input_mkl_shape.IsMklTensor()
+ ? orig_input_mkl_shape.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_);
+
+ memory::dims diff_dst_dims =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_);
+ memory::dims output_dims_mkl_order;
+ this->GetOutputDims(pool_params, &output_dims_mkl_order);
+
+ MklPoolingParams bwdParams(orig_input_dims_mkl_order,
+ output_dims_mkl_order, filter_dims, strides,
+ padding_left, padding_right,
+ algorithm::pooling_avg_exclude_padding);
+ MklPoolingBwdPrimitive<T>* pooling_bwd =
+ MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
+
+ Tensor* output_tensor = nullptr;
+ this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
+ orig_input_dims_mkl_order,
+ this->data_format_mkldnn_, &output_tensor);
+ // get diff_dst memory::desc
+ memory::desc diff_dst_md =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetMklLayout()
+ : memory::desc(diff_dst_dims, MklDnnType<T>(),
+ this->data_format_mkldnn_);
+ // Check whether we need to reorder diff_dst
+ const T* diff_dst_data = grad_tensor.flat<T>().data();
+ if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) {
+ auto target_diff_dst = memory::primitive_desc(
+ {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()},
+ cpu_engine_);
+ grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
+ grad_dnn_data.CheckReorderToOpMem(target_diff_dst);
+ diff_dst_data = const_cast<T*>(
+ reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()));
+ }
+
+ T* diff_src_data = output_tensor->flat<T>().data();
+
+ // execute pooling op
+ pooling_bwd->Execute(diff_dst_data, diff_src_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -639,33 +623,14 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
error_msg));
}
- } // Compute
+ }
private:
// 0. Input("orig_input_shape: int32")
// 1. Input("grad: T")
const int kInputTensorIndexInputShape = 0;
const int kInputTensorIndexInputGradient = 1;
-
- memory::desc ConfigureOriginalInput(
- OpKernelContext* context, const Tensor& tensor_original_input_shape,
- const MklDnnShape& original_input_mkl_shape,
- memory::dims* original_input_dims_mkl_order,
- MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
- CHECK_NOTNULL(original_input_dims_mkl_order);
- CHECK_NOTNULL(pool_params);
- CHECK_NOTNULL(input_tensor_shape);
- // For AvgPoolGrad, we only get the size of the original input because
- // The original data is irrelvant.
- auto shape_vec = tensor_original_input_shape.vec<int32>();
- for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) {
- input_tensor_shape->AddDim(shape_vec(i));
- }
-
- return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
- context, tensor_original_input_shape, original_input_mkl_shape,
- original_input_dims_mkl_order, pool_params, *input_tensor_shape);
- }
+ engine cpu_engine_ = engine(engine::cpu, 0);
void SanityCheckInputs(OpKernelContext* context,
const Tensor& tensor_in_shape,
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 6f490cdc23..d8efb1be3e 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -308,11 +308,9 @@ class MklConcatOp : public OpKernel {
}
if (invoke_eigen) {
- string msg = std::string("Invoking Eigen version of Concat. Reason:") +
- (!is_concat_dim_channel
- ? std::string("Concat dimension is not channel")
- : std::string("Not all tensors are in Mkl layout"));
- VLOG(1) << "_MklConcatOp: " << msg;
+ VLOG(1) << "_MklConcatOp: Invoking Eigen version of Concat. Reason:"
+ << (!is_concat_dim_channel ? "Concat dimension is not channel"
+ : "Not all tensors are in Mkl layout");
CallEigenVersion(context, input_tensors, input_shapes);
return;
}
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index a370037d97..b73a119a88 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -328,9 +328,8 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- static std::string CreateKey(
- const MklConvBwdFilterParams& convBwdFilterDims) {
- std::string prefix = "conv2d_bwd_filter";
+ static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
+ string prefix = "conv2d_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
@@ -346,13 +345,13 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
MklPrimitive* GetConv2dBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
- std::string key = CreateKey(convBwdFilterDims);
+ string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
void SetConv2dBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
- std::string key = CreateKey(convBwdFilterDims);
+ string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index b0f7faaa1a..39498f1a80 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -265,9 +265,8 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- static std::string CreateKey(
- const MklConvBwdInputParams& convBwdInputDims) {
- std::string prefix = "conv2d_bwd_input";
+ static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
+ string prefix = "conv2d_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -282,13 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
MklPrimitive* GetConv2dBwdInput(
const MklConvBwdInputParams& convBwdInputDims) {
- std::string key = CreateKey(convBwdInputDims);
+ string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
void SetConv2dBwdInput(
const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
- std::string key = CreateKey(convBwdInputDims);
+ string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
};
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index b568973220..62396eeb8b 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <string.h>
#include <map>
-#include <string>
#include <vector>
#include <memory>
@@ -35,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
@@ -298,8 +298,8 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- static std::string CreateKey(const MklConvFwdParams& convFwdDims) {
- std::string prefix = "conv2d_fwd_";
+ static string CreateKey(const MklConvFwdParams& convFwdDims) {
+ string prefix = "conv2d_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convFwdDims.src_dims);
@@ -314,12 +314,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
- std::string key = CreateKey(convFwdDims);
+ string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
- std::string key = CreateKey(convFwdDims);
+ string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
};
@@ -930,10 +930,9 @@ class MklConv2DOp : public OpKernel {
conv2d_fwd->Execute(src_data, filter_data, dst_data);
}
} catch (mkldnn::error &e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + std::string(e.message) +
- ", in file " + std::string(__FILE__) + ":" +
- std::to_string(__LINE__);
+ string error_msg = tensorflow::strings::StrCat(
+ "Status: ", e.status, ", message: ", string(e.message), ", in file ",
+ __FILE__, ":", __LINE__);
OP_REQUIRES_OK(context,
errors::Aborted("Operation received an exception:", error_msg));
}
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 5e1a5001dc..3f154ff33b 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
#include <limits>
-#include <string>
#include <vector>
#include <memory>
diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
index 3fe660cf96..0149e78db5 100644
--- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
@@ -262,6 +262,7 @@ class MklFusedBatchNormOp : public OpKernel {
}
void MklCreateInputLayout(OpKernelContext* context) {
+ const Tensor& input = MklGetInput(context, 0);
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
if (input_in_mkl_format) {
mkl_lt_input =
@@ -544,6 +545,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
}
void MklCreateInputLayout(OpKernelContext* context) {
+ const Tensor& input = MklGetInput(context, 0);
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
if (input_in_mkl_format) {
mkl_lt_input =
@@ -684,6 +686,466 @@ class MklFusedBatchNormGradOp : public OpKernel {
#ifndef INTEL_MKL_ML
+struct MklBatchNormFwdParams {
+ memory::dims src_dims;
+ int depth;
+ float eps;
+ bool training;
+
+ MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps,
+ bool training)
+ : src_dims(src_dims), depth(depth), eps(eps), training(training) {}
+};
+
+template <typename T>
+class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ context_.fwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
+ if (context_.bn_fwd == nullptr) Setup(fwdParams);
+ }
+
+ ~MklFusedBatchNormFwdPrimitive() {}
+
+ // BatchNormalization forward execute
+ // src_data: input data buffer of src
+ // weights_data: input data buffer of weights
+ // dst_data: output data buffer of dst
+ // mean_data: output data buffer of means
+ // variance_data: output data buffer of variances
+ void Execute(const T* src_data, const T* weights_data, T* dst_data,
+ T* mean_data, T* variance_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+
+ if (context_.flags & use_scale_shift)
+ context_.weights_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(weights_data)));
+
+ if ((context_.pkind == prop_kind::forward_training) ||
+ (context_.flags & use_global_stats)) {
+ context_.mean_mem->set_data_handle(static_cast<void*>(mean_data));
+ context_.variance_mem->set_data_handle(static_cast<void*>(variance_data));
+ }
+
+ // execution
+ context_.fwd_stream->submit(context_.fwd_primitives);
+
+ context_.src_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
+
+ if (context_.flags & use_scale_shift)
+ context_.weights_mem->set_data_handle(DummyData);
+
+ if ((context_.pkind == prop_kind::forward_training) ||
+ (context_.flags & use_global_stats)) {
+ context_.mean_mem->set_data_handle(DummyData);
+ context_.variance_mem->set_data_handle(DummyData);
+ }
+ }
+
+ memory::primitive_desc GetDstPd() const {
+ return (*context_.dst_mem).get_primitive_desc();
+ }
+
+ mkldnn_memory_format_t GetSrcFmt() const {
+ return (*context_.src_mem).get_primitive_desc().desc().data.format;
+ }
+
+ mkldnn_memory_format_t GetDstFmt() const {
+ return (*context_.dst_mem).get_primitive_desc().desc().data.format;
+ }
+
+ private:
+ // Primitive reuse context for BatchNorm fwd op
+ struct BatchNormFwdContext {
+ // flags indict if it is training or inference mode
+ int64 flags;
+
+ // algorithm
+ mkldnn::prop_kind pkind;
+
+ // Mkldnn Memory
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> weights_mem;
+ std::shared_ptr<mkldnn::memory> dst_mem;
+ std::shared_ptr<mkldnn::memory> mean_mem;
+ std::shared_ptr<mkldnn::memory> variance_mem;
+
+ // BatchNorm forward primitive
+ std::shared_ptr<mkldnn::primitive> bn_fwd;
+ std::shared_ptr<mkldnn::stream> fwd_stream;
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ BatchNormFwdContext()
+ : flags(0),
+ pkind(mkldnn::forward_training),
+ src_mem(nullptr),
+ weights_mem(nullptr),
+ dst_mem(nullptr),
+ mean_mem(nullptr),
+ variance_mem(nullptr),
+ bn_fwd(nullptr),
+ fwd_stream(nullptr) {}
+ };
+
+ void Setup(const MklBatchNormFwdParams& fwdParams) {
+ context_.flags = fwdParams.training ? use_scale_shift
+ : (use_scale_shift | use_global_stats);
+ context_.pkind = fwdParams.training ? prop_kind::forward_training
+ : prop_kind::forward_scoring;
+
+ // memory desc
+ auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
+ get_desired_format(fwdParams.src_dims[1]));
+
+ // fwd desc & primitive desc
+ auto fwd_desc = batch_normalization_forward::desc(
+ context_.pkind, src_md, fwdParams.eps, context_.flags);
+ auto fwd_pd =
+ batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_);
+
+ // memory primitive
+ context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData));
+ context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData));
+
+ if (context_.flags & use_scale_shift) {
+ auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType<T>(),
+ memory::format::nc);
+ context_.weights_mem.reset(
+ new memory({weights_desc, cpu_engine_}, DummyData));
+ }
+
+ if (fwdParams.training || (context_.flags & use_global_stats)) {
+ auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType<T>(),
+ memory::format::nc);
+ context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData));
+
+ auto variance_desc =
+ memory::desc({1, fwdParams.depth}, MklDnnType<T>(), memory::nc);
+ context_.variance_mem.reset(
+ new memory({variance_desc, cpu_engine_}, DummyData));
+ }
+
+ // BatchNorm forward primitive
+ if (!fwdParams.training && !(context_.flags & use_global_stats)) {
+ if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) {
+ context_.bn_fwd.reset(new batch_normalization_forward(
+ fwd_pd, *context_.src_mem, *context_.weights_mem,
+ *context_.dst_mem));
+ } else {
+ context_.bn_fwd.reset(new batch_normalization_forward(
+ fwd_pd, *context_.src_mem, *context_.dst_mem));
+ }
+ } else if (context_.flags & use_global_stats) {
+ if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) {
+ context_.bn_fwd.reset(new batch_normalization_forward(
+ fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem,
+ (const primitive::at)*context_.variance_mem, *context_.weights_mem,
+ *context_.dst_mem));
+ } else {
+ context_.bn_fwd.reset(new batch_normalization_forward(
+ fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem,
+ (const primitive::at)*context_.variance_mem, *context_.dst_mem));
+ }
+ } else {
+ if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) {
+ context_.bn_fwd.reset(new batch_normalization_forward(
+ fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem,
+ *context_.mean_mem, *context_.variance_mem));
+ } else {
+ context_.bn_fwd.reset(new batch_normalization_forward(
+ fwd_pd, *context_.src_mem, *context_.dst_mem, *context_.mean_mem,
+ *context_.variance_mem));
+ }
+ }
+
+ context_.fwd_primitives.push_back(*context_.bn_fwd);
+ }
+
+ mkldnn::memory::desc get_desc_data(const mkldnn::memory& m) const {
+ return m.get_primitive_desc().desc().data;
+ }
+
+ struct BatchNormFwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklFusedBatchNormFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklFusedBatchNormFwdPrimitive<T>* Get(
+ const MklBatchNormFwdParams& fwdParams) {
+ auto bn_fwd = static_cast<MklFusedBatchNormFwdPrimitive<T>*>(
+ MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().GetBatchNormFwd(
+ fwdParams));
+
+ if (bn_fwd == nullptr) {
+ bn_fwd = new MklFusedBatchNormFwdPrimitive<T>(fwdParams);
+ MklFusedBatchNormFwdPrimitiveFactory<T>::GetInstance().SetBatchNormFwd(
+ fwdParams, bn_fwd);
+ }
+ return bn_fwd;
+ }
+
+ static MklFusedBatchNormFwdPrimitiveFactory& GetInstance() {
+ static MklFusedBatchNormFwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklFusedBatchNormFwdPrimitiveFactory() {}
+ ~MklFusedBatchNormFwdPrimitiveFactory() {}
+
+ static std::string CreateKey(const MklBatchNormFwdParams& fwdParams) {
+ std::string prefix = "bn_fwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(fwdParams.src_dims);
+ key_creator.AddAsKey<int>(fwdParams.depth);
+ key_creator.AddAsKey<float>(fwdParams.eps);
+ key_creator.AddAsKey<bool>(fwdParams.training);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetBatchNormFwd(const MklBatchNormFwdParams& fwdParams) {
+ std::string key = CreateKey(fwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetBatchNormFwd(const MklBatchNormFwdParams& fwdParams,
+ MklPrimitive* op) {
+ std::string key = CreateKey(fwdParams);
+ this->SetOp(key, op);
+ }
+};
+
+struct MklBatchNormBwdParams {
+ memory::dims src_dims;
+ memory::dims diff_dst_dims;
+ int depth;
+ float eps;
+ bool training;
+
+ MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims,
+ int depth, float eps, bool training)
+ : src_dims(src_dims),
+ diff_dst_dims(diff_dst_dims),
+ depth(depth),
+ eps(eps),
+ training(training) {}
+};
+
+template <typename T>
+class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ context_.bwd_stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
+ if (context_.bn_bwd == nullptr) Setup(bwdParams);
+ }
+
+ ~MklFusedBatchNormBwdPrimitive() {}
+
+ // BatchNormalization backward execute
+ // src_data: input data buffer of src
+ // mean_data: input data buffer of mean
+ // variance_data: input data buffer of variance
+ // diff_dst_data: input data buffer of diff_dst
+ // weights_data: input data buffer of weights
+ // diff_src_data: output data buffer of diff_src
+ // diff_weights_data: output data buffer of diff_weights
+ void Execute(const T* src_data, const T* mean_data, const T* variance_data,
+ const T* diff_dst_data, const T* weights_data, T* diff_src_data,
+ T* diff_weights_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.mean_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(mean_data)));
+ context_.variance_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(variance_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+
+ if (context_.flags & use_scale_shift) {
+ context_.weights_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(weights_data)));
+ context_.diff_weights_mem->set_data_handle(
+ static_cast<void*>(diff_weights_data));
+ }
+
+ context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
+
+ // execution
+ context_.bwd_stream->submit(context_.bwd_primitives);
+
+ context_.src_mem->set_data_handle(DummyData);
+ context_.mean_mem->set_data_handle(DummyData);
+ context_.variance_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ if (context_.flags & use_scale_shift) {
+ context_.weights_mem->set_data_handle(DummyData);
+ context_.diff_weights_mem->set_data_handle(DummyData);
+ }
+ context_.diff_src_mem->set_data_handle(DummyData);
+ }
+
+ mkldnn_memory_format_t GetSrcFmt() {
+ return (*context_.src_mem).get_primitive_desc().desc().data.format;
+ }
+
+ mkldnn_memory_format_t GetDiffDstFmt() {
+ return (*context_.diff_dst_mem).get_primitive_desc().desc().data.format;
+ }
+
+ memory::primitive_desc GetDiffSrcPd() {
+ return (*context_.diff_src_mem).get_primitive_desc();
+ }
+
+ private:
+ struct BatchNormBwdContext {
+ // Flags to indicate whether it is training or inference
+ int64 flags;
+
+ // MKLDNN memory
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> mean_mem;
+ std::shared_ptr<mkldnn::memory> variance_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+ std::shared_ptr<mkldnn::memory> weights_mem;
+ std::shared_ptr<mkldnn::memory> diff_weights_mem;
+ std::shared_ptr<mkldnn::memory> diff_src_mem;
+
+ // Batch Norm primitive
+ std::shared_ptr<mkldnn::primitive> bn_bwd;
+ std::vector<mkldnn::primitive> bwd_primitives;
+ std::shared_ptr<mkldnn::stream> bwd_stream;
+
+ BatchNormBwdContext()
+ : src_mem(nullptr),
+ mean_mem(nullptr),
+ variance_mem(nullptr),
+ diff_dst_mem(nullptr),
+ weights_mem(nullptr),
+ diff_weights_mem(nullptr),
+ diff_src_mem(nullptr),
+ bwd_stream(nullptr) {}
+ };
+
+ void Setup(const MklBatchNormBwdParams& bwdParams) {
+ context_.flags = bwdParams.training ? use_scale_shift
+ : (use_scale_shift | use_global_stats);
+
+ // memory desc
+ auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType<T>(),
+ get_desired_format(bwdParams.src_dims[1]));
+ auto diff_dst_md =
+ memory::desc({bwdParams.diff_dst_dims}, MklDnnType<T>(),
+ get_desired_format(bwdParams.diff_dst_dims[1]));
+ auto variance_desc =
+ memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::nc);
+ auto mean_desc =
+ memory::desc({1, bwdParams.depth}, MklDnnType<T>(), memory::format::nc);
+ auto weights_desc =
+ memory::desc({2, bwdParams.depth}, MklDnnType<T>(), memory::format::nc);
+ auto diff_weights_desc = weights_desc;
+
+ // fwd desc & primitive desc
+ auto fwd_desc = batch_normalization_forward::desc(
+ prop_kind::forward_training, src_md, bwdParams.eps,
+ bwdParams.training ? use_scale_shift
+ : (use_scale_shift | use_global_stats));
+ auto fwd_pd =
+ batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_);
+
+ // BatchNorm backward primtive
+ //
+ // For inference, specify use_global_stats
+ // 1. on fwd propagation, use mean and variance provided as inputs.
+ // 2. on bwd propagation, mean and variance are considered as constants.
+ // Thus, reduce the amount of MKL computation.
+ auto bwd_desc = batch_normalization_backward::desc(
+ prop_kind::backward, diff_dst_md, src_md, bwdParams.eps,
+ bwdParams.training ? use_scale_shift
+ : (use_scale_shift | use_global_stats));
+ auto bn_bwd_pd = batch_normalization_backward::primitive_desc(
+ bwd_desc, cpu_engine_, fwd_pd);
+
+ // memory primitive
+ context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData));
+ context_.diff_dst_mem.reset(
+ new memory({diff_dst_md, cpu_engine_}, DummyData));
+ context_.variance_mem.reset(
+ new memory({variance_desc, cpu_engine_}, DummyData));
+ context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData));
+ context_.weights_mem.reset(
+ new memory({weights_desc, cpu_engine_}, DummyData));
+ context_.diff_weights_mem.reset(
+ new memory({diff_weights_desc, cpu_engine_}, DummyData));
+ context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData));
+
+ context_.bn_bwd.reset(new batch_normalization_backward(
+ bn_bwd_pd, *context_.src_mem, *context_.mean_mem,
+ *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem,
+ *context_.diff_src_mem, *context_.diff_weights_mem));
+ context_.bwd_primitives.push_back(*context_.bn_bwd);
+ }
+
+ struct BatchNormBwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklFusedBatchNormBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklFusedBatchNormBwdPrimitive<T>* Get(
+ const MklBatchNormBwdParams& bwdParams) {
+ auto bn_bwd = static_cast<MklFusedBatchNormBwdPrimitive<T>*>(
+ MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().GetBatchNormBwd(
+ bwdParams));
+ if (bn_bwd == nullptr) {
+ bn_bwd = new MklFusedBatchNormBwdPrimitive<T>(bwdParams);
+ MklFusedBatchNormBwdPrimitiveFactory<T>::GetInstance().SetBatchNormBwd(
+ bwdParams, bn_bwd);
+ }
+ return bn_bwd;
+ }
+
+ static MklFusedBatchNormBwdPrimitiveFactory& GetInstance() {
+ static MklFusedBatchNormBwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklFusedBatchNormBwdPrimitiveFactory() {}
+ ~MklFusedBatchNormBwdPrimitiveFactory() {}
+
+ static std::string CreateKey(const MklBatchNormBwdParams& bwdParams) {
+ std::string prefix = "bn_bwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(bwdParams.src_dims);
+ key_creator.AddAsKey(bwdParams.diff_dst_dims);
+ key_creator.AddAsKey<int>(bwdParams.depth);
+ key_creator.AddAsKey<float>(bwdParams.eps);
+ key_creator.AddAsKey<bool>(bwdParams.training);
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetBatchNormBwd(const MklBatchNormBwdParams& bwdParams) {
+ std::string key = CreateKey(bwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetBatchNormBwd(const MklBatchNormBwdParams& bwdParams,
+ MklPrimitive* op) {
+ std::string key = CreateKey(bwdParams);
+ this->SetOp(key, op);
+ }
+};
+
template <typename Device, typename T>
class MklFusedBatchNormOp : public OpKernel {
public:
@@ -701,7 +1163,6 @@ class MklFusedBatchNormOp : public OpKernel {
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
const size_t kSrcIndex = 0; // index of src input tensor
const size_t kScaleIndex = 1; // index of scale tensor
const size_t kShiftIndex = 2; // index of shift tensor
@@ -786,7 +1247,7 @@ class MklFusedBatchNormOp : public OpKernel {
SetMeanVariance(est_mean_tensor, est_variance_tensor);
MklDnnData<T> src(&cpu_engine);
- MklDnnData<T> dst(&cpu_engine);
+ MklDnnData<T> weights(&cpu_engine);
memory::format format_m;
if (dnn_shape_src.IsMklTensor()) {
@@ -800,123 +1261,102 @@ class MklFusedBatchNormOp : public OpKernel {
}
// set src primitive
- memory::dims src_dims;
- if (dnn_shape_src.IsMklTensor()) {
- src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(),
- tensor_format_);
- } else {
- src_dims =
- TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
- }
+ memory::dims src_dims =
+ dnn_shape_src.IsMklTensor()
+ ? dnn_shape_src.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
auto src_md = dnn_shape_src.IsMklTensor()
? dnn_shape_src.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), format_m);
- src.SetUsrMem(src_md, &src_tensor);
- // set weights primitive
// MKL-DNN packs scale & shift as "weights":
// <scale>...<scale><shift>...<shift>
- auto weights_desc = memory::desc({2, static_cast<int>(depth_)},
- MklDnnType<T>(), memory::format::nc);
- auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine);
- auto weights_m = memory(weights_pd);
- T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle());
- T* scale_tf =
- reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data()));
- T* shift_tf =
- reinterpret_cast<T*>(const_cast<T*>(shift_tensor.flat<T>().data()));
+ weights.AllocateBuffer(2 * depth_ * sizeof(T));
+ T* weights_data = reinterpret_cast<T*>(weights.GetAllocatedBuffer());
+ const T* scale_tf = scale_tensor.flat<T>().data();
+ const T* shift_tf = shift_tensor.flat<T>().data();
- for (int k = 0; k < depth_; k++) {
- weights_data[k] = scale_tf[k];
- weights_data[k + depth_] = shift_tf[k];
- }
-
- // set mean primitive
- auto mean_desc = memory::desc({1, static_cast<int>(depth_)},
- MklDnnType<T>(), memory::format::nc);
- auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine);
+ std::memcpy(weights_data, scale_tf, depth_ * sizeof(T));
+ std::memcpy(weights_data + depth_, shift_tf, depth_ * sizeof(T));
char* saved_mean_data_tf =
reinterpret_cast<char*>(saved_mean_tensor->flat<T>().data());
std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_),
depth_ * sizeof(T));
- auto mean_m =
- memory(mean_pd, reinterpret_cast<void*>(saved_mean_data_tf));
- // set variance primitive
- auto variance_desc = memory::desc({1, static_cast<int>(depth_)},
- MklDnnType<T>(), memory::format::nc);
- auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine);
char* saved_variance_data_tf =
reinterpret_cast<char*>(saved_variance_tensor->flat<T>().data());
std::memcpy(saved_variance_data_tf,
reinterpret_cast<char*>(variance_values_),
depth_ * sizeof(T));
- auto variance_m = memory(variance_pd, saved_variance_data_tf);
-
- prop_kind pk = (is_training_) ? prop_kind::forward_training
- : prop_kind::forward_scoring;
- auto bnrm_fwd_desc = batch_normalization_forward::desc(
- pk, src.GetUsrMemDesc(), epsilon_,
- is_training_ ? use_scale_shift
- : (use_scale_shift | use_global_stats));
- auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(
- bnrm_fwd_desc, cpu_engine);
-
- // allocate dst tensor
+
+ // get batchnorm op from the pool
+ MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_);
+ MklFusedBatchNormFwdPrimitive<T>* bn_fwd =
+ MklFusedBatchNormFwdPrimitiveFactory<T>::Get(fwdParams);
+
+ // check if reorder is needed for src, weights, mean, variance
+ const T* src_data = src_tensor.flat<T>().data();
+ if (src_md.data.format != bn_fwd->GetSrcFmt()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ auto src_target = memory::primitive_desc(
+ {{src_dims},
+ MklDnnType<T>(),
+ static_cast<memory::format>(bn_fwd->GetSrcFmt())},
+ cpu_engine);
+ src.CheckReorderToOpMem(src_target);
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
+ }
+
+ // allocate output (dst) tensor; always set it as MKL-DNN layout
MklDnnShape dnn_shape_dst;
TensorShape tf_shape_dst;
- if (dnn_shape_src.IsMklTensor()) {
- dnn_shape_dst.SetMklTensor(true);
- auto dst_pd = bnrm_fwd_pd.dst_primitive_desc();
- dnn_shape_dst.SetMklLayout(&dst_pd);
- dnn_shape_dst.SetElemType(MklDnnType<T>());
- dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims,
- format_m);
- tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
- } else {
- dnn_shape_dst.SetMklTensor(false);
- tf_shape_dst = src_tensor.shape();
- }
+ dnn_shape_dst.SetMklTensor(true);
+ auto dst_pd = bn_fwd->GetDstPd();
+ dnn_shape_dst.SetMklLayout(&dst_pd);
+ dnn_shape_dst.SetElemType(MklDnnType<T>());
+ auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension()
+ : src_tensor.shape().dims();
+ dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m);
+ tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst,
dnn_shape_dst);
- // Output of batchnorm has same shape as input.
- dst.SetUsrMem(src_md, dst_tensor);
+ T* weights_op_data = weights_data;
+ T* mean_op_data = saved_mean_tensor->flat<T>().data();
+ T* variance_op_data = saved_variance_tensor->flat<T>().data();
+ T* dst_data = dst_tensor->flat<T>().data();
- primitive bnrm_fwd_op;
- if (is_training_) {
- bnrm_fwd_op =
- batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m,
- dst.GetOpMem(), mean_m, variance_m);
- } else {
- bnrm_fwd_op = batch_normalization_forward(
- bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m,
- (const primitive::at)weights_m, dst.GetOpMem());
- }
- std::vector<primitive> net;
- net.push_back(bnrm_fwd_op);
- stream(stream::kind::eager).submit(net).wait();
+ // execution
+ bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data,
+ variance_op_data);
// copy batch_mean data
- T* batch_mean_data_tf =
- reinterpret_cast<T*>(batch_mean_tensor->flat<T>().data());
+ T* batch_mean_data_tf = batch_mean_tensor->flat<T>().data();
std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf),
- reinterpret_cast<char*>(mean_m.get_data_handle()),
+ reinterpret_cast<char*>(saved_mean_data_tf),
depth_ * sizeof(T));
+ // TODO(yli135): OpMem is same as usr mem since
+ // since its format is hard-coded as nc when primitive is created.
// copy batch_variance data with Bessel's correction
- // if training mode is on
float adjust_factor = 1.0;
if (is_training_) {
size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
size_t adjust_size = orig_size - 1;
adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
}
- for (int k = 0; k < depth_; k++)
- batch_variance_tensor->flat<T>().data()[k] =
- (reinterpret_cast<T*>(variance_m.get_data_handle()))[k] *
- adjust_factor;
+
+ auto variance_data = reinterpret_cast<T*>(saved_variance_data_tf);
+ auto batch_variance_data = batch_variance_tensor->flat<T>().data();
+ if (is_training_) {
+ for (int k = 0; k < depth_; k++) {
+ batch_variance_data[k] = variance_data[k] * adjust_factor;
+ }
+ } else {
+ std::memcpy(batch_variance_data, variance_data, depth_ * sizeof(T));
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -933,7 +1373,8 @@ class MklFusedBatchNormOp : public OpKernel {
bool is_training_;
T* mean_values_;
T* variance_values_;
- int depth_; // batch normalization is done for per channel.
+ size_t depth_; // batch normalization is done for per channel.
+ engine cpu_engine = engine(engine::cpu, 0);
void ExtractParams(OpKernelContext* context) {
const Tensor& input = MklGetInput(context, 0);
@@ -990,8 +1431,9 @@ class MklFusedBatchNormOp : public OpKernel {
tf_shape_scale, mkl_shape_batch_mean);
CHECK_NOTNULL(*batch_mean_tensor);
// set NAN mean value in case of empty input tensor
- for (int k = 0; k < tf_shape_scale.num_elements(); k++)
- (*batch_mean_tensor)->flat<T>().data()[k] = NAN;
+ int num_elements = tf_shape_scale.num_elements();
+ auto batch_mean_data = (*batch_mean_tensor)->flat<T>().data();
+ std::fill_n(batch_mean_data, num_elements, NAN);
// allocate batch variance output tensor
MklDnnShape mkl_shape_batch_variance;
@@ -1001,8 +1443,8 @@ class MklFusedBatchNormOp : public OpKernel {
mkl_shape_batch_variance);
CHECK_NOTNULL(*batch_variance_tensor);
// set NAN variance value in case of empty input tensor
- for (int k = 0; k < tf_shape_scale.num_elements(); k++)
- (*batch_variance_tensor)->flat<T>().data()[k] = NAN;
+ auto batch_variance_data = (*batch_variance_tensor)->flat<T>().data();
+ std::fill_n(batch_variance_data, num_elements, NAN);
// Mean and variance (without Bessel's correction) saved for backward
// computation to serve as pre-computed mean and variance.
@@ -1012,8 +1454,8 @@ class MklFusedBatchNormOp : public OpKernel {
tf_shape_scale, mkl_shape_saved_mean);
CHECK_NOTNULL(*saved_mean_tensor);
// set NAN mean value in case of empty input tensor
- for (int k = 0; k < tf_shape_scale.num_elements(); k++)
- (*saved_mean_tensor)->flat<T>().data()[k] = NAN;
+ auto saved_mean_data = (*saved_mean_tensor)->flat<T>().data();
+ std::fill_n(saved_mean_data, num_elements, NAN);
MklDnnShape mkl_shape_saved_variance;
mkl_shape_saved_variance.SetMklTensor(false);
@@ -1022,8 +1464,8 @@ class MklFusedBatchNormOp : public OpKernel {
mkl_shape_saved_variance);
CHECK_NOTNULL(*saved_variance_tensor);
// set NAN variance value in case of empty input tensor
- for (int k = 0; k < tf_shape_scale.num_elements(); k++)
- (*saved_variance_tensor)->flat<T>().data()[k] = NAN;
+ auto saved_variance_data = (*saved_variance_tensor)->flat<T>().data();
+ std::fill_n(saved_variance_data, num_elements, NAN);
}
};
@@ -1044,12 +1486,12 @@ class MklFusedBatchNormGradOp : public OpKernel {
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
const size_t kDiffDstIndex = 0; // index of diff_dst tensor
const size_t kSrcIndex = 1; // index of src input tensor
const size_t kScaleIndex = 2; // index of scale tensor
const size_t kMeanIndex = 3; // index of saved_mean tensor
const size_t kVarianceIndex = 4; // index of saved_variance tensor
+
const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
@@ -1060,8 +1502,8 @@ class MklFusedBatchNormGradOp : public OpKernel {
MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
GetMklShape(context, kSrcIndex, &dnn_shape_src);
GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst);
- TensorShape tf_shape_src, tf_shape_diff_dst;
+ TensorShape tf_shape_src, tf_shape_diff_dst;
if (dnn_shape_diff_dst.IsMklTensor()) {
tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
OP_REQUIRES(
@@ -1102,6 +1544,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
saved_variance_tensor.shape().DebugString()));
Tensor* diff_src_tensor = nullptr;
+ // special case: input with 0 element and 0 batch size
if (tf_shape_src.num_elements() == 0 ||
tf_shape_diff_dst.num_elements() == 0) {
HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
@@ -1117,189 +1560,127 @@ class MklFusedBatchNormGradOp : public OpKernel {
ExtractParams(context);
}
- MklDnnData<T> src(&cpu_engine);
- MklDnnData<T> mean(&cpu_engine);
- MklDnnData<T> variance(&cpu_engine);
- MklDnnData<T> diff_dst(&cpu_engine);
- MklDnnData<T> diff_src(&cpu_engine);
-
- memory::dims src_dims, diff_dst_dims;
- if (dnn_shape_src.IsMklTensor())
- src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(),
- tensor_format_);
- else
- src_dims =
- TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
-
- if (dnn_shape_diff_dst.IsMklTensor())
- diff_dst_dims = TFShapeToMklDnnDimsInNCHW(
- dnn_shape_diff_dst.GetTfShape(), tensor_format_);
- else
- diff_dst_dims =
- TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_);
-
- // set src and diff_dst primitives according to input layout
- memory::desc src_md({}, memory::data_undef, memory::format_undef);
- memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
+ memory::format format_m;
if (dnn_shape_src.IsMklTensor()) {
- src_md = dnn_shape_src.GetMklLayout();
- } else {
- src_md = memory::desc(src_dims, MklDnnType<T>(),
- TFDataFormatToMklDnnDataFormat(tensor_format_));
- }
- if (dnn_shape_diff_dst.IsMklTensor()) {
- diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
+ if (dnn_shape_src.IsTensorInNCHWFormat())
+ format_m = memory::format::nchw;
+ else
+ format_m = memory::format::nhwc;
} else {
- diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
- TFDataFormatToMklDnnDataFormat(tensor_format_));
+ format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
}
- src.SetUsrMem(src_md, &src_tensor);
- diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
-
- // weights -- DNN packs scales/shifts as weights in order of
- // scale, ..., scale, shift, ..., shift
- auto weights_desc =
- memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc);
- auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine);
- auto weights_m = memory(weights_pd);
- T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle());
- T* scale_tf =
- reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data()));
+
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> diff_dst(&cpu_engine);
+ MklDnnData<T> weights(&cpu_engine);
+ MklDnnData<T> diff_weights(&cpu_engine);
+
+ memory::dims src_dims =
+ dnn_shape_src.IsMklTensor()
+ ? dnn_shape_src.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
+ memory::dims diff_dst_dims =
+ dnn_shape_diff_dst.IsMklTensor()
+ ? dnn_shape_diff_dst.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
+ tensor_format_);
+
+ // set src and diff_dst primitive descriptors
+ memory::desc src_md =
+ dnn_shape_src.IsMklTensor()
+ ? dnn_shape_src.GetMklLayout()
+ : memory::desc(src_dims, MklDnnType<T>(), format_m);
+ memory::desc diff_dst_md =
+ dnn_shape_diff_dst.IsMklTensor()
+ ? dnn_shape_diff_dst.GetMklLayout()
+ : memory::desc(diff_dst_dims, MklDnnType<T>(), format_m);
+
+ // weights -- MKL DNN packs scales/ shifts as weights in order
+ // of scale, ..., scale, shift, ...., shift
+ weights.AllocateBuffer(2 * depth_ * sizeof(T));
+ T* weights_data_tf = reinterpret_cast<T*>(weights.GetAllocatedBuffer());
+ const T* scale_tf = scale_tensor.flat<T>().data();
for (int k = 0; k < depth_; k++) {
- weights_data[k] = scale_tf[k];
- weights_data[k + depth_] = 0;
+ weights_data_tf[k] = scale_tf[k];
+ weights_data_tf[k + depth_] = 0;
}
- // set mean primitive
- memory::dims mv_dims = GetMeanVarianceDims();
- mean.SetUsrMem(mv_dims, memory::format::nc,
- const_cast<void*>(static_cast<const void*>(
- saved_mean_tensor.flat<T>().data())));
- mean.SetOpMemDesc(mv_dims, memory::format::nc);
-
- // set variance primitive
- variance.SetUsrMem(mv_dims, memory::format::nc,
- const_cast<void*>(static_cast<const void*>(
- saved_variance_tensor.flat<T>().data())));
- variance.SetOpMemDesc(mv_dims, memory::format::nc);
-
- // set diff_weight primitive
- auto diff_weights_desc =
- memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc);
- auto diff_weights_pd =
- memory::primitive_desc(diff_weights_desc, cpu_engine);
- auto diff_weights_m = memory(diff_weights_pd);
-
- auto bnrm_fwd_desc = batch_normalization_forward::desc(
- prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_,
- is_training_ ? use_scale_shift
- : (use_scale_shift | use_global_stats));
- auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(
- bnrm_fwd_desc, cpu_engine);
+ diff_weights.AllocateBuffer(2 * depth_ * sizeof(T));
+
+ MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_,
+ is_training_);
+ MklFusedBatchNormBwdPrimitive<T>* bn_bwd =
+ MklFusedBatchNormBwdPrimitiveFactory<T>::Get(bwdParams);
+
+ // check if src/diff_dst need to be reordered
+ const T* src_data = src_tensor.flat<T>().data();
+ if (src_md.data.format != bn_bwd->GetSrcFmt()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ auto src_target = memory::primitive_desc(
+ {{src_dims},
+ MklDnnType<T>(),
+ static_cast<memory::format>(bn_bwd->GetSrcFmt())},
+ cpu_engine);
+ src.CheckReorderToOpMem(src_target);
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
+ }
+
+ const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
+ if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ auto diff_dst_target = memory::primitive_desc(
+ {{diff_dst_dims},
+ MklDnnType<T>(),
+ static_cast<memory::format>(bn_bwd->GetDiffDstFmt())},
+ cpu_engine);
+ diff_dst.CheckReorderToOpMem(diff_dst_target);
+ diff_dst_data = const_cast<T*>(
+ reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
+ }
// Indices of output tensors
const size_t kDiffSrcIndex = 0; // index of diff_src tensor
- // allocate diff_src tensor
+ // allocate output tensor: diff_src, always set as MKL-DNN layout
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
-
- // MKL-DNN's BN primitive not provide API to fetch internal format
- // set common_md as OpMem
- // src and diff_dst will reorder to common_md
- // diff_src will set as common_md
- memory::desc common_md({}, memory::data_undef, memory::format_undef);
- if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
- if (dnn_shape_src.IsMklTensor()) {
- common_md = dnn_shape_src.GetMklLayout();
- } else {
- common_md = dnn_shape_diff_dst.GetMklLayout();
- }
- } else {
- common_md = memory::desc(src_dims, MklDnnType<T>(),
- TFDataFormatToMklDnnDataFormat(tensor_format_));
- }
- // if any of src and diff_dst as mkl layout,
- // then we set diff_src as mkl layout
- if (dnn_shape_src.IsMklTensor() ||
- dnn_shape_diff_dst.IsMklTensor()) {
- dnn_shape_diff_src.SetMklTensor(true);
- // set diff_src's mkl layout as common_md
- auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine);
- dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
- dnn_shape_diff_src.SetElemType(MklDnnType<T>());
- if (dnn_shape_src.IsMklTensor()) {
- dnn_shape_diff_src.SetTfLayout(
- dnn_shape_src.GetDimension(),
- src_dims,
- dnn_shape_src.GetTfDataFormat());
- dnn_shape_diff_src.SetTfDimOrder(
- dnn_shape_src.GetDimension(),
- tensor_format_);
- } else {
- dnn_shape_diff_src.SetTfLayout(
- dnn_shape_diff_dst.GetDimension(),
- src_dims,
- dnn_shape_diff_dst.GetTfDataFormat());
- dnn_shape_diff_src.SetTfDimOrder(
- dnn_shape_diff_dst.GetDimension(),
- tensor_format_);
- }
- tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
- } else {
- dnn_shape_diff_src.SetMklTensor(false);
- // both src and diff_dst are TensorFlow layout,
- // so it is OK to get TensorFlow shape.
- tf_shape_diff_src = src_tensor.shape();
- }
+ dnn_shape_diff_src.SetMklTensor(true);
+ auto diff_src_pd = bn_bwd->GetDiffSrcPd();
+ dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
+ dnn_shape_diff_src.SetElemType(MklDnnType<T>());
+ dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, format_m);
+ dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_);
+ tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
tf_shape_diff_src, dnn_shape_diff_src);
- // set diff_src
- diff_src.SetUsrMem(common_md, diff_src_tensor);
-
- prop_kind pk = prop_kind::backward;
- auto bnrm_bwd_desc = batch_normalization_backward::desc(
- pk, common_md, common_md, epsilon_,
- /* for inference, specify use_global_stats
- 1. on fwd prop, use mean and variance
- provided as inputs
- 2. on bwd prop, mean and variance are
- considered as constants. Thus,
- reduce the amout of MKL computations
- */
- is_training_ ? use_scale_shift
- : (use_scale_shift | use_global_stats));
- auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc(
- bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd);
-
- std::vector<primitive> net;
- src.CheckReorderToOpMem(memory::primitive_desc(common_md,
- cpu_engine), &net);
- diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md,
- cpu_engine), &net);
-
- auto bnrm_bwd_op = batch_normalization_backward(
- bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(),
- diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m);
-
- net.push_back(bnrm_bwd_op);
- stream(stream::kind::eager).submit(net).wait();
-
- // allocate 4 output TF tensors
+ T* mean_data =
+ static_cast<T*>(const_cast<T*>(saved_mean_tensor.flat<T>().data()));
+ T* variance_data = static_cast<T*>(
+ const_cast<T*>(saved_variance_tensor.flat<T>().data()));
+ T* weights_data = weights_data_tf;
+ T* diff_src_data = static_cast<T*>(diff_src_tensor->flat<T>().data());
+ T* diff_weights_data = static_cast<T*>(diff_weights.GetAllocatedBuffer());
+ // Execute
+ bn_bwd->Execute(src_data, mean_data, variance_data, diff_dst_data,
+ weights_data, diff_src_data, diff_weights_data);
+
+ // allocate output TF tensors: diff_scale and diff_shift
Tensor* diff_scale_tensor = nullptr;
Tensor* diff_shift_tensor = nullptr;
AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor,
&diff_shift_tensor);
// copy data: diff_scale and diff_shift
- T* diff_weights_data_dnn =
- reinterpret_cast<T*>(diff_weights_m.get_data_handle());
- for (int i = 0; i < depth_; i++) {
- diff_scale_tensor->flat<T>().data()[i] = diff_weights_data_dnn[i];
- diff_shift_tensor->flat<T>().data()[i] =
- diff_weights_data_dnn[i + depth_];
- }
+ auto diff_scale_data = diff_scale_tensor->flat<T>().data();
+ auto diff_shift_data = diff_shift_tensor->flat<T>().data();
+ std::memcpy(reinterpret_cast<char*>(diff_scale_data),
+ reinterpret_cast<char*>(diff_weights_data),
+ depth_ * sizeof(T));
+ std::memcpy(reinterpret_cast<char*>(diff_shift_data),
+ reinterpret_cast<char*>(diff_weights_data + depth_),
+ depth_ * sizeof(T));
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -1315,6 +1696,7 @@ class MklFusedBatchNormGradOp : public OpKernel {
TensorFormat tensor_format_;
int depth_; // batch normalization is done for per channel.
bool is_training_;
+ engine cpu_engine = engine(engine::cpu, 0);
void ExtractParams(OpKernelContext* context) {
const Tensor& input = MklGetInput(context, 0);
@@ -1330,8 +1712,8 @@ class MklFusedBatchNormGradOp : public OpKernel {
dnn_shape_diff_src.SetMklTensor(false);
AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
tf_shape_src, dnn_shape_diff_src);
- for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++)
- (*diff_src_tensor)->flat<T>().data()[i] = 0;
+ auto diff_src_data = (*diff_src_tensor)->flat<T>().data();
+ std::fill_n(diff_src_data, (*diff_src_tensor)->shape().num_elements(), 0);
Tensor* diff_scale_tensor = nullptr;
Tensor* diff_shift_tensor = nullptr;
@@ -1357,16 +1739,18 @@ class MklFusedBatchNormGradOp : public OpKernel {
AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
tf_shape_scale_shift, mkl_shape_diff_scale);
CHECK_NOTNULL(*diff_scale_tensor);
- for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++)
- (*diff_scale_tensor)->flat<T>().data()[i] = 0;
+ auto diff_scale_data = (*diff_scale_tensor)->flat<T>().data();
+ std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(),
+ 0);
MklDnnShape mkl_shape_diff_shift;
mkl_shape_diff_shift.SetMklTensor(false);
AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
tf_shape_scale_shift, mkl_shape_diff_shift);
CHECK_NOTNULL(*diff_shift_tensor);
- for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++)
- (*diff_shift_tensor)->flat<T>().data()[i] = 0;
+ auto diff_shift_data = (*diff_shift_tensor)->flat<T>().data();
+ std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(),
+ 0);
// Placeholders for estimated_mean and estimated_variance, which are
// used for inference and thus not needed here for gradient computation.
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 62c0404891..fd261433a0 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -23,14 +23,20 @@ limitations under the License.
// and when it is undefined at build time, this file becomes an empty
// compilation unit
-#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
+#if defined(INTEL_MKL)
-#include "mkl_cblas.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
+// This header file is part of MKL ML, need equivalent file in MKL DNN
+#ifndef DO_NOT_USE_ML
+#include "mkl_cblas.h"
+#else
+#include "mkldnn.h"
+#endif
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -100,7 +106,6 @@ class MklMatMulOp : public OpKernel {
private:
bool transpose_a_;
bool transpose_b_;
-
// --------------------------------------------------------------------------
//
// @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS
@@ -150,11 +155,26 @@ class MklMatMulOp : public OpKernel {
// 1.0 and 0.0 respectively.
const float alpha = 1.0f;
const float beta = 0.0f;
+#if defined(DO_NOT_USE_ML)
+ const char* const ftrans[] = {"N", "T", "C"};
+ int index_transa = transa ? 1 : 0;
+ int index_transb = transb ? 1 : 0;
+ VLOG(2) << "MKL DNN SGEMM called";
+ // MKL DNN only supports the Fortran api and requires column major while
+ // Tensorflow uses row major so we reverse the order A and B
+ mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa], &n, &m, &k, &alpha,
+ b, &ldb, a, &lda, &beta, c, &ldc);
+#else
+ // MKL ML binary uses CBLAS API
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
+#endif
}
+ // MKLDNN only supports SGEMM
+#ifndef DO_NOT_USE_ML
+
// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
// parameters, look at FP32 function description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
@@ -197,6 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
+#endif
};
#define REGISTER_CPU(T) \
@@ -207,9 +228,12 @@ class MklMatMulOp : public OpKernel {
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
+
+#ifndef DO_NOT_USE_ML
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
+#endif
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index ea537524b1..0a2151566e 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -119,6 +119,7 @@ class MklMaxPoolingOp : public OpKernel {
mkl_out_shape);
Tensor* workspace_tensor;
+ void* workspace_buf = nullptr;
TensorShape workspace_shape;
mkl_workspace_shape.SetMklTensor(false);
@@ -510,7 +511,6 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
const Tensor& input_tensor =
MklGetInput(context, this->kInputTensorIndexInput);
MklDnnShape dnn_shape_input;
@@ -525,8 +525,9 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
// Get the input tensor and initialize the pooling parameters
- this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
- &dnn_data_input);
+ TensorShape input_tensor_shape = input_tensor.shape();
+ this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
+ input_tensor_shape);
OP_REQUIRES_OK(context, context->status());
// Declare output tensor
@@ -534,44 +535,70 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
- // If input is in Mkl layout, then just get the memory format from it
- // directly, instead of using input data_format to MaxPool.
- if (dnn_shape_input.IsMklTensor()) {
- dnn_data_output.SetUsrMem(
- output_dims_mkl_order,
- static_cast<memory::format>(
- dnn_data_input.GetUsrMemDesc().data.format));
- } else {
- dnn_data_output.SetUsrMem(output_dims_mkl_order,
- this->data_format_mkldnn_);
+ // If input is an empty tensor, allocate an empty output tensor and return
+ if (input_tensor.NumElements() == 0) {
+ const int kOutputIndex = 0;
+ this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
+ output_dims_mkl_order, &output_tensor);
+ return;
}
- // describe the memory layout; let mkl-dnn choose the best for the op
- dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
-
- auto pool_desc = pooling_forward::desc(
- prop_kind::forward, algorithm::pooling_max,
- dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
- memory::dims({pool_params.row_stride, pool_params.col_stride}),
- memory::dims({pool_params.window_rows, pool_params.window_cols}),
- memory::dims({static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)}),
- memory::dims({static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)}),
- TFPaddingToMklDnnPadding(this->padding_));
- auto pool_fwd_desc =
- pooling_forward::primitive_desc(pool_desc, cpu_engine);
-
- this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order,
+ // Get the input memory descriptor
+ memory::desc input_md =
+ dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetMklLayout()
+ : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_);
+
+ // Get src/filter/stride/padding information
+ memory::dims src_dims =
+ dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_);
+
+ memory::dims filter_dims, strides, padding_left, padding_right;
+ this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
+ &padding_left, &padding_right);
+
+ // Get a pooling op from the cached pool
+ MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
+ MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims,
+ strides, padding_left, padding_right,
+ algorithm::pooling_max);
+ pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
+
+ // allocate output tensor
+ this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
+ output_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
OP_REQUIRES_OK(context, context->status());
- dnn_data_output.SetUsrMemDataHandle(output_tensor);
+ dnn_data_output.SetUsrMem(output_dims_mkl_order,
+ pooling_fwd->GetDstMemoryFormat(),
+ output_tensor);
- AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp);
+ AllocateWorkspaceTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
+ &dnn_data_wksp);
OP_REQUIRES_OK(context, context->status());
- this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input,
- &dnn_data_output, &dnn_data_wksp);
+ // check wehther we need to reorder src
+ const T* src_data = input_tensor.flat<T>().data();
+ if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) {
+ dnn_data_input.SetUsrMem(input_md, &input_tensor);
+ auto src_target_primitive_desc = memory::primitive_desc(
+ {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()},
+ cpu_engine);
+ dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc);
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle()));
+ }
+
+ T* dst_data = output_tensor->flat<T>().data();
+ void* ws_data = dnn_data_wksp.GetOpMem().get_data_handle();
+
+ // execute pooling op
+ pooling_fwd->Execute(src_data, dst_data, ws_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -579,10 +606,11 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
error_msg));
}
- } // Compute
+ }
private:
const int kOutputTensorIndexWorkspace = 1;
+ engine cpu_engine = engine(engine::cpu, 0);
void AllocateWorkspaceTensor(
OpKernelContext* context,
@@ -616,98 +644,105 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
public:
explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
: MklPoolingBackwardOpBase<T>(context) {}
-
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
const Tensor& orig_input_tensor =
MklGetInput(context, kInputTensorIndexOrigInput);
- const Tensor& orig_output_tensor =
- MklGetInput(context, kInputTensorIndexOrigOutput);
const Tensor& grad_tensor =
MklGetInput(context, kInputTensorIndexGradient);
const Tensor& workspace_tensor =
MklGetInput(context, kInputTensorIndexWorkspace);
- MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape,
- workspace_mkl_shape;
+ MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape);
- GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape);
GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape);
- GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape);
-
- SanityCheckInputs(context, orig_input_tensor, orig_output_tensor,
- grad_tensor, workspace_tensor, orig_input_mkl_shape,
- orig_output_mkl_shape, grad_mkl_shape,
- workspace_mkl_shape);
if (!context->status().ok()) return;
MklDnnData<T> grad_dnn_data(&cpu_engine);
MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
- MklDnnData<T> output_dnn_data(&cpu_engine);
- Tensor* output_tensor = nullptr;
+
MklPoolParameters pool_params;
- TensorShape orig_input_shape;
- memory::dims output_dims_mkl_order, orig_input_dims_mkl_order;
- memory::desc original_input_md = ConfigureOriginalInput(
- context, orig_input_tensor, orig_input_mkl_shape,
- &orig_input_dims_mkl_order, &pool_params, &orig_input_shape);
-
- memory::desc original_output_md = this->ConfigureOriginalOutput(
- pool_params, orig_output_mkl_shape, output_dims_mkl_order);
-
- memory::desc target_diff_dst_md = this->ConfigureInputGradient(
- grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md);
-
- output_dnn_data.SetUsrMem(original_input_md);
-
- // Create the forward pooling primitive descriptor so we can
- // pass it as a hint to the backward pooling primitive descriptor
- auto pool_fwd_desc = pooling_forward::desc(
- prop_kind::forward, algorithm::pooling_max, original_input_md,
- original_output_md,
- memory::dims({pool_params.row_stride, pool_params.col_stride}),
- memory::dims({pool_params.window_rows, pool_params.window_cols}),
- memory::dims({static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)}),
- memory::dims({static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)}),
- TFPaddingToMklDnnPadding(this->padding_));
- auto pool_fwd_prim_desc =
- pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);
-
- auto pool_bkwd_desc = pooling_backward::desc(
- algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(),
- target_diff_dst_md,
- memory::dims({pool_params.row_stride, pool_params.col_stride}),
- memory::dims({pool_params.window_rows, pool_params.window_cols}),
- memory::dims({static_cast<int>(pool_params.pad_top),
- static_cast<int>(pool_params.pad_left)}),
- memory::dims({static_cast<int>(pool_params.pad_bottom),
- static_cast<int>(pool_params.pad_right)}),
- TFPaddingToMklDnnPadding(this->padding_));
- auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
- pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
-
- this->AllocateOutputTensor(context, pool_bkwd_prim_desc,
+ TensorShape orig_input_shape = orig_input_tensor.shape();
+ this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
+ orig_input_shape);
+
+ memory::dims filter_dims, strides, padding_left, padding_right;
+ this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
+ &padding_left, &padding_right);
+
+ memory::dims diff_dst_dims =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_);
+ memory::dims orig_input_dims_mkl_order =
+ orig_input_mkl_shape.IsMklTensor()
+ ? orig_input_mkl_shape.GetSizesAsMklDnnDims()
+ : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_);
+
+ memory::dims output_dims_mkl_order;
+ this->GetOutputDims(pool_params, &output_dims_mkl_order);
+
+ MklPoolingParams bwdParams(
+ orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
+ strides, padding_left, padding_right, algorithm::pooling_max);
+ MklPoolingBwdPrimitive<T>* pooling_bwd =
+ MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
+
+ // allocate output tensor and memory primitive
+ Tensor* output_tensor = nullptr;
+ this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
this->data_format_mkldnn_, &output_tensor);
- output_dnn_data.SetUsrMemDataHandle(output_tensor);
-
- ConfigureWorkspace(workspace_tensor,
- pool_fwd_prim_desc.workspace_primitive_desc(),
- &workspace_dnn_data);
- this->PrepareAndExecuteNet(
- pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data,
- memory::primitive_desc(target_diff_dst_md, cpu_engine),
- &workspace_dnn_data);
+ // get diff_dst mem desc
+ memory::desc diff_dst_md =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetMklLayout()
+ : memory::desc(diff_dst_dims, MklDnnType<T>(),
+ this->data_format_mkldnn_);
+ // check if diff_dst needs to be reordered
+ const T* diff_dst_data = grad_tensor.flat<T>().data();
+ if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) {
+ auto target_diff_dst = memory::primitive_desc(
+ {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()},
+ cpu_engine);
+ grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
+ grad_dnn_data.CheckReorderToOpMem(target_diff_dst);
+ diff_dst_data = const_cast<T*>(
+ reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()));
+ }
+
+ void* ws_data = static_cast<void*>(
+ const_cast<uint8*>(workspace_tensor.flat<uint8>().data()));
+ ;
+ auto ws_md =
+ pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc();
+ if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) {
+ memory::dims ws_dims;
+ ws_dims.assign(ws_md.data.dims, ws_md.data.dims + ws_md.data.ndims);
+ auto target_ws =
+ memory::primitive_desc({{ws_dims},
+ pooling_bwd->GetWorkspaceDataType(),
+ pooling_bwd->GetWorkspaceFormat()},
+ cpu_engine);
+ workspace_dnn_data.SetUsrMem(ws_md, &workspace_tensor);
+ workspace_dnn_data.CheckReorderToOpMem(target_ws);
+ ws_data = workspace_dnn_data.GetOpMem().get_data_handle();
+ }
+
+ T* diff_src_data = output_tensor->flat<T>().data();
+
+ // execute pooling
+ pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
} catch (mkldnn::error& e) {
- string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
+ string error_msg = "Status:" + std::to_string(e.status) +
+ ", message: " + string(e.message) + ". in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
error_msg));
}
- } // Compute
+ }
private:
// .Input("orig_input: T")
@@ -718,18 +753,6 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
const int kInputTensorIndexOrigOutput = 1;
const int kInputTensorIndexGradient = 2;
const int kInputTensorIndexWorkspace = 3;
- // Output("output: T") in Base Class
-
- memory::desc ConfigureOriginalInput(
- OpKernelContext* context, const Tensor& tensor_original_input,
- const MklDnnShape& original_input_mkl_shape,
- memory::dims* original_input_dims_mkl_order,
- MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
- *input_tensor_shape = tensor_original_input.shape();
- return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
- context, tensor_original_input, original_input_mkl_shape,
- original_input_dims_mkl_order, pool_params, *input_tensor_shape);
- }
void ConfigureWorkspace(const Tensor& workspace_tensor,
memory::primitive_desc workspace_pd,
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index 5ef6ce2a57..915878d9ea 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -24,6 +24,187 @@ limitations under the License.
namespace tensorflow {
+#ifndef INTEL_MKL_ML
+
+using mkldnn::pooling_avg;
+using mkldnn::pooling_avg_exclude_padding;
+using mkldnn::pooling_avg_include_padding;
+using mkldnn::pooling_max;
+using mkldnn::prop_kind;
+
+template <typename T>
+void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
+ if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
+ fwdParams.alg_kind != pooling_avg_include_padding &&
+ fwdParams.alg_kind != pooling_avg_exclude_padding) {
+ assert("Pooling algorithm kind is not supported\n");
+ }
+
+ context_.alg_kind = fwdParams.alg_kind;
+ // create memory desc
+ // FIXME: Pooling doesn't expose to get the src_primitive_desc,
+ // so src format is currently hard-coded.
+ // A utility function is used to do this,
+ // which may be broken with future CPU architectures
+ context_.src_md.reset(
+ new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
+ get_desired_format(fwdParams.src_dims[1])));
+ context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
+ memory::format::any));
+
+ // create a pooling descriptor
+ context_.fwd_desc.reset(new pooling_forward::desc(
+ prop_kind::forward_training, fwdParams.alg_kind, *context_.src_md,
+ *context_.dst_md, fwdParams.strides, fwdParams.filter_dims,
+ fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero));
+ context_.fwd_pd.reset(
+ new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
+
+ // store expected primitive format
+ context_.src_fmt = get_desired_format(fwdParams.src_dims[1]);
+ context_.dst_fmt = static_cast<mkldnn::memory::format>(
+ context_.fwd_pd.get()->dst_primitive_desc().desc().data.format);
+
+ // create MKL-DNN internal memory object with dummy data
+ context_.src_mem.reset(new memory(
+ {{{fwdParams.src_dims}, MklDnnType<T>(), context_.src_fmt}, cpu_engine_},
+ DummyData));
+ context_.dst_mem.reset(
+ new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
+
+ // for max pooling, need to return workspace(ws) for backward computing
+ if (fwdParams.alg_kind == pooling_max) {
+ auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
+ // store workspace's dims and format to create workspace tensor
+ context_.ws_fmt = static_cast<mkldnn::memory::format>(ws_pd.format);
+ context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
+ context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
+ context_.ws_size =
+ context_.fwd_pd.get()->workspace_primitive_desc().get_size();
+ context_.ws_mem.reset(new memory(
+ context_.fwd_pd.get()->workspace_primitive_desc(), DummyData));
+ context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem,
+ *context_.dst_mem,
+ *context_.ws_mem));
+ } else {
+ context_.fwd.reset(new pooling_forward(*context_.fwd_pd, *context_.src_mem,
+ *context_.dst_mem));
+ }
+
+ context_.fwd_primitives.push_back(*context_.fwd);
+}
+
+template <typename T>
+void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
+ void* ws_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+ if (context_.alg_kind == pooling_max) { // max pooling must have ws
+ assert(ws_data != nullptr);
+ context_.ws_mem->set_data_handle(ws_data);
+ }
+ context_.fwd_stream->submit(context_.fwd_primitives);
+
+ // set back data handle
+ context_.src_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
+ if (context_.alg_kind == pooling_max) { // max pooling must have ws
+ assert(ws_data != nullptr);
+ context_.ws_mem->set_data_handle(DummyData);
+ }
+}
+
+template class MklPoolingFwdPrimitive<float>;
+
+template <typename T>
+void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
+ if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
+ bwdParams.alg_kind != pooling_avg_include_padding &&
+ bwdParams.alg_kind != pooling_avg_exclude_padding) {
+ assert("Pooling algorithm kind is not supported\n");
+ }
+ context_.alg_kind = bwdParams.alg_kind;
+
+ // Create memory desc
+ context_.diff_src_md.reset(new memory::desc(
+ {bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
+ context_.diff_dst_md.reset(
+ new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(),
+ get_desired_format(bwdParams.dst_dims[1])));
+ context_.bwd_desc.reset(new pooling_backward::desc(
+ bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
+ bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
+ bwdParams.padding_right, padding_kind::zero));
+
+ // create a forward primitive,
+ // which will be used as a hint for creating backward primitive
+ context_.fwd_desc.reset(new pooling_forward::desc(
+ prop_kind::forward_training, bwdParams.alg_kind, *context_.diff_src_md,
+ *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims,
+ bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero));
+ context_.fwd_pd.reset(
+ new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine));
+ context_.bwd_pd.reset(new pooling_backward::primitive_desc(
+ *context_.bwd_desc, cpu_engine, *context_.fwd_pd));
+
+ // store expected primitive format
+ context_.diff_src_fmt = static_cast<mkldnn::memory::format>(
+ context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format);
+ context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1]);
+
+ // create MKL-DNN internal memory object with dummy data
+ context_.diff_src_mem.reset(
+ new memory(context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
+ context_.diff_dst_mem.reset(new memory(
+ {{{bwdParams.dst_dims}, MklDnnType<T>(), context_.diff_dst_fmt},
+ cpu_engine},
+ DummyData));
+
+ // for max pooling, need to return workspace for backward
+ if (bwdParams.alg_kind == pooling_max) {
+ auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
+ context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
+ context_.ws_fmt = get_desired_format(context_.ws_dims[1]);
+ context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
+ context_.ws_mem.reset(new memory(
+ {{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine},
+ DummyData));
+ context_.bwd.reset(
+ new pooling_backward(*context_.bwd_pd, *context_.diff_dst_mem,
+ *context_.ws_mem, *context_.diff_src_mem));
+ } else {
+ context_.bwd.reset(new pooling_backward(
+ *context_.bwd_pd, *context_.diff_dst_mem, *context_.diff_src_mem));
+ }
+ context_.bwd_primitives.push_back(*context_.bwd);
+}
+
+template <typename T>
+void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
+ T* diff_src_data, const void* ws_data) {
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+ context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
+ if (context_.alg_kind == pooling_max) {
+ assert(ws_data != nullptr);
+ context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
+ }
+
+ context_.bwd_stream->submit(context_.bwd_primitives);
+ // set back data handle
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ context_.diff_src_mem->set_data_handle(DummyData);
+ if (context_.alg_kind == pooling_max) {
+ assert(ws_data != nullptr);
+ context_.ws_mem->set_data_handle(DummyData);
+ }
+}
+
+template class MklPoolingBwdPrimitive<float>;
+
+#endif
+
// Initialization for TensorFlow format
void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& ksize,
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index c0dfed7d7d..9c516afbd0 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
#ifdef INTEL_MKL
-#include <string>
+#include <memory>
#include <vector>
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
@@ -32,6 +32,326 @@ using mkldnn::stream;
namespace tensorflow {
+#ifndef INTEL_MKL_ML
+
+using mkldnn::memory;
+using mkldnn::pooling_avg;
+using mkldnn::pooling_avg_exclude_padding;
+using mkldnn::pooling_avg_include_padding;
+using mkldnn::pooling_max;
+using mkldnn::prop_kind;
+
+struct MklPoolingParams {
+ memory::dims src_dims;
+ memory::dims dst_dims;
+ memory::dims filter_dims;
+ memory::dims strides;
+ memory::dims padding_left;
+ memory::dims padding_right;
+ mkldnn::algorithm alg_kind;
+
+ MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
+ memory::dims filter_dims, memory::dims strides,
+ memory::dims padding_left, memory::dims padding_right,
+ mkldnn::algorithm alg_kind)
+ : src_dims(src_dims),
+ dst_dims(dst_dims),
+ filter_dims(filter_dims),
+ strides(strides),
+ padding_left(padding_left),
+ padding_right(padding_right),
+ alg_kind(alg_kind) {}
+};
+
+template <typename T>
+class MklPoolingFwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ context_.fwd_stream.reset(new stream(stream::kind::eager));
+ if (context_.fwd == nullptr) Setup(fwdParams);
+ }
+
+ ~MklPoolingFwdPrimitive() {}
+
+ // Pooling forward execute
+ // src_data: input data buffer of src
+ // ws_data: output data buffer of workspace
+ // dst_data: output data buffer of dst
+ void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
+
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd()
+ const {
+ return context_.fwd_pd;
+ }
+
+ memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
+
+ memory::format GetDstMemoryFormat() const { return context_.dst_fmt; }
+
+ private:
+ void Setup(const MklPoolingParams& fwdParams);
+
+ struct PoolingFwdContext {
+ // algorithm
+ mkldnn::algorithm alg_kind;
+
+ // expected memory format
+ memory::format src_fmt;
+ memory::format dst_fmt;
+ memory::format ws_fmt;
+
+ // workspace shape
+ memory::dims ws_dims;
+ memory::data_type ws_dt;
+ size_t ws_size;
+
+ // MKL-DNN memory, just dummy data
+ std::shared_ptr<mkldnn::memory> ws_mem;
+ std::shared_ptr<mkldnn::memory> src_mem;
+ std::shared_ptr<mkldnn::memory> dst_mem;
+
+ // desc & primitive desc
+ std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
+
+ // memory desc
+ std::shared_ptr<mkldnn::memory::desc> src_md;
+ std::shared_ptr<mkldnn::memory::desc> dst_md;
+
+ // Pooling primitive
+ std::shared_ptr<mkldnn::pooling_forward> fwd;
+ std::shared_ptr<mkldnn::stream> fwd_stream;
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ PoolingFwdContext()
+ : src_fmt(memory::format::any),
+ dst_fmt(memory::format::any),
+ ws_fmt(memory::format::any),
+ ws_mem(nullptr),
+ src_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ fwd_pd(nullptr),
+ src_md(nullptr),
+ dst_md(nullptr),
+ fwd(nullptr),
+ fwd_stream(nullptr) {}
+ };
+
+ struct PoolingFwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) {
+ MklPoolingFwdPrimitive<T>* pooling_forward = nullptr;
+
+ // Get pooling primitive from the pool
+ pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>(
+ MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd(
+ fwdParams));
+
+ if (pooling_forward == nullptr) {
+ pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams);
+ MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd(
+ fwdParams, pooling_forward);
+ }
+ return pooling_forward;
+ }
+
+ static MklPoolingFwdPrimitiveFactory& GetInstance() {
+ static MklPoolingFwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklPoolingFwdPrimitiveFactory() {}
+ ~MklPoolingFwdPrimitiveFactory() {}
+
+ // The key to be created will be used to get/set pooling
+ // primitive op from reuse perspective.
+ // A pooling key is a string which concates key parameters
+ // as well as algorithm kind (max versus avg).
+ static std::string CreateKey(const MklPoolingParams& fwdParams) {
+ std::string prefix = "pooling_fwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(fwdParams.src_dims);
+ key_creator.AddAsKey(fwdParams.dst_dims);
+ key_creator.AddAsKey(fwdParams.filter_dims);
+ key_creator.AddAsKey(fwdParams.strides);
+ key_creator.AddAsKey(fwdParams.padding_left);
+ key_creator.AddAsKey(fwdParams.padding_right);
+ key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) {
+ std::string key = CreateKey(fwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) {
+ std::string key = CreateKey(fwdParams);
+ this->SetOp(key, op);
+ }
+};
+
+template <typename T>
+class MklPoolingBwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
+ : cpu_engine(engine::cpu, 0) {
+ context_.bwd_stream.reset(new stream(stream::kind::eager));
+ if (context_.bwd == nullptr) Setup(bwdParams);
+ }
+
+ ~MklPoolingBwdPrimitive() {}
+
+ // Pooling backward execute
+ // diff_dst_data: input data buffer of diff_dst
+ // diff_src_data: output data buffer of diff_src
+ // ws_data: input data buffer of workspace
+ void Execute(const T* diff_dst_data, T* diff_src_data,
+ const void* ws_data = nullptr);
+
+ public:
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd()
+ const {
+ return context_.fwd_pd;
+ }
+ std::shared_ptr<mkldnn::pooling_backward::primitive_desc> GetPoolingBwdPd()
+ const {
+ return context_.bwd_pd;
+ }
+
+ memory::format GetDiffDstFormat() const { return context_.diff_dst_fmt; }
+
+ mkldnn::memory::data_type GetWorkspaceDataType() const {
+ return context_.ws_dt;
+ }
+ memory::format GetWorkspaceFormat() const { return context_.ws_fmt; }
+
+ private:
+ void Setup(const MklPoolingParams& bwdParams);
+
+ // Primitive reuse context for pooling bwd ops
+ struct PoolingBwdContext {
+ // algorithm
+ mkldnn::algorithm alg_kind;
+
+ // expected memory format
+ mkldnn::memory::format diff_src_fmt;
+ mkldnn::memory::format diff_dst_fmt;
+ mkldnn::memory::format ws_fmt;
+
+ // workspace attribute
+ mkldnn::memory::dims ws_dims;
+ mkldnn::memory::data_type ws_dt;
+
+ // MKL-DNN memory
+ std::shared_ptr<mkldnn::memory> ws_mem;
+ std::shared_ptr<mkldnn::memory> diff_src_mem;
+ std::shared_ptr<mkldnn::memory> diff_dst_mem;
+
+ // memory desc
+ std::shared_ptr<mkldnn::memory::desc> diff_src_md;
+ std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
+
+ // desc & primitive desc
+ std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::pooling_backward::desc> bwd_desc;
+ std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
+ std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd;
+
+ // pooling primitive
+ std::shared_ptr<mkldnn::pooling_backward> bwd;
+ std::shared_ptr<mkldnn::stream> bwd_stream;
+
+ std::vector<mkldnn::primitive> bwd_primitives;
+
+ PoolingBwdContext()
+ : diff_src_fmt(memory::format::any),
+ diff_dst_fmt(memory::format::any),
+ ws_fmt(memory::format::any),
+ ws_mem(nullptr),
+ diff_src_mem(nullptr),
+ diff_dst_mem(nullptr),
+ diff_src_md(nullptr),
+ diff_dst_md(nullptr),
+ fwd_desc(nullptr),
+ bwd_desc(nullptr),
+ fwd_pd(nullptr),
+ bwd_pd(nullptr),
+ bwd(nullptr),
+ bwd_stream(nullptr) {}
+ };
+
+ struct PoolingBwdContext context_;
+ engine cpu_engine;
+};
+
+template <typename T>
+class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) {
+ MklPoolingBwdPrimitive<T>* pooling_backward = nullptr;
+
+ // Find a pooling backward primitive from the pool
+ // If it does not exist, create a new one
+ pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>(
+ MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd(
+ bwdParams));
+ if (pooling_backward == nullptr) {
+ pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams);
+ MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd(
+ bwdParams, pooling_backward);
+ }
+ return pooling_backward;
+ }
+
+ static MklPoolingBwdPrimitiveFactory& GetInstance() {
+ static MklPoolingBwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklPoolingBwdPrimitiveFactory() {}
+ ~MklPoolingBwdPrimitiveFactory() {}
+
+ // The key to be created will be used to get/set pooling
+ // primitive op from reuse perspective.
+ // A pooling key is a string which concates key parameters
+ // as well as algorithm kind (max versus avg).
+ static std::string CreateKey(const MklPoolingParams& bwdParams) {
+ std::string prefix = "pooling_bwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(bwdParams.src_dims);
+ key_creator.AddAsKey(bwdParams.dst_dims);
+ key_creator.AddAsKey(bwdParams.filter_dims);
+ key_creator.AddAsKey(bwdParams.strides);
+ key_creator.AddAsKey(bwdParams.padding_left);
+ key_creator.AddAsKey(bwdParams.padding_right);
+ key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) {
+ std::string key = CreateKey(bwdParams);
+ return this->GetOp(key);
+ }
+
+ void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) {
+ std::string key = CreateKey(bwdParams);
+ this->SetOp(key, op);
+ }
+};
+#endif
+
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklPoolParameters {
@@ -163,6 +483,41 @@ class MklPoolingOpBase : public OpKernel {
}
}
+ void PoolParamsToDims(const MklPoolParameters* pool_params,
+ memory::dims* filter_dims, memory::dims* strides,
+ memory::dims* padding_left,
+ memory::dims* padding_right) {
+ *filter_dims = {pool_params->window_rows, pool_params->window_cols};
+ *strides = {pool_params->row_stride, pool_params->col_stride};
+ *padding_left = {static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)};
+ *padding_right = {static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)};
+ }
+
+ void AllocateEmptyOutputTensor(OpKernelContext* context,
+ const int kOutputIndex,
+ MklPoolParameters* pool_params,
+ const memory::dims output_dims_mkl_order,
+ Tensor** output_tensor) {
+ MklDnnShape output_mkl_shape;
+ output_mkl_shape.SetMklTensor(false);
+ TensorShape output_tf_shape;
+ if (pool_params->data_format == TensorFormat::FORMAT_NCHW) {
+ output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
+ } else {
+ memory::dims output_dims_NHWC_order;
+ output_dims_NHWC_order = {pool_params->tensor_in_batch,
+ static_cast<int>(pool_params->out_height),
+ static_cast<int>(pool_params->out_width),
+ pool_params->out_depth};
+ output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
+ }
+ AllocateOutputSetMklShape(context, kOutputIndex, output_tensor,
+ output_tf_shape, output_mkl_shape);
+ CHECK_NOTNULL(output_tensor);
+ }
+
// Checks to make sure that the memory we need to allocate
// is a multiple of sizeof(T)
// returns the number of elements
@@ -235,23 +590,6 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
CHECK_NOTNULL(*output_tensor);
}
- void PrepareAndExecuteNet(
- const pooling_forward::primitive_desc& pool_fwd_desc,
- const MklDnnData<T>* src, MklDnnData<T>* dst,
- MklDnnData<uint8>* wksp = nullptr) {
- std::vector<primitive> net;
-
- // Create pooling primitive and add it to net
- if (wksp != nullptr) {
- net.push_back(pooling_forward(pool_fwd_desc, src->GetOpMem(),
- dst->GetOpMem(), wksp->GetOpMem()));
- } else {
- net.push_back(
- pooling_forward(pool_fwd_desc, src->GetOpMem(), dst->GetOpMem()));
- }
- stream(stream::kind::eager).submit(net).wait();
- }
-
void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
const MklDnnShape& input_mkl_shape) {
if (!input_mkl_shape.IsMklTensor()) {
@@ -301,67 +639,6 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
CHECK_NOTNULL(*output_tensor);
}
- void PrepareAndExecuteNet(
- const pooling_backward::primitive_desc& pool_bkwd_desc,
- MklDnnData<T>* input_gradient_diff_dst, MklDnnData<T>* output_diff_src,
- const memory::primitive_desc& target_diff_dst_pd,
- const MklDnnData<uint8>* workspace = nullptr) {
- std::vector<primitive> net;
-
- // If the input gradient isn't in the same format as the output
- // reorder it to the same format as the output
- input_gradient_diff_dst->CheckReorderToOpMem(target_diff_dst_pd, &net);
-
- // Create pooling primitive and add it to net
- if (nullptr == workspace) {
- net.push_back(pooling_backward(pool_bkwd_desc,
- input_gradient_diff_dst->GetOpMem(),
- output_diff_src->GetOpMem()));
- } else {
- net.push_back(
- pooling_backward(pool_bkwd_desc, input_gradient_diff_dst->GetOpMem(),
- workspace->GetOpMem(), output_diff_src->GetOpMem()));
- }
- stream(stream::kind::eager).submit(net).wait();
- }
-
- // Max Pooling and Avg Pooling have slightly different implementations
- // Takes the Tensor containing original input data and the original
- // mkl Dnn Shape and populates other data
- memory::desc ConfigureOriginalInput(
- OpKernelContext* context, const Tensor& tensor_original_input_shape,
- const MklDnnShape& original_input_mkl_shape,
- memory::dims* original_input_dims_nchw, MklPoolParameters* pool_params,
- const TensorShape& input_tensor_shape) {
- CHECK_NOTNULL(original_input_dims_nchw);
- CHECK_NOTNULL(pool_params);
- this->InitMklPoolParameters(context, pool_params, original_input_mkl_shape,
- input_tensor_shape);
-
- *original_input_dims_nchw =
- original_input_mkl_shape.IsMklTensor()
- ? original_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_);
-
- return original_input_mkl_shape.IsMklTensor()
- ? original_input_mkl_shape.GetMklLayout()
- : memory::desc(*original_input_dims_nchw, MklDnnType<T>(),
- this->data_format_mkldnn_);
- }
-
- memory::desc ConfigureOriginalOutput(
- const MklPoolParameters& pool_params,
- const MklDnnShape& original_output_mkl_shape,
- memory::dims output_dims_mkl_order) {
- this->GetOutputDims(pool_params, &output_dims_mkl_order);
-
- return original_output_mkl_shape.IsMklTensor()
- ? original_output_mkl_shape.GetMklLayout()
- : memory::desc(output_dims_mkl_order, MklDnnType<T>(),
- this->data_format_mkldnn_);
- }
-
memory::desc ConfigureInputGradient(
const MklDnnShape& input_gradient_mkl_shape,
const Tensor& input_gradient_tensor,
diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc
index 02ea9fc068..9c536df215 100644
--- a/tensorflow/core/kernels/mkl_reshape_op.cc
+++ b/tensorflow/core/kernels/mkl_reshape_op.cc
@@ -152,8 +152,12 @@ class MklReshapeOp : public OpKernel {
// If Tensorflow's data format and the underlying format maintained by
// MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
// safely return true.
+ // @todo: Future do not force skip reorder for all blocked format. Use
+ // blocking_desc_is_equal() for checking all the stride arrays in
+ // mkl-dnn/blob/master/src/common/type_helpers.hpp
auto input_mkl_md = mkl_shape_input.GetMklLayout();
- if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) {
+ if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format &&
+ mkl_shape_input.GetTfDataFormat() != memory::format::blocked) {
ret = true;
}
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index f59843a07a..c7d0d4de0d 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -121,10 +121,11 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
-void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
- int num_boxes, const Tensor& max_output_size,
- const float score_threshold,
- std::function<bool(int, int)> suppress_check_fn) {
+void DoNonMaxSuppressionOp(
+ OpKernelContext* context, const Tensor& scores, int num_boxes,
+ const Tensor& max_output_size, const float score_threshold,
+ const std::function<bool(int, int)>& suppress_check_fn,
+ bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
std::vector<float> scores_data(num_boxes);
@@ -172,6 +173,15 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& scores,
}
}
+ int num_valid_outputs = selected.size();
+ if (pad_to_max_output_size) {
+ selected.resize(output_size, 0);
+ selected_scores.resize(output_size, 0);
+ }
+ if (ptr_num_valid_outputs) {
+ *ptr_num_valid_outputs = num_valid_outputs;
+ }
+
// Allocate output tensors
Tensor* output_indices = nullptr;
TensorShape output_shape({static_cast<int>(selected.size())});
@@ -262,54 +272,106 @@ class NonMaxSuppressionV2Op : public OpKernel {
}
};
-template <typename Device>
-class NonMaxSuppressionV3Op : public OpKernel {
+class NonMaxSuppressionV3V4Base : public OpKernel {
public:
- explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ explicit NonMaxSuppressionV3V4Base(OpKernelConstruction* context)
: OpKernel(context) {}
void Compute(OpKernelContext* context) override {
// boxes: [num_boxes, 4]
- const Tensor& boxes = context->input(0);
+ boxes_ = context->input(0);
// scores: [num_boxes]
- const Tensor& scores = context->input(1);
+ scores_ = context->input(1);
// max_output_size: scalar
- const Tensor& max_output_size = context->input(2);
+ max_output_size_ = context->input(2);
OP_REQUIRES(
- context, TensorShapeUtils::IsScalar(max_output_size.shape()),
+ context, TensorShapeUtils::IsScalar(max_output_size_.shape()),
errors::InvalidArgument("max_output_size must be 0-D, got shape ",
- max_output_size.shape().DebugString()));
+ max_output_size_.shape().DebugString()));
// iou_threshold: scalar
const Tensor& iou_threshold = context->input(3);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(iou_threshold.shape()),
errors::InvalidArgument("iou_threshold must be 0-D, got shape ",
iou_threshold.shape().DebugString()));
- const float iou_threshold_val = iou_threshold.scalar<float>()();
-
+ iou_threshold_val_ = iou_threshold.scalar<float>()();
+ OP_REQUIRES(context, iou_threshold_val_ >= 0 && iou_threshold_val_ <= 1,
+ errors::InvalidArgument("iou_threshold must be in [0, 1]"));
// score_threshold: scalar
const Tensor& score_threshold = context->input(4);
OP_REQUIRES(
context, TensorShapeUtils::IsScalar(score_threshold.shape()),
errors::InvalidArgument("score_threshold must be 0-D, got shape ",
score_threshold.shape().DebugString()));
- const float score_threshold_val = score_threshold.scalar<float>()();
+ score_threshold_val_ = score_threshold.scalar<float>()();
- OP_REQUIRES(context, iou_threshold_val >= 0 && iou_threshold_val <= 1,
- errors::InvalidArgument("iou_threshold must be in [0, 1]"));
- int num_boxes = 0;
- ParseAndCheckBoxSizes(context, boxes, &num_boxes);
- CheckScoreSizes(context, num_boxes, scores);
+ num_boxes_ = 0;
+ ParseAndCheckBoxSizes(context, boxes_, &num_boxes_);
+ CheckScoreSizes(context, num_boxes_, scores_);
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoComputeAndPostProcess(context);
+ }
+
+ protected:
+ virtual void DoComputeAndPostProcess(OpKernelContext* context) = 0;
+
+ Tensor boxes_;
+ Tensor scores_;
+ Tensor max_output_size_;
+ int num_boxes_;
+ float iou_threshold_val_;
+ float score_threshold_val_;
+};
+
+template <typename Device>
+class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {}
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
template <typename Device>
+class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
+ public:
+ explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
+ : NonMaxSuppressionV3V4Base(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("pad_to_max_output_size",
+ &pad_to_max_output_size_));
+ }
+
+ protected:
+ void DoComputeAndPostProcess(OpKernelContext* context) override {
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ int num_valid_outputs;
+
+ DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
+
+ // Allocate scalar output tensor for number of indices computed.
+ Tensor* num_outputs_t = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(
+ 1, tensorflow::TensorShape{}, &num_outputs_t));
+ num_outputs_t->scalar<int32>().setConstant(num_valid_outputs);
+ }
+
+ private:
+ bool pad_to_max_output_size_;
+};
+
+template <typename Device>
class NonMaxSuppressionWithOverlapsOp : public OpKernel {
public:
explicit NonMaxSuppressionWithOverlapsOp(OpKernelConstruction* context)
@@ -365,6 +427,9 @@ REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice>);
+
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
NonMaxSuppressionWithOverlapsOp<CPUDevice>);
diff --git a/tensorflow/core/kernels/non_max_suppression_op_test.cc b/tensorflow/core/kernels/non_max_suppression_op_test.cc
index 055161a35f..c321849f40 100644
--- a/tensorflow/core/kernels/non_max_suppression_op_test.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op_test.cc
@@ -570,6 +570,61 @@ TEST_F(NonMaxSuppressionV3OpTest, TestEmptyInput) {
}
//
+// NonMaxSuppressionV4Op Tests
+//
+
+class NonMaxSuppressionV4OpTest : public OpsTestBase {
+ protected:
+ void MakeOp() {
+ TF_EXPECT_OK(NodeDefBuilder("non_max_suppression_op", "NonMaxSuppressionV4")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("pad_to_max_output_size", true)
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ }
+};
+
+TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFive) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {5});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.0f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ const auto expected_indices = test::AsTensor<int>({3, 0, 5, 0, 0});
+ test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0));
+ Tensor expected_num_valid = test::AsScalar<int>(3);
+ test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1));
+}
+
+TEST_F(NonMaxSuppressionV4OpTest, TestSelectFromThreeClustersPadFiveScoreThr) {
+ MakeOp();
+ AddInputFromArray<float>(
+ TensorShape({6, 4}),
+ {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1, 0.9f,
+ 0, 10, 1, 11, 0, 10.1f, 1, 11.1f, 0, 100, 1, 101});
+ AddInputFromArray<float>(TensorShape({6}), {.9f, .75f, .6f, .95f, .5f, .3f});
+ AddInputFromArray<int>(TensorShape({}), {6});
+ AddInputFromArray<float>(TensorShape({}), {.5f});
+ AddInputFromArray<float>(TensorShape({}), {0.4f});
+ TF_ASSERT_OK(RunOpKernel());
+
+ const auto expected_indices = test::AsTensor<int>({3, 0, 0, 0, 0, 0});
+ test::ExpectTensorEqual<int>(expected_indices, *GetOutput(0));
+ Tensor expected_num_valid = test::AsScalar<int>(2);
+ test::ExpectTensorEqual<int>(expected_num_valid, *GetOutput(1));
+}
+
+//
// NonMaxSuppressionWithOverlapsOp Tests
//
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index a7a9609c21..33ed044dae 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -114,8 +114,16 @@ class PartitionedCallOp : public AsyncOpKernel {
// The FunctionLibraryRuntime's library cannot be mutated from within
// an OpKernel, so functions are instantiated in an overlay library.
- overlay_lib_.reset(new FunctionLibraryDefinition(
- *lib->GetFunctionLibraryDefinition()));
+ OP_REQUIRES_ASYNC(
+ ctx, overlay_libs_.find(lib) == overlay_libs_.end(),
+ errors::Internal("Found an overlay library but did not "
+ "find cached function partitions; "
+ "this indicates a bug."),
+ done);
+ FunctionLibraryDefinition* overlay_lib =
+ new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
+ overlay_libs_.emplace(lib, overlay_lib);
+
auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
// TODO(akshayka): Fail gracefully if the set of devices corresponds
@@ -125,13 +133,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
ctx, UpdateArgAndRetMetadata(target, subgraph.get()), done);
FunctionDef shard;
- string unique_name = UniquifyFunctionName(func_.name());
+ string unique_name = UniquifyFunctionName(overlay_lib, func_.name());
OP_REQUIRES_OK_ASYNC(
ctx, GraphToFunctionDef(*subgraph, unique_name, &shard), done);
- OP_REQUIRES_OK_ASYNC(ctx, overlay_lib_->AddFunctionDef(shard), done);
+ OP_REQUIRES_OK_ASYNC(ctx, overlay_lib->AddFunctionDef(shard), done);
FunctionLibraryRuntime::InstantiateOptions opts;
opts.target = target;
- opts.overlay_lib = overlay_lib_.get();
+ opts.overlay_lib = overlay_lib;
FHandle handle;
OP_REQUIRES_OK_ASYNC(
ctx,
@@ -399,10 +407,11 @@ class PartitionedCallOp : public AsyncOpKernel {
}
}
- string UniquifyFunctionName(const string& name) {
+ string UniquifyFunctionName(const FunctionLibraryDefinition* function_library,
+ const string& name) {
for (;; ++suffix_) {
const string candidate = strings::StrCat(name, "_", suffix_);
- if (overlay_lib_->Find(candidate) == nullptr) {
+ if (function_library->Find(candidate) == nullptr) {
return candidate;
}
}
@@ -410,14 +419,16 @@ class PartitionedCallOp : public AsyncOpKernel {
NameAttrList func_;
string local_device_name_;
- // Function shards are added to `overlay_lib_`.
- std::unique_ptr<FunctionLibraryDefinition> overlay_lib_;
- // Contains maps from device names to handles of function shards, keyed by
+ // Contains maps from device names to handles of function partitions, keyed by
// FunctionLibraryRuntime pointers. (Because this kernel may be instantiated
// for a stateful op, different invocations of it may use different FLRs.)
gtl::FlatMap<FunctionLibraryRuntime*,
std::unique_ptr<gtl::FlatMap<string, FHandle>>>
function_handles_ GUARDED_BY(mu_);
+ // Function partitions are added to overlay libraries.
+ gtl::FlatMap<FunctionLibraryRuntime*,
+ std::unique_ptr<FunctionLibraryDefinition>>
+ overlay_libs_ GUARDED_BY(mu_);
// Map from device name to the indices of the arguments and return values
// placed on that device. Read-only after the first invocation.
gtl::FlatMap<string, ArgAndRetIndices> arg_and_ret_indices_;
@@ -427,7 +438,7 @@ class PartitionedCallOp : public AsyncOpKernel {
mutex mu_;
- // Used to uniquify function names in `overlay_lib_`.
+ // Used to uniquify function names in `overlay_libs_`.
uint32 suffix_ = 0;
};
REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op.h b/tensorflow/core/kernels/quantize_and_dequantize_op.h
index 782263e4e9..6b0c5e5a46 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op.h
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/cwise_ops.h"
namespace tensorflow {
namespace functor {
@@ -89,17 +90,14 @@ struct QuantizeAndDequantizeOneScaleImpl {
// min_range and max_range - because we may have changed either min_range
// or max_range.
out.device(d) =
- ((input.cwiseMin(max_range).cwiseMax(min_range) - min_range) * scale +
- T(0.5))
- .floor() *
- inverse_scale +
- min_range;
+ (input.cwiseMin(max_range).cwiseMax(min_range) * scale)
+ .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) *
+ inverse_scale;
} else {
- // No need to clamp to min_range and max_range in this case as they were
- // measured from the tensor.
out.device(d) =
- ((input - min_range) * scale + T(0.5)).floor() * inverse_scale +
- min_range;
+ (input * scale)
+ .unaryExpr(Eigen::internal::scalar_round_op_google<T>()) *
+ inverse_scale;
}
}
};
diff --git a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
index 629c698503..cddabf8a99 100644
--- a/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
+++ b/tensorflow/core/kernels/quantize_and_dequantize_op_test.cc
@@ -226,13 +226,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given) {
AddInputFromArray<float>(TensorShape({}), {1.0}); // Max
// Note that the range is given as [-1, 1].
- // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128,
+ // With int8, the tensor is quantized to {-102, -64, 0, 38, 102, 70, -128,
// 127}.
// Scale is: 1/127
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4}));
test::FillValues<float>(
- &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127,
+ &expected, {-102.0 / 127, -64.0 / 127, 0, 38.0 / 127, 102.0 / 127,
70.0 / 127, -128.0 / 127, 1});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
@@ -257,13 +257,13 @@ TEST_F(QuantizeAndDequantizeTest, Convert_2D_tensor_with_int8_range_given_V3) {
AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits
// Note that the range is given as [-1, 1].
- // With int8, the tensor is quantized to {-102, -63, 0, 38, 102, 70, -128,
+ // With int8, the tensor is quantized to {-102, -64, 0, 38, 102, 70, -128,
// 127}.
// Scale is: 1/127
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 4}));
test::FillValues<float>(
- &expected, {-102.0 / 127, -63.0 / 127, 0, 38.0 / 127, 102.0 / 127,
+ &expected, {-102.0 / 127, -64.0 / 127, 0, 38.0 / 127, 102.0 / 127,
70.0 / 127, -128.0 / 127, 1});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
@@ -285,11 +285,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given) {
AddInputFromArray<float>(TensorShape({}), {1.0}); // Max
// Note that the range is given as [0, 1].
- // With int8, the tensor is quantized to {0, 0, 77, 204}
+ // With int8, the tensor is quantized to {0, 0, 76, 204}
// Scale is: 1/255
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1}));
- test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255});
+ test::FillValues<float>(&expected, {0, 0, 76.0 / 255, 204.0 / 255});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
@@ -311,11 +311,11 @@ TEST_F(QuantizeAndDequantizeTest, Convert_4D_tensor_with_uint8_range_given_V3) {
AddInputFromArray<int32>(TensorShape({}), {8}); // num_bits
// Note that the range is given as [0, 1].
- // With int8, the tensor is quantized to {0, 0, 77, 204}
+ // With int8, the tensor is quantized to {0, 0, 76, 204}
// Scale is: 1/255
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, TensorShape({2, 2, 1, 1}));
- test::FillValues<float>(&expected, {0, 0, 77.0 / 255, 204.0 / 255});
+ test::FillValues<float>(&expected, {0, 0, 76.0 / 255, 204.0 / 255});
test::ExpectTensorNear<float>(expected, *GetOutput(0), 1e-5);
}
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index c5292e1ae1..cab9eb729d 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -213,64 +213,32 @@ class AssignVariableOp : public OpKernel {
"Variable and value dtypes don't match; respectively, ",
dtype_, " and ", context->input(1).dtype()));
Var* variable = nullptr;
- OP_REQUIRES_OK(
- context,
- LookupOrCreateResource<Var>(
- context, HandleFromInput(context, 0), &variable,
- [this, context](Var** ptr) {
- *ptr = new Var(dtype_);
- PersistentTensor unused;
- Tensor* tmp;
- AllocatorAttributes attr;
- if (!relax_constraints_) {
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- }
- TF_RETURN_IF_ERROR(context->allocate_persistent(
- dtype_, context->input(1).shape(), &unused, &tmp, attr));
- *(*ptr)->tensor() = *tmp;
- return Status::OK();
- }));
+ const Tensor& value = context->input(1);
+ // Note: every resource-variable-manipulating op assumes copy-on-write
+ // semantics, and creates a copy of the variable's Tensor if its refcount is
+ // bigger than 1 when we try to modify it. This means we never need to copy
+ // the original tensor for AssignVariableOp; even if there are other live
+ // users of it we know none can modify it so this is always safe (even in
+ // esoteric cases where the same tensor is used to initialize multiple
+ // variables or the tensor is a constant this is safe, as future writes will
+ // trigger copies).
+ OP_REQUIRES_OK(context, LookupOrCreateResource<Var>(
+ context, HandleFromInput(context, 0), &variable,
+ [this, &value](Var** ptr) {
+ *ptr = new Var(dtype_);
+ *(*ptr)->tensor() = value;
+ (*ptr)->is_initialized = true;
+ return Status::OK();
+ }));
core::ScopedUnref s(variable);
-
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
-
- const Tensor& value = context->input(1);
- AllocatorAttributes attr;
- if (!relax_constraints_) {
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- }
-
- // Copying is unnecessary if we are the last user of the value
- // tensor, we can just adopt the input tensor's buffer instead.
- std::unique_ptr<Tensor> input_alias = context->forward_input(
- 1, OpKernelContext::Params::kNoReservation /*output_index*/, dtype_,
- value.shape(), DEVICE_MEMORY, attr);
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
- if (input_alias) {
- *variable->tensor() = *input_alias;
- return;
- }
-
- // Need to copy, but maybe we can re-use variable's buffer?
- if (!variable->tensor()->RefCountIsOne() ||
- !variable->tensor()->shape().IsSameSize(value.shape())) {
- // Copy to new buffer
- PersistentTensor unused;
- Tensor* tmp;
- OP_REQUIRES_OK(context, context->allocate_persistent(
- dtype_, value.shape(), &unused, &tmp, attr));
- *variable->tensor() = *tmp;
- }
- functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
- copy_functor(context->eigen_device<Device>(), variable->tensor()->flat<T>(),
- value.flat<T>());
+ *variable->tensor() = value;
}
private:
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index 990bd2bff9..e335e38bdc 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -23,7 +23,9 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
@@ -95,7 +97,7 @@ void SaveTensors(
return tensor_names_flat(a) < tensor_names_flat(b);
});
- for (size_t i : sorted_name_idx) {
+ for (const size_t i : sorted_name_idx) {
const string& name = tensor_names_flat(i);
const Tensor& input = context->input(i + kFixedInputs);
TensorShape shape(input.shape());
@@ -226,43 +228,53 @@ void RestoreTensor(OpKernelContext* context,
#undef READER_COPY
}
-Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
- const Tensor& tensor_names,
- const Tensor& shape_and_slices,
- gtl::ArraySlice<DataType> dtypes) {
- const string& prefix_string = prefix.scalar<string>()();
+namespace {
- const auto& tensor_names_flat = tensor_names.flat<string>();
- const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+// Tensors larger than this threshold will be restored from a thread-pool.
+const int64 kLargeShapeThreshold = 16 << 20; // 16M
- // Sort lookup keys to improve locality when reading multiple tensors.
- std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
- std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
- std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
- [&tensor_names_flat](size_t a, size_t b) {
- return tensor_names_flat(a) < tensor_names_flat(b);
- });
+// A restore operation for a single tensor. Small tensors may be restored
+// directly from the op thread to improve read locality. Large tensors can be
+// restored from a thread pool: this requires creating a separate BundleReader
+// for each restore.
+struct RestoreOp {
+ RestoreOp& operator=(const RestoreOp&) = delete;
- BundleReader reader(Env::Default(), prefix_string);
- TF_RETURN_IF_ERROR(reader.status());
+ bool should_run_in_pool(BundleReader* reader) const {
+ TensorShape restored_full_shape;
- // TODO(zongheng): potential optimization: one Seek() in first lookup.
- // TODO(zongheng): consider measuring speed and issuing concurrent lookups
- // within a fixed memory budget.
- TensorShape restored_full_shape;
- Tensor* restored_tensor = nullptr;
- for (auto i : sorted_name_idx) {
- const string& tensor_name = tensor_names_flat(i);
- const string& shape_and_slice = shape_and_slices_flat(i);
+ // Ignore status here; we'll catch the error later.
+ if (!reader->LookupTensorShape(tensor_name, &restored_full_shape).ok()) {
+ return false;
+ }
+
+ return restored_full_shape.num_elements() > kLargeShapeThreshold;
+ }
+
+ // Run this restore operation using a new BundleReader.
+ void run_with_new_reader() {
+ BundleReader reader(Env::Default(), reader_prefix);
+ if (!reader.status().ok()) {
+ status = reader.status();
+ return;
+ }
+
+ status = run(&reader);
+ }
+ Status run(BundleReader* reader) {
+ TensorShape restored_full_shape;
TF_RETURN_IF_ERROR(
- reader.LookupTensorShape(tensor_name, &restored_full_shape));
+ reader->LookupTensorShape(tensor_name, &restored_full_shape));
+ VLOG(1) << "Restoring tensor " << idx << " : " << tensor_name << " : "
+ << restored_full_shape.num_elements();
+ Tensor* restored_tensor;
if (shape_and_slice.empty()) {
// Lookup the full tensor.
TF_RETURN_IF_ERROR(
- context->allocate_output(i, restored_full_shape, &restored_tensor));
- TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor));
+ context->allocate_output(idx, restored_full_shape, &restored_tensor));
+ TF_RETURN_IF_ERROR(reader->Lookup(tensor_name, restored_tensor));
} else {
// Lookup the slice.
TensorShape parsed_full_shape;
@@ -272,6 +284,7 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
TF_RETURN_IF_ERROR(
checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
&parsed_slice, &parsed_slice_shape));
+
if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
return errors::InvalidArgument(
"tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
@@ -279,19 +292,113 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
" does not match the shape stored in checkpoint: ",
restored_full_shape.DebugString());
}
-
TF_RETURN_IF_ERROR(
- context->allocate_output(i, parsed_slice_shape, &restored_tensor));
+ context->allocate_output(idx, parsed_slice_shape, &restored_tensor));
TF_RETURN_IF_ERROR(
- reader.LookupSlice(tensor_name, parsed_slice, restored_tensor));
+ reader->LookupSlice(tensor_name, parsed_slice, restored_tensor));
+ }
+ return Status::OK();
+ }
+
+ OpKernelContext* context;
+ size_t idx;
+ string tensor_name;
+ string shape_and_slice;
+ string reader_prefix;
+
+ ::tensorflow::Status status;
+};
+
+} // namespace
+
+Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
+ const Tensor& tensor_names,
+ const Tensor& shape_and_slices,
+ gtl::ArraySlice<DataType> dtypes) {
+ const string& prefix_string = prefix.scalar<string>()();
+
+ const auto& tensor_names_flat = tensor_names.flat<string>();
+ const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
+
+ // Sort lookup keys to improve locality when reading multiple tensors.
+ std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
+ std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
+ std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
+ [&tensor_names_flat](size_t a, size_t b) {
+ return tensor_names_flat(a) < tensor_names_flat(b);
+ });
+
+ std::vector<std::unique_ptr<RestoreOp> > pool_restore_ops;
+ std::vector<std::unique_ptr<RestoreOp> > direct_restore_ops;
+
+ BundleReader default_reader(Env::Default(), prefix_string);
+ TF_RETURN_IF_ERROR(default_reader.status());
+
+ std::vector<string> mismatched_errors;
+ for (const size_t i : sorted_name_idx) {
+ TensorShape restored_full_shape;
+ DataType original_dtype;
+ const string& tensor_name = tensor_names_flat(i);
+ TF_RETURN_IF_ERROR(default_reader.LookupDtypeAndShape(
+ tensor_name, &original_dtype, &restored_full_shape));
+ if (dtypes[i] != original_dtype) {
+ string error_msg = strings::StrCat(
+ "tensor_name = ", tensor_name, "; expected dtype ",
+ DataTypeString(dtypes[i]), " does not equal original dtype ",
+ DataTypeString(original_dtype));
+ mismatched_errors.emplace_back(error_msg);
+ }
+ }
+ if (!mismatched_errors.empty()) {
+ const string error_msg = str_util::Join(mismatched_errors, "\n");
+ return errors::InvalidArgument(error_msg);
+ }
+
+ for (auto i : sorted_name_idx) {
+ const string& tensor_name = tensor_names_flat(i);
+ const string& shape_and_slice = shape_and_slices_flat(i);
+ auto op =
+ new RestoreOp{context, i, tensor_name, shape_and_slice, prefix_string};
+ if (op->should_run_in_pool(&default_reader)) {
+ pool_restore_ops.emplace_back(op);
+ } else {
+ direct_restore_ops.emplace_back(op);
+ }
+ }
+
+ {
+ // Schedule any threaded operations first, skipping thread pool creation if
+ // we don't have any expensive operations.
+ std::unique_ptr<thread::ThreadPool> reader_pool;
+ if (!pool_restore_ops.empty()) {
+ reader_pool.reset(
+ new thread::ThreadPool(Env::Default(), "restore_tensors", 8));
+ for (auto& op : pool_restore_ops) {
+ reader_pool->Schedule([&op]() { op->run_with_new_reader(); });
+ }
}
- if (dtypes[i] != restored_tensor->dtype()) {
+
+ // Read small tensors from the op thread
+ for (auto& op : direct_restore_ops) {
+ TF_RETURN_IF_ERROR(op->run(&default_reader));
+ }
+ }
+
+ // Check status of pool ops; this must come after the pool shuts down.
+ for (auto& op : pool_restore_ops) {
+ TF_RETURN_IF_ERROR(op->status);
+ }
+
+ for (auto i : sorted_name_idx) {
+ const string& tensor_name = tensor_names_flat(i);
+ if (dtypes[i] != context->mutable_output(i)->dtype()) {
return errors::InvalidArgument(
"tensor_name = ", tensor_name, "; expected dtype ",
DataTypeString(dtypes[i]), " does not equal restored dtype ",
- DataTypeString(restored_tensor->dtype()));
+ DataTypeString(context->mutable_output(i)->dtype()));
}
}
+
return Status::OK();
}
diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc
index e72608945b..93a753787a 100644
--- a/tensorflow/core/kernels/softmax_op.cc
+++ b/tensorflow/core/kernels/softmax_op.cc
@@ -61,15 +61,16 @@ class SoftmaxOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& logits_in = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in.shape()),
+ errors::InvalidArgument("logits must have >= 1 dimension, got ",
+ logits_in.shape().DebugString()));
Tensor* softmax_out = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, logits_in.shape(), &softmax_out));
if (logits_in.NumElements() > 0) {
functor::SoftmaxFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- softmax_out->matrix<T>(), log_);
+ functor(context->eigen_device<Device>(), logits_in.flat_inner_dims<T>(),
+ softmax_out->flat_inner_dims<T>(), log_);
}
}
diff --git a/tensorflow/core/kernels/softmax_op_gpu.cu.cc b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
index b63dcbb163..d1e677feb0 100644
--- a/tensorflow/core/kernels/softmax_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/softmax_op_gpu.cu.cc
@@ -134,11 +134,12 @@ class SoftmaxOpGPU : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& logits_in_ = context->input(0);
- auto logits_in = logits_in_.matrix<T>();
+ OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(logits_in_.shape()),
+ errors::InvalidArgument("logits must have >= 1 dimension, got ",
+ logits_in_.shape().DebugString()));
+ auto logits_in = logits_in_.flat_inner_dims<T>();
const int rows = logits_in.dimension(0);
const int cols = logits_in.dimension(1);
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in_.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
Tensor* softmax_out = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{0}, 0, logits_in_.shape(), &softmax_out));
diff --git a/tensorflow/core/kernels/spacetobatch_op.cc b/tensorflow/core/kernels/spacetobatch_op.cc
index fdc08ec8e3..64f1b0d661 100644
--- a/tensorflow/core/kernels/spacetobatch_op.cc
+++ b/tensorflow/core/kernels/spacetobatch_op.cc
@@ -42,29 +42,29 @@ typedef Eigen::GpuDevice GPUDevice;
namespace {
template <typename Device, typename T>
-void SpaceToBatchOpCompute(OpKernelContext* context,
- const Tensor& orig_input_tensor,
- const Tensor& orig_block_shape,
- const Tensor& orig_paddings) {
+Status SpaceToBatchOpCompute(OpKernelContext* context,
+ const Tensor& orig_input_tensor,
+ const Tensor& orig_block_shape,
+ const Tensor& orig_paddings) {
const int input_dims = orig_input_tensor.dims();
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(orig_block_shape.shape()),
- errors::InvalidArgument("block_shape rank should be 1 instead of ",
- orig_block_shape.dims()));
+ if (!TensorShapeUtils::IsVector(orig_block_shape.shape())) {
+ return errors::InvalidArgument("block_shape rank should be 1 instead of ",
+ orig_block_shape.dims());
+ }
const int block_dims = orig_block_shape.dim_size(0);
- OP_REQUIRES(
- context, orig_input_tensor.dims() >= 1 + block_dims,
- errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
- " instead of ", orig_input_tensor.dims()));
-
- OP_REQUIRES(context,
- TensorShapeUtils::IsMatrix(orig_paddings.shape()) &&
- block_dims == orig_paddings.dim_size(0) &&
- 2 == orig_paddings.dim_size(1),
- errors::InvalidArgument("paddings should have shape [",
- block_dims, ", 2] instead of ",
- orig_paddings.shape().DebugString()));
+ if (orig_input_tensor.dims() < 1 + block_dims) {
+ return errors::InvalidArgument("input rank should be >= ", 1 + block_dims,
+ " instead of ", orig_input_tensor.dims());
+ }
+
+ if (!(TensorShapeUtils::IsMatrix(orig_paddings.shape()) &&
+ block_dims == orig_paddings.dim_size(0) &&
+ 2 == orig_paddings.dim_size(1))) {
+ return errors::InvalidArgument("paddings should have shape [", block_dims,
+ ", 2] instead of ",
+ orig_paddings.shape().DebugString());
+ }
// To avoid out-of-bounds access in the case that the block_shape and/or
// paddings tensors are concurrently modified, we must copy the values.
@@ -101,22 +101,23 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
for (int block_dim = 0; block_dim < block_dims; ++block_dim) {
block_shape_product *= block_shape[block_dim];
}
- OP_REQUIRES(
- context, block_shape_product > 0,
- errors::InvalidArgument("Product of block sizes must be positive, got ",
- block_shape_product));
+ if (block_shape_product <= 0) {
+ return errors::InvalidArgument(
+ "Product of block sizes must be positive, got ", block_shape_product);
+ }
const int internal_block_dims =
block_dims - removed_prefix_block_dims - removed_suffix_block_dims;
- OP_REQUIRES(context, internal_block_dims <= kMaxSpaceToBatchBlockDims,
- errors::InvalidArgument(
- "Maximum number of non-combined block dimensions is ",
- internal_block_dims, " but must not exceed ",
- kMaxSpaceToBatchBlockDims));
+ if (internal_block_dims > kMaxSpaceToBatchBlockDims) {
+ return errors::InvalidArgument(
+ "Maximum number of non-combined block dimensions is ",
+ internal_block_dims, " but must not exceed ",
+ kMaxSpaceToBatchBlockDims);
+ }
if (internal_block_dims == 0) {
context->set_output(0, orig_input_tensor);
- return;
+ return Status::OK();
}
// For the purpose of computing the result, the input will be treated as
@@ -146,16 +147,18 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
block_dim < block_dims - removed_suffix_block_dims; ++block_dim) {
const int64 pad_start = paddings[2 * block_dim],
pad_end = paddings[2 * block_dim + 1];
- OP_REQUIRES(context, pad_start >= 0 && pad_end >= 0,
- errors::InvalidArgument("Paddings must be non-negative"));
+ if (pad_start < 0 || pad_end < 0) {
+ return errors::InvalidArgument("Paddings must be non-negative");
+ }
const int64 input_size = orig_input_tensor.dim_size(block_dim + 1);
const int64 block_shape_value = block_shape[block_dim];
const int64 padded_size = input_size + pad_start + pad_end;
- OP_REQUIRES(
- context, padded_size % block_shape_value == 0,
- errors::InvalidArgument("padded_shape[", block_dim, "]=", padded_size,
- " is not divisible by block_shape[", block_dim,
- "]=", block_shape_value));
+ if (padded_size % block_shape_value != 0) {
+ return errors::InvalidArgument("padded_shape[", block_dim,
+ "]=", padded_size,
+ " is not divisible by block_shape[",
+ block_dim, "]=", block_shape_value);
+ }
internal_input_shape.AddDim(input_size);
const int64 output_size = padded_size / block_shape_value;
internal_output_shape.AddDim(output_size);
@@ -174,29 +177,29 @@ void SpaceToBatchOpCompute(OpKernelContext* context,
// Allocate output tensor.
Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(0, external_output_shape,
- &output_tensor));
+ TF_RETURN_IF_ERROR(
+ context->allocate_output(0, external_output_shape, &output_tensor));
const int64* internal_paddings = &paddings[2 * removed_prefix_block_dims];
const int64* internal_block_shape = &block_shape[removed_prefix_block_dims];
switch (internal_block_dims) {
-#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
- case NUM_BLOCK_DIMS: { \
- OP_REQUIRES_OK( \
- context, \
- (functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \
- context->eigen_device<Device>(), \
- orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
- internal_input_shape.dim_sizes()), \
- internal_block_shape, internal_paddings, \
- output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
- internal_output_shape.dim_sizes())))); \
- } break; \
+#define TF_SPACETOBATCH_BLOCK_DIMS_CASE(NUM_BLOCK_DIMS) \
+ case NUM_BLOCK_DIMS: { \
+ TF_RETURN_IF_ERROR( \
+ functor::SpaceToBatchFunctor<Device, T, NUM_BLOCK_DIMS, false>()( \
+ context->eigen_device<Device>(), \
+ orig_input_tensor.shaped<T, NUM_BLOCK_DIMS + 2>( \
+ internal_input_shape.dim_sizes()), \
+ internal_block_shape, internal_paddings, \
+ output_tensor->shaped<T, NUM_BLOCK_DIMS + 2>( \
+ internal_output_shape.dim_sizes()))); \
+ } break; \
/**/
TF_SPACETOBATCH_FOR_EACH_NUM_BLOCK_DIMS(TF_SPACETOBATCH_BLOCK_DIMS_CASE)
#undef TF_SPACETOBATCH_BLOCK_DIMS_CASE
}
+ return Status::OK();
}
} // namespace
@@ -211,8 +214,9 @@ class SpaceToBatchNDOp : public OpKernel {
const Tensor& orig_input_tensor = context->input(0);
const Tensor& orig_block_shape = context->input(1);
const Tensor& orig_paddings = context->input(2);
- SpaceToBatchOpCompute<Device, T>(context, orig_input_tensor,
- orig_block_shape, orig_paddings);
+ OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>(
+ context, orig_input_tensor, orig_block_shape,
+ orig_paddings));
}
};
@@ -241,7 +245,8 @@ class SpaceToBatchOp : public OpKernel {
OP_REQUIRES(context, kRequiredDims == dims,
errors::InvalidArgument("Input rank should be: ", kRequiredDims,
"instead of: ", dims));
- SpaceToBatchOpCompute<Device, T>(context, in0, block_shape_, in1);
+ OP_REQUIRES_OK(context, SpaceToBatchOpCompute<Device, T>(
+ context, in0, block_shape_, in1));
}
private:
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 1e3e92a68a..59fdc2262a 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -304,6 +305,9 @@ class StridedSliceAssignOp : public OpKernel {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ mutex_lock ml(*v->mu());
+ OP_REQUIRES_OK(context,
+ PrepareToUpdateVariable<Device, T>(context, v->tensor()));
old_lhs = *v->tensor();
OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index f288e124ee..d3c4f62071 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -39,8 +39,15 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
// GetInputTensor which will signal a failure.
std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
+ bool any_resource = false;
+ for (auto i : input_ids) {
+ if (ctx->input_dtype(i) == DT_RESOURCE) {
+ any_resource = true;
+ break;
+ }
+ }
std::vector<mutex_lock> locks;
- if (!do_lock) {
+ if (!do_lock && !any_resource) {
return locks;
}
std::vector<mutex*> mutexes;
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 7e56e15450..765335d3a0 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -80,18 +80,8 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
Var* var;
TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var));
core::ScopedUnref unref_var(var);
- if (lock_held) {
- TF_RETURN_IF_ERROR(
- PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
- *out = *var->tensor();
- } else {
- mutex_lock ml(*var->mu());
- if (!sparse) {
- TF_RETURN_IF_ERROR(
- PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
- }
- *out = *var->tensor();
- }
+ TF_RETURN_IF_ERROR(PrepareToUpdateVariable<Device, T>(ctx, var->tensor()));
+ *out = *var->tensor();
return Status::OK();
}
*out = ctx->mutable_input(input, lock_held);