aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/data/map_defun_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/data/map_defun_op.cc')
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc10
1 files changed, 7 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 607d0ca028..3c562fc7f3 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -23,13 +23,13 @@ limitations under the License.
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
+namespace data {
namespace {
void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
bool always_collect_stats) {
opts->step_id = ctx->step_id();
opts->rendezvous = ctx->rendezvous();
- opts->cancellation_manager = ctx->cancellation_manager();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
@@ -117,10 +117,13 @@ class MapDefunOp : public AsyncOpKernel {
for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
auto* call_frame =
new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
+ CancellationManager* c_mgr = new CancellationManager;
+ opts_.cancellation_manager = c_mgr;
ctx->function_library()->Run(
opts_, func_handle_, call_frame,
- [call_frame, refcounted](const Status& func_status) {
+ [call_frame, refcounted, c_mgr](const Status& func_status) {
delete call_frame;
+ delete c_mgr;
refcounted->UpdateStatus(func_status);
refcounted->Unref();
});
@@ -189,8 +192,9 @@ class MapDefunOp : public AsyncOpKernel {
const OpKernel* kernel_;
const size_t iter_;
};
-}; // namespace
+};
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
} // namespace
+} // namespace data
} // namespace tensorflow