diff options
-rw-r--r-- | tensorflow/contrib/framework/python/ops/prettyprint_ops.py | 3 | ||||
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 2 | ||||
-rw-r--r-- | tensorflow/core/kernels/tensor_array.h | 17 | ||||
-rw-r--r-- | tensorflow/core/kernels/tensor_array_ops.cc | 180 | ||||
-rw-r--r-- | tensorflow/core/ops/data_flow_ops.cc | 229 | ||||
-rw-r--r-- | tensorflow/core/public/version.h | 2 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/tensor_array_ops_test.py | 21 | ||||
-rw-r--r-- | tensorflow/python/ops/hidden_ops.txt | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/tensor_array_grad.py | 43 | ||||
-rw-r--r-- | tensorflow/python/ops/tensor_array_ops.py | 152 |
10 files changed, 492 insertions, 169 deletions
diff --git a/tensorflow/contrib/framework/python/ops/prettyprint_ops.py b/tensorflow/contrib/framework/python/ops/prettyprint_ops.py index 5dc6d1aec0..2637aa15ea 100644 --- a/tensorflow/contrib/framework/python/ops/prettyprint_ops.py +++ b/tensorflow/contrib/framework/python/ops/prettyprint_ops.py @@ -161,7 +161,8 @@ def print_op(input_, with ops.control_dependencies([p]): input_ = tensor_array_ops.TensorArray(dtype=input_.dtype, - handle=input_.handle) + handle=input_.handle, + flow=input_.flow) else: raise ValueError("input_ must be of type " "Tensor, SparseTensor or TensorArray") diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 5241a4d916..1a73a3d0f8 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -276,6 +276,7 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(ResourceHandle); #undef REGISTER_GPU_HOST_KERNEL @@ -345,6 +346,7 @@ REGISTER_GPU_HOST_KERNEL(int32); REGISTER_GPU_HOST_REF_KERNEL(int32); REGISTER_GPU_HOST_KERNEL(string); REGISTER_GPU_HOST_REF_KERNEL(string); +REGISTER_GPU_HOST_KERNEL(ResourceHandle); #undef REGISTER_GPU_HOST_KERNEL #undef REGISTER_GPU_HOST_REF_KERNEL diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index ae1700cd0a..4704130994 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -132,11 +132,12 @@ class TensorArray : public ResourceBase { // 'N' elements. While the underlying storage is a std::vector and // can hold more than MAX_INT entries, in practice we do not expect // users to construct this many Tensors for storage in a TensorArray. - TensorArray(const DataType& dtype, const Tensor& handle, int32 N, - const PartialTensorShape& element_shape, bool dynamic_size, - bool multiple_writes_aggregate, bool is_grad, int32 marked_size, - bool clear_after_read) - : dtype_(dtype), + TensorArray(const string& key, const DataType& dtype, const Tensor& handle, + int32 N, const PartialTensorShape& element_shape, + bool dynamic_size, bool multiple_writes_aggregate, bool is_grad, + int32 marked_size, bool clear_after_read) + : key_(key), + dtype_(dtype), handle_(handle), closed_(false), dynamic_size_(dynamic_size), @@ -334,6 +335,10 @@ class TensorArray : public ResourceBase { mutex* mu() { return &mu_; } Tensor* handle() { return &handle_; } + ResourceHandle resource_handle(OpKernelContext* ctx) { + return MakePerStepResourceHandle<TensorArray>(ctx, key_); + } + private: Status LockedWrite(OpKernelContext* ctx, const int32 index, PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_); @@ -355,6 +360,8 @@ class TensorArray : public ResourceBase { return Status::OK(); } + const string key_; + const DataType dtype_; Tensor handle_; diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 3fed57a125..166aa8fb34 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -73,12 +73,16 @@ Status GetHandle(OpKernelContext* ctx, string* container, string* ta_handle) { Status GetTensorArray(OpKernelContext* ctx, TensorArray** tensor_array) { string container; string ta_handle; - TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); - ResourceMgr* rm = ctx->resource_manager(); - if (rm == nullptr) return errors::Internal("No resource manager."); - TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), - container + ta_handle, tensor_array)); - return Status::OK(); + if (ctx->input_dtype(0) != DT_RESOURCE) { + TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &ta_handle)); + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) return errors::Internal("No resource manager."); + TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), + container + ta_handle, tensor_array)); + return Status::OK(); + } else { + return LookupResource(ctx, HandleFromInput(ctx, 0), tensor_array); + } } Status SetupFlowControlInputs(OpKernelContext* ctx, bool set_output) { @@ -117,8 +121,18 @@ class TensorArrayCreationOp : public OpKernel { if (IsRefType(ctx->expected_output_dtype(0))) { ctx->set_output_ref(0, output_tensor_array->mu(), output_tensor_array->handle()); - } else { + } else if (ctx->expected_output_dtype(0) == DT_STRING) { ctx->set_output(0, *output_tensor_array->handle()); + } else { + Tensor* handle; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); + handle->flat<ResourceHandle>()(0) = + output_tensor_array->resource_handle(ctx); + } + if (ctx->num_outputs() == 2) { + // Create the flow output. + Tensor* flow; + OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &flow)); } } @@ -165,14 +179,15 @@ class TensorArrayOp : public TensorArrayCreationOp { handle(0) = "_tensor_arrays"; handle(1) = unique_tensor_array_name; + auto key = strings::StrCat(handle(0), unique_tensor_array_name); + TensorArray* tensor_array = new TensorArray( - dtype_, *tensor_array_output_handle, size, element_shape_, + key, dtype_, *tensor_array_output_handle, size, element_shape_, dynamic_size_, false /* multiple_writes_aggregate */, false /* is_grad */, -1 /* marked_size */, clear_after_read_); - TF_RETURN_IF_ERROR(rm->Create( - ctx->step_container()->name(), - strings::StrCat(handle(0), unique_tensor_array_name), tensor_array)); + TF_RETURN_IF_ERROR( + rm->Create(ctx->step_container()->name(), key, tensor_array)); *output_tensor_array = tensor_array; @@ -192,6 +207,8 @@ class TensorArrayOp : public TensorArrayCreationOp { REGISTER_KERNEL_BUILDER(Name("TensorArray").Device(DEVICE_CPU), TensorArrayOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU), TensorArrayOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU), + TensorArrayOp); #if GOOGLE_CUDA @@ -207,6 +224,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV2").Device(DEVICE_CPU), .TypeConstraint<type>("dtype") \ .HostMemory("size") \ .HostMemory("handle"), \ + TensorArrayOp); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("size") \ + .HostMemory("handle"), \ TensorArrayOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -229,12 +252,23 @@ class TensorArrayGradOp : public TensorArrayCreationOp { TensorArray** output_tensor_array) override { string container; string tensor_array_name; - TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name)); - - if (container != "_tensor_arrays") { - return errors::InvalidArgument( - "Input container should be '_tensor_arrays', but received '", - container, "'"); + if (ctx->input_dtype(0) != DT_RESOURCE) { + TF_RETURN_IF_ERROR(GetHandle(ctx, &container, &tensor_array_name)); + if (container != "_tensor_arrays") { + return errors::InvalidArgument( + "Input container should be '_tensor_arrays', but received '", + container, "'"); + } + } else { + container = "_tensor_arrays"; + auto resource = ctx->input(0).flat<ResourceHandle>()(0); + if (StringPiece(resource.name()).substr(0, container.size()) != + container) { + return errors::InvalidArgument("Wrong input container. ", + resource.name()); + } + tensor_array_name = + StringPiece(resource.name()).substr(container.size()).ToString(); } auto output_handle = tensor_array_output_handle->flat<string>(); @@ -264,11 +298,13 @@ class TensorArrayGradOp : public TensorArrayCreationOp { "writes are performed to the same index."); } - auto creator = [this, tensor_array, array_size, marked_size, - tensor_array_output_handle](TensorArray** ret) -> Status { + const auto key = strings::StrCat(output_handle(0), output_handle(1)); + auto creator = [this, key, tensor_array, array_size, marked_size, + tensor_array_output_handle, + output_handle](TensorArray** ret) -> Status { *ret = new TensorArray( - tensor_array->ElemType(), *tensor_array_output_handle, array_size, - tensor_array->ElemShape(), false /* dynamic_size */, + key, tensor_array->ElemType(), *tensor_array_output_handle, + array_size, tensor_array->ElemShape(), false /* dynamic_size */, true /* multiple_writes_aggregate */, true /* is_grad */, marked_size /* marked_size */, true /* close_after_read */); TF_RETURN_IF_ERROR((*ret)->CopyShapesFrom(tensor_array)); @@ -276,9 +312,7 @@ class TensorArrayGradOp : public TensorArrayCreationOp { }; Status s = rm->LookupOrCreate<TensorArray>( - ctx->step_container()->name(), - strings::StrCat(output_handle(0), output_handle(1)), - output_tensor_array, creator); + ctx->step_container()->name(), key, output_tensor_array, creator); (*output_tensor_array)->Unref(); return s; @@ -297,6 +331,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad").Device(DEVICE_CPU), TensorArrayGradOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2").Device(DEVICE_CPU), TensorArrayGradOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3").Device(DEVICE_CPU), + TensorArrayGradOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayGrad") .Device(DEVICE_GPU) @@ -308,6 +344,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV2") .HostMemory("handle") .HostMemory("grad_handle"), TensorArrayGradOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayGradV3") + .Device(DEVICE_GPU) + .HostMemory("handle") + .HostMemory("grad_handle"), + TensorArrayGradOp); // WRITE ********************************************************************** @@ -353,6 +394,9 @@ class TensorArrayWriteOp : public OpKernel { TensorArrayWriteOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("TensorArrayWriteV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + TensorArrayWriteOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayWriteV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ TensorArrayWriteOp<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_WRITE); @@ -373,6 +417,12 @@ TF_CALL_ALL_TYPES(REGISTER_WRITE); .TypeConstraint<type>("T") \ .HostMemory("handle") \ .HostMemory("index"), \ + TensorArrayWriteOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayWriteV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("handle") \ + .HostMemory("index"), \ TensorArrayWriteOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -430,6 +480,10 @@ class TensorArrayReadOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV2") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("dtype"), \ + TensorArrayReadOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("dtype"), \ TensorArrayReadOp<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_READ) @@ -450,6 +504,12 @@ TF_CALL_ALL_TYPES(REGISTER_READ) .TypeConstraint<type>("dtype") \ .HostMemory("handle") \ .HostMemory("index"), \ + TensorArrayReadOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayReadV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("handle") \ + .HostMemory("index"), \ TensorArrayReadOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -599,6 +659,11 @@ class TensorArrayPackOrGatherOp : public OpKernel { Name("TensorArrayGatherV2") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("dtype"), \ + TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayGatherV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("dtype"), \ TensorArrayPackOrGatherOp<CPUDevice, type, false /* LEGACY_PACK */>); TF_CALL_POD_STRING_TYPES(REGISTER_GATHER_AND_PACK); @@ -631,6 +696,13 @@ REGISTER_GATHER_AND_PACK(bfloat16); .TypeConstraint<type>("dtype") \ .HostMemory("indices") \ .HostMemory("handle"), \ + TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayGatherV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("indices") \ + .HostMemory("handle"), \ TensorArrayPackOrGatherOp<GPUDevice, type, false /* LEGACY_PACK */>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -654,6 +726,13 @@ REGISTER_KERNEL_BUILDER( .HostMemory("indices") .HostMemory("handle"), TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); +REGISTER_KERNEL_BUILDER( + Name("TensorArrayGatherV3") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("dtype") + .HostMemory("indices") + .HostMemory("handle"), + TensorArrayPackOrGatherOp<CPUDevice, int32, false /* LEGACY_PACK */>); #endif // GOOGLE_CUDA @@ -808,6 +887,12 @@ class TensorArrayConcatOp : public OpKernel { .TypeConstraint<type>("dtype") \ .HostMemory("lengths") \ .HostMemory("handle"), \ + TensorArrayConcatOp<CPUDevice, type>) \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("lengths") \ + .HostMemory("handle"), \ TensorArrayConcatOp<CPUDevice, type>) TF_CALL_POD_STRING_TYPES(REGISTER_CONCAT); @@ -832,6 +917,12 @@ REGISTER_CONCAT(bfloat16); .TypeConstraint<type>("dtype") \ .HostMemory("lengths") \ .HostMemory("handle"), \ + TensorArrayConcatOp<GPUDevice, type>) \ + REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("dtype") \ + .HostMemory("lengths") \ + .HostMemory("handle"), \ TensorArrayConcatOp<GPUDevice, type>) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -853,6 +944,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV2") .HostMemory("lengths") .HostMemory("handle"), TensorArrayConcatOp<CPUDevice, int32>); +REGISTER_KERNEL_BUILDER(Name("TensorArrayConcatV3") + .Device(DEVICE_GPU) + .TypeConstraint<int32>("dtype") + .HostMemory("lengths") + .HostMemory("handle"), + TensorArrayConcatOp<CPUDevice, int32>); #endif // GOOGLE_CUDA @@ -999,6 +1096,12 @@ class TensorArrayUnpackOrScatterOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<type>("T"), \ TensorArrayUnpackOrScatterOp<CPUDevice, type, \ + false /* LEGACY_UNPACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayScatterV3") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T"), \ + TensorArrayUnpackOrScatterOp<CPUDevice, type, \ false /* LEGACY_UNPACK */>); TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK); @@ -1029,6 +1132,14 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK); .HostMemory("indices") \ .HostMemory("handle"), \ TensorArrayUnpackOrScatterOp<GPUDevice, type, \ + false /* LEGACY_UNPACK */>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArrayScatterV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("indices") \ + .HostMemory("handle"), \ + TensorArrayUnpackOrScatterOp<GPUDevice, type, \ false /* LEGACY_UNPACK */>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -1166,6 +1277,9 @@ class TensorArraySplitOp : public OpKernel { TensorArraySplitOp<CPUDevice, type>); \ REGISTER_KERNEL_BUILDER( \ Name("TensorArraySplitV2").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + TensorArraySplitOp<CPUDevice, type>); \ + REGISTER_KERNEL_BUILDER( \ + Name("TensorArraySplitV3").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ TensorArraySplitOp<CPUDevice, type>); TF_CALL_ALL_TYPES(REGISTER_SPLIT); @@ -1185,6 +1299,12 @@ TF_CALL_ALL_TYPES(REGISTER_SPLIT); .TypeConstraint<type>("T") \ .HostMemory("lengths") \ .HostMemory("handle"), \ + TensorArraySplitOp<GPUDevice, type>); \ + REGISTER_KERNEL_BUILDER(Name("TensorArraySplitV3") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("lengths") \ + .HostMemory("handle"), \ TensorArraySplitOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); @@ -1214,6 +1334,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArraySize").Device(DEVICE_CPU), TensorArraySizeOp); REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2").Device(DEVICE_CPU), TensorArraySizeOp); +REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3").Device(DEVICE_CPU), + TensorArraySizeOp); REGISTER_KERNEL_BUILDER(Name("TensorArraySize") .Device(DEVICE_GPU) @@ -1225,6 +1347,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV2") .HostMemory("handle") .HostMemory("size"), TensorArraySizeOp); +REGISTER_KERNEL_BUILDER(Name("TensorArraySizeV3") + .Device(DEVICE_GPU) + .HostMemory("handle") + .HostMemory("size"), + TensorArraySizeOp); // CLOSE // ********************************************************************** @@ -1257,6 +1384,8 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayClose").Device(DEVICE_CPU), TensorArrayCloseOp); REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV2").Device(DEVICE_CPU), TensorArrayCloseOp); +REGISTER_KERNEL_BUILDER(Name("TensorArrayCloseV3").Device(DEVICE_CPU), + TensorArrayCloseOp); REGISTER_KERNEL_BUILDER( Name("TensorArrayClose").Device(DEVICE_GPU).HostMemory("handle"), @@ -1264,5 +1393,8 @@ REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER( Name("TensorArrayCloseV2").Device(DEVICE_GPU).HostMemory("handle"), TensorArrayCloseOp); +REGISTER_KERNEL_BUILDER( + Name("TensorArrayCloseV3").Device(DEVICE_GPU).HostMemory("handle"), + TensorArrayCloseOp); } // namespace tensorflow diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index bb1856058c..eb7f72f532 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -771,14 +771,15 @@ handle: The handle to a stack. // -------------------------------------------------------------------------- -REGISTER_OP("TensorArrayV2") +REGISTER_OP("TensorArrayV3") .Input("size: int32") .Attr("dtype: type") .Attr("element_shape: shape = { unknown_rank: true }") .Attr("dynamic_size: bool = false") .Attr("clear_after_read: bool = true") .Attr("tensor_array_name: string = ''") - .Output("handle: string") + .Output("handle: resource") + .Output("flow: float") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { ShapeHandle unused; @@ -791,6 +792,7 @@ An array of Tensors of given size, with data written via Write and read via Read or Pack. handle: The handle to the TensorArray. +flow: A scalar used to control gradient flow. size: The size of the array. dtype: The type of the elements on the tensor_array. element_shape: The expected shape of an element, if known. Used to @@ -806,10 +808,11 @@ tensor_array_name: Overrides the name used for the temporary tensor_array is guaranteed unique). )doc"); -REGISTER_OP("TensorArrayGradV2") - .Input("handle: string") +REGISTER_OP("TensorArrayGradV3") + .Input("handle: resource") .Input("flow_in: float") - .Output("grad_handle: string") + .Output("grad_handle: resource") + .Output("flow_out: float") .Attr("source: string") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { @@ -866,8 +869,8 @@ source: The gradient source string, used to decide which gradient TensorArray to return. )doc"); -REGISTER_OP("TensorArrayWriteV2") - .Input("handle: string") +REGISTER_OP("TensorArrayWriteV3") + .Input("handle: resource") .Input("index: int32") .Input("value: T") .Input("flow_in: float") @@ -894,8 +897,8 @@ flow_in: A float scalar that enforces proper chaining of operations. flow_out: A float scalar that enforces proper chaining of operations. )doc"); -REGISTER_OP("TensorArrayReadV2") - .Input("handle: string") +REGISTER_OP("TensorArrayReadV3") + .Input("handle: resource") .Input("index: int32") .Input("flow_in: float") .Output("value: dtype") @@ -919,8 +922,8 @@ flow_in: A float scalar that enforces proper chaining of operations. value: The tensor that is read from the TensorArray. )doc"); -REGISTER_OP("TensorArrayGatherV2") - .Input("handle: string") +REGISTER_OP("TensorArrayGatherV3") + .Input("handle: resource") .Input("indices: int32") .Input("flow_in: float") .Output("value: dtype") @@ -951,8 +954,8 @@ value: All of the elements in the TensorArray, concatenated along a new axis (the new dimension 0). )doc"); -REGISTER_OP("TensorArrayScatterV2") - .Input("handle: string") +REGISTER_OP("TensorArrayScatterV3") + .Input("handle: resource") .Input("indices: int32") .Input("value: T") .Input("flow_in: float") @@ -979,8 +982,8 @@ flow_in: A float scalar that enforces proper chaining of operations. flow_out: A float scalar that enforces proper chaining of operations. )doc"); -REGISTER_OP("TensorArrayConcatV2") - .Input("handle: string") +REGISTER_OP("TensorArrayConcatV3") + .Input("handle: resource") .Input("flow_in: float") .Output("value: dtype") .Output("lengths: int64") @@ -1026,8 +1029,8 @@ lengths: A vector of the row sizes of the original T elements in the `(n1, n2, ..., n(T-1))`. )doc"); -REGISTER_OP("TensorArraySplitV2") - .Input("handle: string") +REGISTER_OP("TensorArraySplitV3") + .Input("handle: resource") .Input("value: T") .Input("lengths: int64") .Input("flow_in: float") @@ -1072,8 +1075,8 @@ flow_in: A float scalar that enforces proper chaining of operations. flow_out: A float scalar that enforces proper chaining of operations. )doc"); -REGISTER_OP("TensorArraySizeV2") - .Input("handle: string") +REGISTER_OP("TensorArraySizeV3") + .Input("handle: resource") .Input("flow_in: float") .Output("size: int32") .SetShapeFn([](InferenceContext* c) { @@ -1091,8 +1094,8 @@ flow_in: A float scalar that enforces proper chaining of operations. size: The current size of the TensorArray. )doc"); -REGISTER_OP("TensorArrayCloseV2") - .Input("handle: string") +REGISTER_OP("TensorArrayCloseV3") + .Input("handle: resource") .SetShapeFn([](InferenceContext* c) { ShapeHandle handle; DimensionHandle unused_dim; @@ -1121,7 +1124,23 @@ REGISTER_OP("TensorArray") .Output("handle: Ref(string)") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayV2"); + .Deprecated(16, "Use TensorArrayV3"); +REGISTER_OP("TensorArrayV2") + .Input("size: int32") + .Attr("dtype: type") + .Attr("element_shape: shape = { unknown_rank: true }") + .Attr("dynamic_size: bool = false") + .Attr("clear_after_read: bool = true") + .Attr("tensor_array_name: string = ''") + .Output("handle: string") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + c->set_output(0, c->Vector(2)); + return Status::OK(); + }) + .Deprecated(20, "Use TensorArrayV3"); REGISTER_OP("TensorArrayGrad") .Input("handle: string") .Input("flow_in: float") @@ -1129,7 +1148,22 @@ REGISTER_OP("TensorArrayGrad") .Attr("source: string") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayGradV2"); + .Deprecated(16, "Use TensorArrayGradV3"); +REGISTER_OP("TensorArrayGradV2") + .Input("handle: string") + .Input("flow_in: float") + .Output("grad_handle: string") + .Attr("source: string") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + c->set_output(0, c->Vector(2)); + return Status::OK(); + }) + .Deprecated(20, "Use TensorArrayGradV3"); REGISTER_OP("TensorArrayWrite") .Input("handle: Ref(string)") .Input("index: int32") @@ -1138,7 +1172,26 @@ REGISTER_OP("TensorArrayWrite") .Output("flow_out: float") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayWriteV2"); + .Deprecated(16, "Use TensorArrayWriteV3"); +REGISTER_OP("TensorArrayWriteV2") + .Input("handle: string") + .Input("index: int32") + .Input("value: T") + .Input("flow_in: float") + .Output("flow_out: float") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }) + .Deprecated(20, "Use TensorArrayGradV3"); REGISTER_OP("TensorArrayRead") .Input("handle: Ref(string)") .Input("index: int32") @@ -1146,7 +1199,24 @@ REGISTER_OP("TensorArrayRead") .Output("value: dtype") .Attr("dtype: type") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayReadV2"); + .Deprecated(16, "Use TensorArrayReadV3"); +REGISTER_OP("TensorArrayReadV2") + .Input("handle: string") + .Input("index: int32") + .Input("flow_in: float") + .Output("value: dtype") + .Attr("dtype: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::UnknownShape(c); + }) + .Deprecated(20, "Use TensorArrayReadV3"); REGISTER_OP("TensorArrayPack") .Input("handle: Ref(string)") .Input("flow_in: float") @@ -1154,7 +1224,7 @@ REGISTER_OP("TensorArrayPack") .Attr("dtype: type") .Attr("element_shape: shape = { unknown_rank: true }") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayGatherV2 with RangeOp"); + .Deprecated(16, "Use TensorArrayGatherV3 with RangeOp"); REGISTER_OP("TensorArrayUnpack") .Input("handle: Ref(string)") .Input("value: T") @@ -1162,7 +1232,7 @@ REGISTER_OP("TensorArrayUnpack") .Output("flow_out: float") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayScatterV2 with RangeOp"); + .Deprecated(20, "Use TensorArrayScatterV3 with RangeOp"); REGISTER_OP("TensorArrayGather") .Input("handle: Ref(string)") .Input("indices: int32") @@ -1171,7 +1241,24 @@ REGISTER_OP("TensorArrayGather") .Attr("dtype: type") .Attr("element_shape: shape = { unknown_rank: true }") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayGatherV2"); + .Deprecated(16, "Use TensorArrayGatherV3"); +REGISTER_OP("TensorArrayGatherV2") + .Input("handle: string") + .Input("indices: int32") + .Input("flow_in: float") + .Output("value: dtype") + .Attr("dtype: type") + .Attr("element_shape: shape = { unknown_rank: true }") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return shape_inference::UnknownShape(c); + }) + .Deprecated(20, "Use TensorArrayGatherV3"); REGISTER_OP("TensorArrayScatter") .Input("handle: Ref(string)") .Input("indices: int32") @@ -1180,7 +1267,24 @@ REGISTER_OP("TensorArrayScatter") .Output("flow_out: float") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayScatterV2"); + .Deprecated(19, "Use TensorArrayGradV3"); +REGISTER_OP("TensorArrayScatterV2") + .Input("handle: string") + .Input("indices: int32") + .Input("value: T") + .Input("flow_in: float") + .Output("flow_out: float") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), 0), 2, &unused_dim)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }) + .Deprecated(20, "Use TensorArrayScatterV3"); REGISTER_OP("TensorArrayConcat") .Input("handle: Ref(string)") .Input("flow_in: float") @@ -1189,7 +1293,26 @@ REGISTER_OP("TensorArrayConcat") .Attr("dtype: type") .Attr("element_shape_except0: shape = { unknown_rank: true }") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayConcatV2"); + .Deprecated(16, "Use TensorArrayGradV3"); +REGISTER_OP("TensorArrayConcatV2") + .Input("handle: string") + .Input("flow_in: float") + .Output("value: dtype") + .Output("lengths: int64") + .Attr("dtype: type") + .Attr("element_shape_except0: shape = { unknown_rank: true }") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + c->set_output(0, c->UnknownShape()); + c->set_output(1, c->Vector(c->UnknownDim())); + return Status::OK(); + }) + .Deprecated(20, "Use TensorArrayConcatV3"); REGISTER_OP("TensorArraySplit") .Input("handle: Ref(string)") .Input("value: T") @@ -1198,17 +1321,57 @@ REGISTER_OP("TensorArraySplit") .Output("flow_out: float") .Attr("T: type") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArraySplitV2"); + .Deprecated(16, "Use TensorArraySplitV3"); +REGISTER_OP("TensorArraySplitV2") + .Input("handle: string") + .Input("value: T") + .Input("lengths: int64") + .Input("flow_in: float") + .Output("flow_out: float") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return shape_inference::ScalarShape(c); + }) + .Deprecated(20, "Use TensorArraySplitV3"); REGISTER_OP("TensorArraySize") .Input("handle: Ref(string)") .Input("flow_in: float") .Output("size: int32") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArraySizeV2"); + .Deprecated(16, "Use TensorArraySizeV3"); +REGISTER_OP("TensorArraySizeV2") + .Input("handle: string") + .Input("flow_in: float") + .Output("size: int32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + return shape_inference::ScalarShape(c); + }) + .Deprecated(20, "Use TensorArraySizeV3"); REGISTER_OP("TensorArrayClose") .Input("handle: Ref(string)") .SetShapeFn([](InferenceContext* c) { return Status::OK(); }) - .Deprecated(16, "Use TensorArrayCloseV2"); + .Deprecated(16, "Use TensorArrayCloseV3"); +REGISTER_OP("TensorArrayCloseV2") + .Input("handle: string") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle handle; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &handle)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim)); + return Status::OK(); + }) + .Deprecated(20, "Use TensorArrayCloseV3"); // -------------------------------------------------------------------------- diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 7e8f5fb20e..d960b8dd42 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -78,7 +78,7 @@ limitations under the License. // 20. Catch all version 1.0 changes to Python API generation. SplitV is now // used for tf.split, ReverseV2 is now used by tf.reverse, ConcatV2 is // now used by tf.concat_v2 (and soon tf.concat). Graphs use flooring -// division and mod semantics. (12dec2016) +// division and mod semantics. TensorArrayV3. (12dec2016) #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 454f1ceb61..6d8eeb39aa 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -344,7 +344,7 @@ class TensorArrayTest(test.TestCase): r1_0 = g_ta_1.read(0) t_g_ta_0, t_g_ta_1, d_r1_0 = session.run( - [g_ta_0.handle, g_ta_1.handle, r1_0]) + [g_ta_0.handle.op, g_ta_1.handle.op, r1_0]) self.assertAllEqual(t_g_ta_0, t_g_ta_1) self.assertAllEqual([[4.0, 5.0]], d_r1_0) @@ -378,7 +378,7 @@ class TensorArrayTest(test.TestCase): w0 = ta.write(0, [[4.0, 5.0]]) # Test reading wrong datatype - r0_bad = gen_data_flow_ops._tensor_array_read_v2( + r0_bad = gen_data_flow_ops._tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) with self.assertRaisesOpError( "TensorArray dtype is float but Op requested dtype double."): @@ -531,23 +531,6 @@ class TensorArrayTest(test.TestCase): r = r1 + r2 self.assertAllClose(9.0, r.eval()) - def testDuplicateTensorArrayHasDifferentName(self): - with self.test_session(use_gpu=True) as session: - h1 = tensor_array_ops.TensorArray( - size=1, dtype=dtypes.float32, tensor_array_name="foo") - c1 = h1.write(0, 4.0) - h2 = tensor_array_ops.TensorArray( - size=1, dtype=dtypes.float32, tensor_array_name="foo") - c2 = h2.write(0, 5.0) - _, _, c1h, c2h = session.run([c1.flow, c2.flow, c1.handle, c2.handle]) - c1h = [x.decode("ascii") for x in c1h] - c2h = [x.decode("ascii") for x in c2h] - self.assertEqual(c1h[0], "_tensor_arrays") - self.assertEqual(c2h[0], "_tensor_arrays") - self.assertTrue(c1h[1].startswith("foo_")) - self.assertTrue(c2h[1].startswith("foo_")) - self.assertNotEqual(c1h[1], c2h[1]) - def _testTensorArrayGradientWriteReadType(self, dtype): with self.test_session(use_gpu=True) as session: ta = tensor_array_ops.TensorArray( diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index 1e982e1197..6f147dcaf0 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -107,6 +107,18 @@ TensorArraySizeV2 TensorArraySplitV2 TensorArrayUnpackV2 TensorArrayWriteV2 +TensorArrayV3 +TensorArrayCloseV3 +TensorArrayConcatV3 +TensorArrayGatherV3 +TensorArrayGradV3 +TensorArrayReadV3 +TensorArrayPackV3 +TensorArrayScatterV3 +TensorArraySizeV3 +TensorArraySplitV3 +TensorArrayUnpackV3 +TensorArrayWriteV3 GetSessionHandle GetSessionTensor DeleteSessionTensor diff --git a/tensorflow/python/ops/tensor_array_grad.py b/tensorflow/python/ops/tensor_array_grad.py index 1b1f3926d4..0e7d1880ce 100644 --- a/tensorflow/python/ops/tensor_array_grad.py +++ b/tensorflow/python/ops/tensor_array_grad.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Gradients for operators defined in tensor_array_ops.py.""" from __future__ import absolute_import from __future__ import division @@ -21,7 +20,6 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import tensor_array_ops - # TODO(b/31222613): These ops may be differentiable, and there may be # latent bugs here. ops.NotDifferentiable("TensorArray") @@ -34,6 +32,11 @@ ops.NotDifferentiable("TensorArrayGradV2") ops.NotDifferentiable("TensorArraySizeV2") ops.NotDifferentiable("TensorArrayCloseV2") +ops.NotDifferentiable("TensorArrayV3") +ops.NotDifferentiable("TensorArrayGradV3") +ops.NotDifferentiable("TensorArraySizeV3") +ops.NotDifferentiable("TensorArrayCloseV3") + def _GetGradSource(op_or_tensor): """Identify which call to tf.gradients created this gradient op or tensor. @@ -74,6 +77,7 @@ def _GetGradSource(op_or_tensor): @ops.RegisterGradient("TensorArrayRead") @ops.RegisterGradient("TensorArrayReadV2") +@ops.RegisterGradient("TensorArrayReadV3") def _TensorArrayReadGrad(op, grad): """Gradient for TensorArrayRead. @@ -95,14 +99,16 @@ def _TensorArrayReadGrad(op, grad): flow = op.inputs[2] dtype = op.get_attr("dtype") grad_source = _GetGradSource(grad) - g = tensor_array_ops.TensorArray(dtype=dtype, handle=handle).grad( - source=grad_source, flow=flow) + g = tensor_array_ops.TensorArray( + dtype=dtype, handle=handle, flow=flow).grad( + source=grad_source, flow=flow) w_g = g.write(index, grad) return [None, None, w_g.flow] @ops.RegisterGradient("TensorArrayWrite") @ops.RegisterGradient("TensorArrayWriteV2") +@ops.RegisterGradient("TensorArrayWriteV3") def _TensorArrayWriteGrad(op, flow): """Gradient for TensorArrayWrite. @@ -119,14 +125,16 @@ def _TensorArrayWriteGrad(op, flow): index = op.inputs[1] dtype = op.get_attr("T") grad_source = _GetGradSource(flow) - g = tensor_array_ops.TensorArray(dtype=dtype, handle=handle).grad( - source=grad_source, flow=flow) + g = tensor_array_ops.TensorArray( + dtype=dtype, handle=handle, flow=flow).grad( + source=grad_source, flow=flow) grad = g.read(index) return [None, None, grad, flow] @ops.RegisterGradient("TensorArrayGather") @ops.RegisterGradient("TensorArrayGatherV2") +@ops.RegisterGradient("TensorArrayGatherV3") def _TensorArrayGatherGrad(op, grad): """Gradient for TensorArrayGather. @@ -148,14 +156,16 @@ def _TensorArrayGatherGrad(op, grad): flow = op.inputs[2] dtype = op.get_attr("dtype") grad_source = _GetGradSource(grad) - g = tensor_array_ops.TensorArray(dtype=dtype, handle=handle).grad( - source=grad_source, flow=flow) + g = tensor_array_ops.TensorArray( + dtype=dtype, handle=handle, flow=flow).grad( + source=grad_source, flow=flow) u_g = g.scatter(indices, grad) return [None, None, u_g.flow] @ops.RegisterGradient("TensorArrayScatter") @ops.RegisterGradient("TensorArrayScatterV2") +@ops.RegisterGradient("TensorArrayScatterV3") def _TensorArrayScatterGrad(op, flow): """Gradient for TensorArrayScatter. @@ -170,14 +180,16 @@ def _TensorArrayScatterGrad(op, flow): indices = op.inputs[1] dtype = op.get_attr("T") grad_source = _GetGradSource(flow) - g = tensor_array_ops.TensorArray(dtype=dtype, handle=handle).grad( - source=grad_source, flow=flow) + g = tensor_array_ops.TensorArray( + dtype=dtype, handle=handle, flow=flow).grad( + source=grad_source, flow=flow) grad = g.gather(indices) return [None, None, grad, flow] @ops.RegisterGradient("TensorArrayConcat") @ops.RegisterGradient("TensorArrayConcatV2") +@ops.RegisterGradient("TensorArrayConcatV3") def _TensorArrayConcatGrad(op, grad, unused_lengths_grad): """Gradient for TensorArrayConcat. @@ -199,8 +211,9 @@ def _TensorArrayConcatGrad(op, grad, unused_lengths_grad): lengths = op.outputs[1] dtype = op.get_attr("dtype") grad_source = _GetGradSource(grad) - g = tensor_array_ops.TensorArray(dtype=dtype, handle=handle).grad( - source=grad_source, flow=flow) + g = tensor_array_ops.TensorArray( + dtype=dtype, handle=handle, flow=flow).grad( + source=grad_source, flow=flow) u_g = g.split(grad, lengths=lengths) # handle, flow_in return [None, u_g.flow] @@ -208,6 +221,7 @@ def _TensorArrayConcatGrad(op, grad, unused_lengths_grad): @ops.RegisterGradient("TensorArraySplit") @ops.RegisterGradient("TensorArraySplitV2") +@ops.RegisterGradient("TensorArraySplitV3") def _TensorArraySplitGrad(op, flow): """Gradient for TensorArraySplit. @@ -221,8 +235,9 @@ def _TensorArraySplitGrad(op, flow): handle = op.inputs[0] dtype = op.get_attr("T") grad_source = _GetGradSource(flow) - g = tensor_array_ops.TensorArray(dtype=dtype, handle=handle).grad( - source=grad_source, flow=flow) + g = tensor_array_ops.TensorArray( + dtype=dtype, handle=handle, flow=flow).grad( + source=grad_source, flow=flow) grad = g.concat() # handle, value, lengths, flow_in return [None, grad, None, flow] diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 043fd9b1ab..7433afc85f 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """TensorArray operations. ## Classes containing dynamically sized arrays of Tensors. @@ -25,8 +24,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util @@ -82,9 +79,17 @@ class TensorArray(object): @@grad """ - def __init__(self, dtype, size=None, dynamic_size=None, - clear_after_read=None, tensor_array_name=None, handle=None, - flow=None, infer_shape=True, element_shape=None, name=None): + def __init__(self, + dtype, + size=None, + dynamic_size=None, + clear_after_read=None, + tensor_array_name=None, + handle=None, + flow=None, + infer_shape=True, + element_shape=None, + name=None): """Construct a new TensorArray or wrap an existing TensorArray handle. A note about the parameter `name`: @@ -159,29 +164,22 @@ class TensorArray(object): with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: if handle is not None: self._handle = handle - else: - if flow is not None: - with ops.colocate_with(flow): - self._handle = gen_data_flow_ops._tensor_array_v2( - dtype=dtype, size=size, element_shape=element_shape, - dynamic_size=dynamic_size, - clear_after_read=clear_after_read, - tensor_array_name=tensor_array_name, name=scope) - else: - # Construct the TensorArray with an empty device. The first - # write into the TensorArray from a Tensor with a set device - # will retroactively set the device value of this op. - with ops.device(None), ops.colocate_with(None, ignore_existing=True): - self._handle = gen_data_flow_ops._tensor_array_v2( - dtype=dtype, size=size, element_shape=element_shape, - dynamic_size=dynamic_size, - clear_after_read=clear_after_read, - tensor_array_name=tensor_array_name, name=scope) - if flow is not None: + if flow is None: + raise ValueError("flow must not be None if handle is not None.") self._flow = flow else: - with ops.colocate_with(self._handle): - self._flow = constant_op.constant(0, dtype=_dtypes.float32) + # Construct the TensorArray with an empty device. The first + # write into the TensorArray from a Tensor with a set device + # will retroactively set the device value of this op. + with ops.device(None), ops.colocate_with(None, ignore_existing=True): + self._handle, self._flow = gen_data_flow_ops._tensor_array_v3( + dtype=dtype, + size=size, + element_shape=element_shape, + dynamic_size=dynamic_size, + clear_after_read=clear_after_read, + tensor_array_name=tensor_array_name, + name=scope) @property def flow(self): @@ -207,12 +205,15 @@ class TensorArray(object): flow = self.flow with ops.name_scope(name, "TensorArrayGrad", [self._handle]): with ops.colocate_with(self._handle): - g_handle = gen_data_flow_ops._tensor_array_grad_v2( + g_handle, unused_flow = gen_data_flow_ops._tensor_array_grad_v3( handle=self._handle, source=source, flow_in=flow, name=name) with ops.control_dependencies([g_handle]): flow = array_ops.identity(flow, name="gradient_flow") - g = TensorArray(dtype=self._dtype, handle=g_handle, flow=flow, - infer_shape=self._infer_shape) + g = TensorArray( + dtype=self._dtype, + handle=g_handle, + flow=flow, + infer_shape=self._infer_shape) g._element_shape = self._element_shape return g @@ -227,9 +228,12 @@ class TensorArray(object): The tensor at index `index`. """ with ops.colocate_with(self._handle): - value = gen_data_flow_ops._tensor_array_read_v2( - handle=self._handle, index=index, flow_in=self._flow, - dtype=self._dtype, name=name) + value = gen_data_flow_ops._tensor_array_read_v3( + handle=self._handle, + index=index, + flow_in=self._flow, + dtype=self._dtype, + name=name) if self._element_shape: value.set_shape(self._element_shape[0].dims) return value @@ -253,20 +257,22 @@ class TensorArray(object): value = ops.convert_to_tensor(value, name="value") _maybe_set_device(self._handle.op, value) with ops.colocate_with(self._handle): - flow_out = gen_data_flow_ops._tensor_array_write_v2( - handle=self._handle, index=index, value=value, flow_in=self._flow, + flow_out = gen_data_flow_ops._tensor_array_write_v3( + handle=self._handle, + index=index, + value=value, + flow_in=self._flow, name=name) - ta = TensorArray(dtype=self._dtype, handle=self._handle) - ta._flow = flow_out + ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out) ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape if ta._infer_shape: val_shape = value.get_shape() if ta._element_shape: if not val_shape == ta._element_shape[0]: - raise ValueError( - "Inconsistent shapes: saw %s but expected %s " - "(and infer_shape=True)" % (val_shape, ta._element_shape[0])) + raise ValueError("Inconsistent shapes: saw %s but expected %s " + "(and infer_shape=True)" % + (val_shape, ta._element_shape[0])) else: ta._element_shape.append(val_shape) return ta @@ -287,12 +293,12 @@ class TensorArray(object): with ops.name_scope(name, "TensorArrayStack", [self._handle]): return self.gather(math_ops.range(0, self.size()), name=name) - @deprecated( - "2016-12-12", - "This op will be removed after the deprecation date. " - "Please switch to tf.stack.") + @deprecated("2016-12-12", + "This op will be removed after the deprecation date. " + "Please switch to tf.stack.") def pack(self, name=None): return self.stack(name) + pack.__doc__ = stack.__doc__ def gather(self, indices, name=None): @@ -314,7 +320,7 @@ class TensorArray(object): element_shape = self._element_shape[0] else: element_shape = tensor_shape.TensorShape(None) - value = gen_data_flow_ops._tensor_array_gather_v2( + value = gen_data_flow_ops._tensor_array_gather_v3( handle=self._handle, indices=indices, flow_in=self._flow, @@ -343,7 +349,7 @@ class TensorArray(object): else: element_shape_except0 = tensor_shape.TensorShape(None) with ops.colocate_with(self._handle): - value, _ = gen_data_flow_ops._tensor_array_concat_v2( + value, _ = gen_data_flow_ops._tensor_array_concat_v3( handle=self._handle, flow_in=self._flow, dtype=self._dtype, @@ -374,12 +380,12 @@ class TensorArray(object): return self.scatter( indices=math_ops.range(0, num_elements), value=value, name=name) - @deprecated( - "2016-12-12", - "This op will be removed after the deprecation date. " - "Please switch to tf.unstack.") + @deprecated("2016-12-12", + "This op will be removed after the deprecation date. " + "Please switch to tf.unstack.") def unpack(self, value, name=None): return self.unstack(value, name) + unpack.__doc__ = unstack.__doc__ def scatter(self, indices, value, name=None): @@ -403,11 +409,13 @@ class TensorArray(object): value = ops.convert_to_tensor(value, name="value") _maybe_set_device(self._handle.op, value) with ops.colocate_with(self._handle): - flow_out = gen_data_flow_ops._tensor_array_scatter_v2( - handle=self._handle, indices=indices, value=value, - flow_in=self._flow, name=name) - ta = TensorArray(dtype=self._dtype, handle=self._handle) - ta._flow = flow_out + flow_out = gen_data_flow_ops._tensor_array_scatter_v3( + handle=self._handle, + indices=indices, + value=value, + flow_in=self._flow, + name=name) + ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out) ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape if ta._infer_shape: @@ -417,10 +425,9 @@ class TensorArray(object): element_shape = tensor_shape.TensorShape(val_shape.dims[1:]) if ta._element_shape: if not element_shape == ta._element_shape[0]: - raise ValueError( - "Inconsistent shapes: saw %s but expected %s " - "(and infer_shape=True)" - % (element_shape, ta._element_shape[0])) + raise ValueError("Inconsistent shapes: saw %s but expected %s " + "(and infer_shape=True)" % + (element_shape, ta._element_shape[0])) else: ta._element_shape.append(element_shape) return ta @@ -447,11 +454,13 @@ class TensorArray(object): _maybe_set_device(self._handle.op, value) lengths_64 = math_ops.to_int64(lengths) with ops.colocate_with(self._handle): - flow_out = gen_data_flow_ops._tensor_array_split_v2( - handle=self._handle, value=value, lengths=lengths_64, - flow_in=self._flow, name=name) - ta = TensorArray(dtype=self._dtype, handle=self._handle) - ta._flow = flow_out + flow_out = gen_data_flow_ops._tensor_array_split_v3( + handle=self._handle, + value=value, + lengths=lengths_64, + flow_in=self._flow, + name=name) + ta = TensorArray(dtype=self._dtype, handle=self._handle, flow=flow_out) ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape if ta._infer_shape: @@ -460,14 +469,13 @@ class TensorArray(object): element_shape = tensor_shape.unknown_shape() if val_shape.dims is not None: if clengths is not None and clengths.max() == clengths.min(): - element_shape = tensor_shape.TensorShape( - [clengths[0]] + val_shape.dims[1:]) + element_shape = tensor_shape.TensorShape([clengths[0]] + + val_shape.dims[1:]) if ta._element_shape: if not element_shape == ta._element_shape[0]: - raise ValueError( - "Inconsistent shapes: saw %s but expected %s " - "(and infer_shape=True)" - % (element_shape, ta._element_shape[0])) + raise ValueError("Inconsistent shapes: saw %s but expected %s " + "(and infer_shape=True)" % + (element_shape, ta._element_shape[0])) else: ta._element_shape.append(element_shape) return ta @@ -475,13 +483,13 @@ class TensorArray(object): def size(self, name=None): """Return the size of the TensorArray.""" with ops.colocate_with(self._handle): - return gen_data_flow_ops._tensor_array_size_v2( + return gen_data_flow_ops._tensor_array_size_v3( handle=self._handle, flow_in=self.flow, name=name) def close(self, name=None): """Close the current TensorArray.""" with ops.colocate_with(self._handle): - return gen_data_flow_ops._tensor_array_close_v2( + return gen_data_flow_ops._tensor_array_close_v3( handle=self._handle, name=name) # pylint: enable=protected-access |