aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/tensor_array_ops.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-29 15:04:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-29 15:24:53 -0800
commit029539120a186073506a09983c72e7b4dc24ee74 (patch)
tree1d6c07695a24dcbd11d1904170e923a86b095956 /tensorflow/core/kernels/tensor_array_ops.cc
parent6cbc3696a14fe909843ca4fc0e9d1670861110b4 (diff)
Adds V3 version of TensorArray ops. All use resource handles and TensorArrayV3 has a flow output.
Change: 143210240
Diffstat (limited to 'tensorflow/core/kernels/tensor_array_ops.cc')
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc180
1 files changed, 156 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index 3fed57a125..166aa8fb34 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -73,12 +73,16 @@ Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) {
Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) {
string container;
string ta_handle;
- TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
- ResourceMgr* rm = ctx->resource_manager();
- if (rm == nullptr) return errors::Internal("No resource manager.");
- TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
- container + ta_handle, tensor_array));
- return Status::OK();
+ if (ctx->input_dtype(0) != DT_RESOURCE) {
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle));
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) return errors::Internal("No resource manager.");
+ TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
+ container + ta_handle, tensor_array));
+ return Status::OK();
+ } else {
+ return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array);
+ }
}
Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) {
@@ -117,8 +121,18 @@ class TensorArrayCreationOp : public OpKernel {
if (IsRefType(ctx->expected_output_dtype(0))) {
ctx->set_output_ref(0, output_tensor_array->mu(),
output_tensor_array->handle());
- } else {
+ } else if (ctx->expected_output_dtype(0) == DT_STRING) {
ctx->set_output(0, *output_tensor_array->handle());
+ } else {
+ Tensor* handle;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
+ handle->flat<ResourceHandle>()(0) =
+ output_tensor_array->resource_handle(ctx);
+ }
+ if (ctx->num_outputs() == 2) {
+ // Create the flow output.
+ Tensor* flow;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow));
}
}
@@ -165,14 +179,15 @@ class TensorArrayOp : public TensorArrayCreationOp {
handle(0) = "_tensor_arrays";
handle(1) = unique_tensor_array_name;
+ auto key = strings::StrCat(handle(0), unique_tensor_array_name);
+
TensorArray* tensor_array = new TensorArray(
- dtype_, *tensor_array_output_handle, size, element_shape_,
+ key, dtype_, *tensor_array_output_handle, size, element_shape_,
dynamic_size_, false /* multiple_writes_aggregate */,
false /* is_grad */, -1 /* marked_size */, clear_after_read_);
- TF_RETURN_IF_ERROR(rm->Create(
- ctx->step_container()->name(),
- strings::StrCat(handle(0), unique_tensor_array_name), tensor_array));
+ TF_RETURN_IF_ERROR(
+ rm->Create(ctx->step_container()->name(), key, tensor_array));
*output_tensor_array = tensor_array;
@@ -192,6 +207,8 @@ class TensorArrayOp : public TensorArrayCreationOp {
REGISTER_KERNEL_BUILDER(Name("TensorArray").Device(DEVICE_CPU), TensorArrayOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU),
TensorArrayOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
+ TensorArrayOp);
#if GOOGLE_CUDA
@@ -207,6 +224,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU),
.TypeConstraint<type>("dtype") \
.HostMemory("size") \
.HostMemory("handle"), \
+ TensorArrayOp); \
+ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("size") \
+ .HostMemory("handle"), \
TensorArrayOp);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -229,12 +252,23 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
TensorArray** output_tensor_array) override {
string container;
string tensor_array_name;
- TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name));
-
- if (container != "_tensor_arrays") {
- return errors::InvalidArgument(
- "Input container should be '_tensor_arrays', but received '",
- container, "'");
+ if (ctx->input_dtype(0) != DT_RESOURCE) {
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name));
+ if (container != "_tensor_arrays") {
+ return errors::InvalidArgument(
+ "Input container should be '_tensor_arrays', but received '",
+ container, "'");
+ }
+ } else {
+ container = "_tensor_arrays";
+ auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ if (StringPiece(resource.name()).substr(0, container.size()) !=
+ container) {
+ return errors::InvalidArgument("Wrong input container. ",
+ resource.name());
+ }
+ tensor_array_name =
+ StringPiece(resource.name()).substr(container.size()).ToString();
}
auto output_handle = tensor_array_output_handle->flat<string>();
@@ -264,11 +298,13 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
"writes are performed to the same index.");
}
- auto creator = [this, tensor_array, array_size, marked_size,
- tensor_array_output_handle](TensorArray** ret) -> Status {
+ const auto key = strings::StrCat(output_handle(0), output_handle(1));
+ auto creator = [this, key, tensor_array, array_size, marked_size,
+ tensor_array_output_handle,
+ output_handle](TensorArray** ret) -> Status {
*ret = new TensorArray(
- tensor_array->ElemType(), *tensor_array_output_handle, array_size,
- tensor_array->ElemShape(), false /* dynamic_size */,
+ key, tensor_array->ElemType(), *tensor_array_output_handle,
+ array_size, tensor_array->ElemShape(), false /* dynamic_size */,
true /* multiple_writes_aggregate */, true /* is_grad */,
marked_size /* marked_size */, true /* close_after_read */);
TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array));
@@ -276,9 +312,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
};
Status s = rm->LookupOrCreate<TensorArray>(
- ctx->step_container()->name(),
- strings::StrCat(output_handle(0), output_handle(1)),
- output_tensor_array, creator);
+ ctx->step_container()->name(), key, output_tensor_array, creator);
(*output_tensor_array)->Unref();
return s;
@@ -297,6 +331,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad").Device(DEVICE_CPU),
TensorArrayGradOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2").Device(DEVICE_CPU),
TensorArrayGradOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3").Device(DEVICE_CPU),
+ TensorArrayGradOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad")
.Device(DEVICE_GPU)
@@ -308,6 +344,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2")
.HostMemory("handle")
.HostMemory("grad_handle"),
TensorArrayGradOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3")
+ .Device(DEVICE_GPU)
+ .HostMemory("handle")
+ .HostMemory("grad_handle"),
+ TensorArrayGradOp);
// WRITE **********************************************************************
@@ -353,6 +394,9 @@ class TensorArrayWriteOp : public OpKernel {
TensorArrayWriteOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("TensorArrayWriteV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ TensorArrayWriteOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TensorArrayWriteV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
TensorArrayWriteOp<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_WRITE);
@@ -373,6 +417,12 @@ TF_CALL_ALL_TYPES(REGISTER_WRITE);
.TypeConstraint<type>("T") \
.HostMemory("handle") \
.HostMemory("index"), \
+ TensorArrayWriteOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("handle") \
+ .HostMemory("index"), \
TensorArrayWriteOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -430,6 +480,10 @@ class TensorArrayReadOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
+ TensorArrayReadOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype"), \
TensorArrayReadOp<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_READ)
@@ -450,6 +504,12 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
.TypeConstraint<type>("dtype") \
.HostMemory("handle") \
.HostMemory("index"), \
+ TensorArrayReadOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("handle") \
+ .HostMemory("index"), \
TensorArrayReadOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -599,6 +659,11 @@ class TensorArrayPackOrGatherOp : public OpKernel {
Name("TensorArrayGatherV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
+ TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TensorArrayGatherV3") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype"), \
TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>);
TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK);
@@ -631,6 +696,13 @@ REGISTER_GATHER_AND_PACK(bfloat16);
.TypeConstraint<type>("dtype") \
.HostMemory("indices") \
.HostMemory("handle"), \
+ TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TensorArrayGatherV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("indices") \
+ .HostMemory("handle"), \
TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -654,6 +726,13 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("indices")
.HostMemory("handle"),
TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>);
+REGISTER_KERNEL_BUILDER(
+ Name("TensorArrayGatherV3")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("dtype")
+ .HostMemory("indices")
+ .HostMemory("handle"),
+ TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>);
#endif // GOOGLE_CUDA
@@ -808,6 +887,12 @@ class TensorArrayConcatOp : public OpKernel {
.TypeConstraint<type>("dtype") \
.HostMemory("lengths") \
.HostMemory("handle"), \
+ TensorArrayConcatOp<CPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("lengths") \
+ .HostMemory("handle"), \
TensorArrayConcatOp<CPUDevice, type>)
TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT);
@@ -832,6 +917,12 @@ REGISTER_CONCAT(bfloat16);
.TypeConstraint<type>("dtype") \
.HostMemory("lengths") \
.HostMemory("handle"), \
+ TensorArrayConcatOp<GPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("lengths") \
+ .HostMemory("handle"), \
TensorArrayConcatOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -853,6 +944,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2")
.HostMemory("lengths")
.HostMemory("handle"),
TensorArrayConcatOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int32>("dtype")
+ .HostMemory("lengths")
+ .HostMemory("handle"),
+ TensorArrayConcatOp<CPUDevice, int32>);
#endif // GOOGLE_CUDA
@@ -999,6 +1096,12 @@ class TensorArrayUnpackOrScatterOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
TensorArrayUnpackOrScatterOp<CPUDevice, type, \
+ false /* LEGACY_UNPACK */>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TensorArrayScatterV3") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T"), \
+ TensorArrayUnpackOrScatterOp<CPUDevice, type, \
false /* LEGACY_UNPACK */>);
TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
@@ -1029,6 +1132,14 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
.HostMemory("indices") \
.HostMemory("handle"), \
TensorArrayUnpackOrScatterOp<GPUDevice, type, \
+ false /* LEGACY_UNPACK */>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TensorArrayScatterV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("indices") \
+ .HostMemory("handle"), \
+ TensorArrayUnpackOrScatterOp<GPUDevice, type, \
false /* LEGACY_UNPACK */>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -1166,6 +1277,9 @@ class TensorArraySplitOp : public OpKernel {
TensorArraySplitOp<CPUDevice, type>); \
REGISTER_KERNEL_BUILDER( \
Name("TensorArraySplitV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ TensorArraySplitOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TensorArraySplitV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
TensorArraySplitOp<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER_SPLIT);
@@ -1185,6 +1299,12 @@ TF_CALL_ALL_TYPES(REGISTER_SPLIT);
.TypeConstraint<type>("T") \
.HostMemory("lengths") \
.HostMemory("handle"), \
+ TensorArraySplitOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV3") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("lengths") \
+ .HostMemory("handle"), \
TensorArraySplitOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
@@ -1214,6 +1334,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArraySize").Device(DEVICE_CPU),
TensorArraySizeOp);
REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2").Device(DEVICE_CPU),
TensorArraySizeOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3").Device(DEVICE_CPU),
+ TensorArraySizeOp);
REGISTER_KERNEL_BUILDER(Name("TensorArraySize")
.Device(DEVICE_GPU)
@@ -1225,6 +1347,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2")
.HostMemory("handle")
.HostMemory("size"),
TensorArraySizeOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3")
+ .Device(DEVICE_GPU)
+ .HostMemory("handle")
+ .HostMemory("size"),
+ TensorArraySizeOp);
// CLOSE
// **********************************************************************
@@ -1257,6 +1384,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayClose").Device(DEVICE_CPU),
TensorArrayCloseOp);
REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV2").Device(DEVICE_CPU),
TensorArrayCloseOp);
+REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV3").Device(DEVICE_CPU),
+ TensorArrayCloseOp);
REGISTER_KERNEL_BUILDER(
Name("TensorArrayClose").Device(DEVICE_GPU).HostMemory("handle"),
@@ -1264,5 +1393,8 @@ REGISTER_KERNEL_BUILDER(
REGISTER_KERNEL_BUILDER(
Name("TensorArrayCloseV2").Device(DEVICE_GPU).HostMemory("handle"),
TensorArrayCloseOp);
+REGISTER_KERNEL_BUILDER(
+ Name("TensorArrayCloseV3").Device(DEVICE_GPU).HostMemory("handle"),
+ TensorArrayCloseOp);
} // namespace tensorflow