diff options
author | 2016-12-29 15:04:03 -0800 | |
---|---|---|
committer | 2016-12-29 15:24:53 -0800 | |
commit | 029539120a186073506a09983c72e7b4dc24ee74 (patch) | |
tree | 1d6c07695a24dcbd11d1904170e923a86b095956 /tensorflow/core/kernels/tensor_array_ops.cc | |
parent | 6cbc3696a14fe909843ca4fc0e9d1670861110b4 (diff) |
Adds V3 version of TensorArray ops. All use resource handles and TensorArrayV3 has a flow output.
Change: 143210240
Diffstat (limited to 'tensorflow/core/kernels/tensor_array_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/tensor_array_ops.cc | 180 |
1 files changed, 156 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 3fed57a125..166aa8fb34 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -73,12 +73,16 @@ Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { string container; string ta_handle; - TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); - ResourceMgr* rm = ctx->resource_manager(); - if (rm == nullptr) return errors::Internal("No resource manager."); - TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), - container + ta_handle, tensor_array)); - return Status::OK(); + if (ctx->input_dtype(0) != DT_RESOURCE) { + TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) return errors::Internal("No resource manager."); + TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), + container + ta_handle, tensor_array)); + return Status::OK(); + } else { + return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array); + } } Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { @@ -117,8 +121,18 @@ class TensorArrayCreationOp : public OpKernel { if (IsRefType(ctx->expected_output_dtype(0))) { ctx->set_output_ref(0, output_tensor_array->mu(), output_tensor_array->handle()); - } else { + } else if (ctx->expected_output_dtype(0) == DT_STRING) { ctx->set_output(0, *output_tensor_array->handle()); + } else { + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->flat<ResourceHandle>()(0) = + output_tensor_array->resource_handle(ctx); + } + if (ctx->num_outputs() == 2) { + // Create the flow output. + Tensor* flow; + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow)); } } @@ -165,14 +179,15 @@ class TensorArrayOp : public TensorArrayCreationOp { handle(0) = "_tensor_arrays"; handle(1) = unique_tensor_array_name; + auto key = strings::StrCat(handle(0), unique_tensor_array_name); + TensorArray* tensor_array = new TensorArray( - dtype_, *tensor_array_output_handle, size, element_shape_, + key, dtype_, *tensor_array_output_handle, size, element_shape_, dynamic_size_, false /* multiple_writes_aggregate */, false /* is_grad */, -1 /* marked_size */, clear_after_read_); - TF_RETURN_IF_ERROR(rm->Create( - ctx->step_container()->name(), - strings::StrCat(handle(0), unique_tensor_array_name), tensor_array)); + TF_RETURN_IF_ERROR( + rm->Create(ctx->step_container()->name(), key, tensor_array)); *output_tensor_array = tensor_array; @@ -192,6 +207,8 @@ class TensorArrayOp : public TensorArrayCreationOp { REGISTER_KERNEL_BUILDER(Name("TensorArray").Device(DEVICE_CPU), TensorArrayOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU), TensorArrayOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU), + TensorArrayOp); #if GOOGLE_CUDA @@ -207,6 +224,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU), .TypeConstraint<type>("dtype") \ .HostMemory("size") \ .HostMemory("handle"), \ + TensorArrayOp); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("size") \ + .HostMemory("handle"), \ TensorArrayOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -229,12 +252,23 @@ class TensorArrayGradOp : public TensorArrayCreationOp { TensorArray** output_tensor_array) override { string container; string tensor_array_name; - TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name)); - - if (container != "_tensor_arrays") { - return errors::InvalidArgument( - "Input container should be '_tensor_arrays', but received '", - container, "'"); + if (ctx->input_dtype(0) != DT_RESOURCE) { + TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name)); + if (container != "_tensor_arrays") { + return errors::InvalidArgument( + "Input container should be '_tensor_arrays', but received '", + container, "'"); + } + } else { + container = "_tensor_arrays"; + auto resource = ctx->input(0).flat<ResourceHandle>()(0); + if (StringPiece(resource.name()).substr(0, container.size()) != + container) { + return errors::InvalidArgument("Wrong input container. ", + resource.name()); + } + tensor_array_name = + StringPiece(resource.name()).substr(container.size()).ToString(); } auto output_handle = tensor_array_output_handle->flat<string>(); @@ -264,11 +298,13 @@ class TensorArrayGradOp : public TensorArrayCreationOp { "writes are performed to the same index."); } - auto creator = [this, tensor_array, array_size, marked_size, - tensor_array_output_handle](TensorArray** ret) -> Status { + const auto key = strings::StrCat(output_handle(0), output_handle(1)); + auto creator = [this, key, tensor_array, array_size, marked_size, + tensor_array_output_handle, + output_handle](TensorArray** ret) -> Status { *ret = new TensorArray( - tensor_array->ElemType(), *tensor_array_output_handle, array_size, - tensor_array->ElemShape(), false /* dynamic_size */, + key, tensor_array->ElemType(), *tensor_array_output_handle, + array_size, tensor_array->ElemShape(), false /* dynamic_size */, true /* multiple_writes_aggregate */, true /* is_grad */, marked_size /* marked_size */, true /* close_after_read */); TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array)); @@ -276,9 +312,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { }; Status s = rm->LookupOrCreate<TensorArray>( - ctx->step_container()->name(), - strings::StrCat(output_handle(0), output_handle(1)), - output_tensor_array, creator); + ctx->step_container()->name(), key, output_tensor_array, creator); (*output_tensor_array)->Unref(); return s; @@ -297,6 +331,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad").Device(DEVICE_CPU), TensorArrayGradOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2").Device(DEVICE_CPU), TensorArrayGradOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3").Device(DEVICE_CPU), + TensorArrayGradOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad") .Device(DEVICE_GPU) @@ -308,6 +344,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2") .HostMemory("handle") .HostMemory("grad_handle"), TensorArrayGradOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3") + .Device(DEVICE_GPU) + .HostMemory("handle") + .HostMemory("grad_handle"), + TensorArrayGradOp); // WRITE ********************************************************************** @@ -353,6 +394,9 @@ class TensorArrayWriteOp : public OpKernel { TensorArrayWriteOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("TensorArrayWriteV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + TensorArrayWriteOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayWriteV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ TensorArrayWriteOp<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_WRITE); @@ -373,6 +417,12 @@ TF_CALL_ALL_TYPES(REGISTER_WRITE); .TypeConstraint<type>("T") \ .HostMemory("handle") \ .HostMemory("index"), \ + TensorArrayWriteOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("handle") \ + .HostMemory("index"), \ TensorArrayWriteOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -430,6 +480,10 @@ class TensorArrayReadOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("dtype"), \ + TensorArrayReadOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("dtype"), \ TensorArrayReadOp<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_READ) @@ -450,6 +504,12 @@ TF_CALL_ALL_TYPES(REGISTER_READ) .TypeConstraint<type>("dtype") \ .HostMemory("handle") \ .HostMemory("index"), \ + TensorArrayReadOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("handle") \ + .HostMemory("index"), \ TensorArrayReadOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -599,6 +659,11 @@ class TensorArrayPackOrGatherOp : public OpKernel { Name("TensorArrayGatherV2") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("dtype"), \ + TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayGatherV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("dtype"), \ TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK); @@ -631,6 +696,13 @@ REGISTER_GATHER_AND_PACK(bfloat16); .TypeConstraint<type>("dtype") \ .HostMemory("indices") \ .HostMemory("handle"), \ + TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayGatherV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("indices") \ + .HostMemory("handle"), \ TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -654,6 +726,13 @@ REGISTER_KERNEL_BUILDER( .HostMemory("indices") .HostMemory("handle"), TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); +REGISTER_KERNEL_BUILDER( + Name("TensorArrayGatherV3") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("dtype") + .HostMemory("indices") + .HostMemory("handle"), + TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); #endif // GOOGLE_CUDA @@ -808,6 +887,12 @@ class TensorArrayConcatOp : public OpKernel { .TypeConstraint<type>("dtype") \ .HostMemory("lengths") \ .HostMemory("handle"), \ + TensorArrayConcatOp<CPUDevice, type>) \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("lengths") \ + .HostMemory("handle"), \ TensorArrayConcatOp<CPUDevice, type>) TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); @@ -832,6 +917,12 @@ REGISTER_CONCAT(bfloat16); .TypeConstraint<type>("dtype") \ .HostMemory("lengths") \ .HostMemory("handle"), \ + TensorArrayConcatOp<GPUDevice, type>) \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("lengths") \ + .HostMemory("handle"), \ TensorArrayConcatOp<GPUDevice, type>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -853,6 +944,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2") .HostMemory("lengths") .HostMemory("handle"), TensorArrayConcatOp<CPUDevice, int32>); +REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("dtype") + .HostMemory("lengths") + .HostMemory("handle"), + TensorArrayConcatOp<CPUDevice, int32>); #endif // GOOGLE_CUDA @@ -999,6 +1096,12 @@ class TensorArrayUnpackOrScatterOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T"), \ TensorArrayUnpackOrScatterOp<CPUDevice, type, \ + false /* LEGACY_UNPACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayScatterV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T"), \ + TensorArrayUnpackOrScatterOp<CPUDevice, type, \ false /* LEGACY_UNPACK */>); TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK); @@ -1029,6 +1132,14 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK); .HostMemory("indices") \ .HostMemory("handle"), \ TensorArrayUnpackOrScatterOp<GPUDevice, type, \ + false /* LEGACY_UNPACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayScatterV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("indices") \ + .HostMemory("handle"), \ + TensorArrayUnpackOrScatterOp<GPUDevice, type, \ false /* LEGACY_UNPACK */>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -1166,6 +1277,9 @@ class TensorArraySplitOp : public OpKernel { TensorArraySplitOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("TensorArraySplitV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + TensorArraySplitOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArraySplitV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ TensorArraySplitOp<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_SPLIT); @@ -1185,6 +1299,12 @@ TF_CALL_ALL_TYPES(REGISTER_SPLIT); .TypeConstraint<type>("T") \ .HostMemory("lengths") \ .HostMemory("handle"), \ + TensorArraySplitOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("lengths") \ + .HostMemory("handle"), \ TensorArraySplitOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -1214,6 +1334,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArraySize").Device(DEVICE_CPU), TensorArraySizeOp); REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2").Device(DEVICE_CPU), TensorArraySizeOp); +REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3").Device(DEVICE_CPU), + TensorArraySizeOp); REGISTER_KERNEL_BUILDER(Name("TensorArraySize") .Device(DEVICE_GPU) @@ -1225,6 +1347,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2") .HostMemory("handle") .HostMemory("size"), TensorArraySizeOp); +REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3") + .Device(DEVICE_GPU) + .HostMemory("handle") + .HostMemory("size"), + TensorArraySizeOp); // CLOSE // ********************************************************************** @@ -1257,6 +1384,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayClose").Device(DEVICE_CPU), TensorArrayCloseOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV2").Device(DEVICE_CPU), TensorArrayCloseOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV3").Device(DEVICE_CPU), + TensorArrayCloseOp); REGISTER_KERNEL_BUILDER( Name("TensorArrayClose").Device(DEVICE_GPU).HostMemory("handle"), @@ -1264,5 +1393,8 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("TensorArrayCloseV2").Device(DEVICE_GPU).HostMemory("handle"), TensorArrayCloseOp); +REGISTER_KERNEL_BUILDER( + Name("TensorArrayCloseV3").Device(DEVICE_GPU).HostMemory("handle"), + TensorArrayCloseOp); } // namespace tensorflow |