aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc26
-rw-r--r--tensorflow/core/framework/shape_inference.cc78
-rw-r--r--tensorflow/core/framework/shape_inference.h11
-rw-r--r--tensorflow/core/framework/shape_inference_test.cc13
-rw-r--r--tensorflow/core/kernels/BUILD1
-rw-r--r--tensorflow/core/kernels/dense_update_functor.cc56
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h14
-rw-r--r--tensorflow/core/kernels/gather_functor.h13
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc118
-rw-r--r--tensorflow/core/kernels/scatter_functor.h118
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h30
-rw-r--r--tensorflow/core/ops/list_ops.cc4
-rw-r--r--tensorflow/python/framework/tensor_util.py19
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py121
-rw-r--r--tensorflow/python/ops/list_ops.py11
16 files changed, 491 insertions, 143 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 0bc4c5d473..d4c3f2eda8 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -151,6 +151,7 @@ tensorflow/core/kernels/decode_bmp_op.cc
tensorflow/core/kernels/depthtospace_op.cc
tensorflow/core/kernels/data_format_ops.cc
tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/dense_update_functor.cc
tensorflow/core/kernels/dense_update_ops.cc
tensorflow/core/kernels/deep_conv2d.cc
tensorflow/core/kernels/decode_wav_op.cc
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index 1b7e3138ee..06dbe04986 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -431,6 +431,32 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
InferenceContext* src_context = GetContext(input_edge->src());
if (src_context == nullptr) return errors::Internal("Missing src context");
ShapeHandle src_shape = src_context->output(input_edge->src_output());
+
+ if (src_context->Value(src_context->Rank(src_shape)) == 0) {
+ Tensor t;
+ bool evaluated = false;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantTensorForEdge(node, dst_idx, &evaluated, &t));
+ if (!evaluated) {
+ return errors::InvalidArgument(
+ "Received a shape scalar with unknown static value. A static value "
+ "of '-1' is required to represent an unknown shape.");
+ }
+ if (t.dims() == 0) {
+ if (t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) {
+ *result = target_context->UnknownShape();
+ return Status::OK();
+ } else if (t.dtype() == DT_INT64 && t.scalar<int64>()() == -1) {
+ *result = target_context->UnknownShape();
+ return Status::OK();
+ }
+ }
+ return errors::InvalidArgument(
+ "Received an invalid shape scalar with a static value that is not "
+ "'-1': ",
+ t.DebugString());
+ }
+
TF_RETURN_IF_ERROR(src_context->WithRank(src_shape, 1, &src_shape));
const string& src_op = input_edge->src()->type_string();
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 54ecaa5dd4..cc1ec47a83 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -726,6 +726,24 @@ ShapeHandle InferenceContext::Matrix(DimensionOrConstant dim1,
return MakeShape({dim1, dim2});
}
+Status InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ int input_idx, ShapeHandle* out) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(WithRankAtMost(input(input_idx), 1, &input_shape));
+
+ requested_input_tensor_as_partial_shape_[input_idx] = true;
+ if (input_idx < input_tensors_as_shapes_.size() &&
+ input_tensors_as_shapes_[input_idx].IsSet() &&
+ RankKnown(input_tensors_as_shapes_[input_idx])) {
+ *out = input_tensors_as_shapes_[input_idx];
+ return Status::OK();
+ }
+
+ return InternalMakeShapeFromTensor(
+ true /* treat_unknown_scalar_tensor_as_unknown_shape */,
+ input_tensor(input_idx), input_shape, out);
+}
+
Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
ShapeHandle* out) {
ShapeHandle input_shape;
@@ -739,13 +757,31 @@ Status InferenceContext::MakeShapeFromShapeTensor(int input_idx,
return Status::OK();
}
- return MakeShapeFromTensor(input_tensor(input_idx), input_shape, out);
+ return InternalMakeShapeFromTensor(
+ false /* treat_unknown_scalar_tensor_as_unknown_shape */,
+ input_tensor(input_idx), input_shape, out);
}
Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
ShapeHandle tensor_shape,
ShapeHandle* out) {
+ return InternalMakeShapeFromTensor(
+ false /* treat_unknown_scalar_tensor_as_unknown_shape */, t, tensor_shape,
+ out);
+}
+
+Status InferenceContext::InternalMakeShapeFromTensor(
+ bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+ ShapeHandle tensor_shape, ShapeHandle* out) {
+ // Only callers who have set
+ if (!treat_unknown_scalar_tensor_as_unknown_shape) {
+ TF_RETURN_IF_ERROR(WithRank(tensor_shape, 1, &tensor_shape));
+ }
if (t == nullptr) {
+ // This is guarded by the check above.
+ if (Rank(tensor_shape) == 0) {
+ return ReturnUnknownShape(out);
+ }
// Shape tensor is not known, but if the shape of the shape tensor is then
// the right number of unknown dims can be created.
DimensionHandle shape_dim = Dim(tensor_shape, 0);
@@ -759,10 +795,46 @@ Status InferenceContext::MakeShapeFromTensor(const Tensor* t,
return ReturnCreatedShape(dims, out);
}
+ if (t->shape().dims() == 0) {
+ if (t->dtype() == DataType::DT_INT32) {
+ auto flat_t = t->scalar<int32>();
+ if (flat_t() != -1) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, or if its rank 0 it must have value "
+ "-1 "
+ "(representing an unknown shape). Saw value: ",
+ flat_t());
+ }
+ return ReturnUnknownShape(out);
+ } else if (t->dtype() == DataType::DT_INT64) {
+ auto flat_t = t->scalar<int64>();
+ if (flat_t() != -1) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, or if its rank 0 it must have value "
+ "-1 "
+ "(representing an unknown shape). Saw value: ",
+ flat_t());
+ }
+ return ReturnUnknownShape(out);
+ } else {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Input tensor must be int32 or int64, but was ",
+ DataTypeString(t->dtype()));
+ }
+ }
+
if (t->shape().dims() != 1) {
*out = nullptr;
- return errors::InvalidArgument("Input tensor must be rank 1, but was rank ",
- t->shape().dims());
+ return errors::InvalidArgument(
+ "Input tensor must be rank 1, but was rank ", t->shape().dims(), ".",
+ ((t->shape().dims() == 0)
+ ? "If it is rank 0 rank 0 it must have statically known value -1 "
+ "(representing an unknown shape). "
+ : " "),
+ "Saw tensor shape ", t->shape().DebugString());
}
std::vector<DimensionHandle> dims;
if (t->dtype() == DataType::DT_INT32) {
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index accc587000..cdb4bd79bb 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -463,6 +463,12 @@ class InferenceContext {
// the input tensor is NULL, then an unknown shape is returned.
Status MakeShapeFromShapeTensor(int input_idx, ShapeHandle* out);
+ // Like the function above, but treats scalar values as unknown
+ // shapes. **NOTE** If the scalar is statically known, its value
+ // must be -1 or an error is returned.
+ Status MakeShapeFromShapeTensorTreatScalarAsUnknownShape(int input_idx,
+ ShapeHandle* out);
+
// Returns in <out> a new shape corresponding to <proto>.
Status MakeShapeFromShapeProto(const TensorShapeProto& proto,
ShapeHandle* out);
@@ -708,6 +714,11 @@ class InferenceContext {
merged_dims_.clear();
}
+ // Helper method for MakeShapeFromTensor and MakeShapeFromShapeTensor.
+ Status InternalMakeShapeFromTensor(
+ bool treat_unknown_scalar_tensor_as_unknown_shape, const Tensor* t,
+ ShapeHandle tensor_shape, ShapeHandle* out);
+
ShapeManager shape_manager_;
// inputs_, outputs_, and input_tensors_as_shapes_ refer to values from
diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc
index da103bfec9..586c38e43b 100644
--- a/tensorflow/core/framework/shape_inference_test.cc
+++ b/tensorflow/core/framework/shape_inference_test.cc
@@ -1081,17 +1081,26 @@ TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
t = ::tensorflow::test::AsTensor<int64>({});
EXPECT_EQ("[]", create(&t));
+ // Test negative scalar
+ t = ::tensorflow::test::AsScalar<int32>(-1);
+ EXPECT_EQ("?", create(&t));
+
t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
EXPECT_TRUE(str_util::StrContains(
create(&t), "Input tensor must be int32 or int64, but was float"));
t = ::tensorflow::test::AsScalar<int32>(1);
+ auto s_scalar = create(&t);
EXPECT_TRUE(str_util::StrContains(
- create(&t), "Input tensor must be rank 1, but was rank 0"));
+ s_scalar,
+ "Input tensor must be rank 1, or if its rank 0 it must have value -1"))
+ << s_scalar;
t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
+ auto s_matrix = create(&t);
EXPECT_TRUE(str_util::StrContains(
- create(&t), "Input tensor must be rank 1, but was rank 2"));
+ s_matrix, "Input tensor must be rank 1, but was rank 2"))
+ << s_matrix;
// Test negative values for the dims.
t = ::tensorflow::test::AsTensor<int64>({3, -2, 1});
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 783de6af88..b931f79b72 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1395,6 +1395,7 @@ tf_kernel_library(
visibility = [":friends"],
deps = [
":bounds_check",
+ ":dense_update_functor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",
diff --git a/tensorflow/core/kernels/dense_update_functor.cc b/tensorflow/core/kernels/dense_update_functor.cc
index a878fe9a97..3ed3794e01 100644
--- a/tensorflow/core/kernels/dense_update_functor.cc
+++ b/tensorflow/core/kernels/dense_update_functor.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -70,4 +71,59 @@ struct DenseUpdate<CPUDevice, string, ASSIGN> {
} // namespace functor
+#define CPU_DENSE_COPY(T) \
+ case DataTypeToEnum<T>::value: { \
+ functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_; \
+ copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
+ from.flat<T>()); \
+ break; \
+ }
+
+#define INSTANTIATE_GET_VARIANT_COPY_FN(DEVICE, TYPE_CALLER, TYPE_DENSE_COPY) \
+ template <> \
+ Status VariantCopyFn<DEVICE>(OpKernelContext * context, const Tensor& from, \
+ Tensor* to) { \
+ PersistentTensor tmp; \
+ Tensor* tensor; \
+ AllocatorAttributes attr; \
+ attr.set_gpu_compatible(true); \
+ attr.set_nic_compatible(true); \
+ TF_RETURN_IF_ERROR(context->allocate_persistent( \
+ from.dtype(), from.shape(), &tmp, &tensor, attr)); \
+ switch (from.dtype()) { \
+ TYPE_CALLER(TYPE_DENSE_COPY); \
+ default: \
+ return errors::InvalidArgument( \
+ "VariantCopyFn: Could not perform a deep copy of variant " \
+ "element of type: ", \
+ DataTypeString(from.dtype()), \
+ " using device: ", context->device()->name()); \
+ } \
+ *to = *tensor; \
+ return Status::OK(); \
+ }
+
+INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
+
+#if GOOGLE_CUDA
+#define GPU_DENSE_COPY(T) \
+ case DataTypeToEnum<T>::value: { \
+ functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_; \
+ copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
+ from.flat<T>()); \
+ break; \
+ }
+#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
+ TF_CALL_GPU_ALL_TYPES(T); \
+ TF_CALL_int32(T); \
+ TF_CALL_int64(T);
+INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
+ GPU_DENSE_COPY);
+#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
+#undef GPU_DENSE_COPY
+#endif // GOOGLE_CUDA
+
+#undef CPU_DENSE_COPY
+#undef INSTANTIATE_GET_VARIANT_COPY_FN
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h
index 4aefe26c54..240c13261e 100644
--- a/tensorflow/core/kernels/dense_update_functor.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -19,11 +19,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
@@ -89,6 +92,17 @@ struct DenseUpdate<SYCLDevice, T, ASSIGN> {
#endif // TENSORFLOW_USE_SYCL
} // end namespace functor
+
+template <typename Device>
+Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
+
+template <>
+Status VariantCopyFn<CPUDevice>(OpKernelContext* context, const Tensor& from,
+ Tensor* to);
+template <>
+Status VariantCopyFn<GPUDevice>(OpKernelContext* context, const Tensor& from,
+ Tensor* to);
+
} // end namespace tensorflow
#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index 16ccb03b85..2c6e8bf3bc 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -28,6 +28,7 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
namespace functor {
@@ -50,7 +51,7 @@ SliceIndex HandleCopies(OpKernelContext* ctx,
}
// Compute slice_bytes here so that static knowledge is available
const size_t slice_bytes = slice_elems * sizeof(T);
- auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
+ auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
mutex mu;
// Store the value of invalidate index for printing error information, it's a
// shared variable.
@@ -162,6 +163,16 @@ struct GatherFunctor<CPUDevice, T, Index> {
}
};
+template <typename Index>
+struct GatherFunctor<GPUDevice, Variant, Index> {
+ int64 operator()(OpKernelContext* ctx,
+ typename TTypes<Variant, 3>::ConstTensor params,
+ typename TTypes<Index>::ConstFlat indices,
+ typename TTypes<Variant, 3>::Tensor out) {
+ return GatherFunctorCPU<Variant, Index>()(ctx, params, indices, out);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index f49a05c70a..72504200cc 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -280,64 +280,6 @@ class AssignVariableOp : public OpKernel {
};
template <typename Device>
-Status VariantCopyFn(OpKernelContext* context, const Tensor& from, Tensor* to);
-
-#define CPU_DENSE_COPY(T) \
- case DataTypeToEnum<T>::value: { \
- functor::DenseUpdate<CPUDevice, T, ASSIGN> copy_functor_; \
- copy_functor_(context->eigen_device<CPUDevice>(), tensor->flat<T>(), \
- from.flat<T>()); \
- break; \
- }
-
-#define INSTANTIATE_GET_VARIANT_COPY_FN(Device, TYPE_CALLER, TYPE_DENSE_COPY) \
- template <> \
- Status VariantCopyFn<Device>(OpKernelContext * context, const Tensor& from, \
- Tensor* to) { \
- PersistentTensor tmp; \
- Tensor* tensor; \
- AllocatorAttributes attr; \
- attr.set_gpu_compatible(true); \
- attr.set_nic_compatible(true); \
- TF_RETURN_IF_ERROR(context->allocate_persistent( \
- from.dtype(), from.shape(), &tmp, &tensor, attr)); \
- switch (from.dtype()) { \
- TYPE_CALLER(TYPE_DENSE_COPY); \
- default: \
- return errors::InvalidArgument( \
- "VariantCopyFn: Could not perform a deep copy of variant " \
- "element of type: ", \
- DataTypeString(from.dtype()), \
- " using device: ", context->device()->name()); \
- } \
- *to = *tensor; \
- return Status::OK(); \
- }
-
-INSTANTIATE_GET_VARIANT_COPY_FN(CPUDevice, TF_CALL_ALL_TYPES, CPU_DENSE_COPY);
-
-#if GOOGLE_CUDA
-#define GPU_DENSE_COPY(T) \
- case DataTypeToEnum<T>::value: { \
- functor::DenseUpdate<GPUDevice, T, ASSIGN> copy_functor_; \
- copy_functor_(context->eigen_device<GPUDevice>(), tensor->flat<T>(), \
- from.flat<T>()); \
- break; \
- }
-#define TF_CALL_GPU_AND_ADDITIONAL_TYPES(T) \
- TF_CALL_GPU_ALL_TYPES(T); \
- TF_CALL_int32(T); \
- TF_CALL_int64(T);
-INSTANTIATE_GET_VARIANT_COPY_FN(GPUDevice, TF_CALL_GPU_AND_ADDITIONAL_TYPES,
- GPU_DENSE_COPY);
-#undef TF_CALL_GPU_AND_ADDITIONAL_TYPES
-#undef GPU_DENSE_COPY
-#endif // GOOGLE_CUDA
-
-#undef CPU_DENSE_COPY
-#undef INSTANTIATE_GET_VARIANT_COPY_FN
-
-template <typename Device>
class AssignVariableOp<Device, Variant> : public OpKernel {
public:
explicit AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {
@@ -370,9 +312,16 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
// Copying is unnecessary if we are the last user of the value
// tensor, we can just adopt the input tensor's buffer instead.
// Note that Variant objects themselves always reside on host.
+ //
+ // We nevertheless want to signal to the runtime that the tensor
+ // should reside in memory of the associated device, as Variant
+ // tensors may be marked as sitting on either CPU or GPU. This
+ // helps to elide one or more copies.
std::unique_ptr<Tensor> input_alias = context->forward_input(
1, OpKernelContext::Params::kNoReservation /*output_index*/, DT_VARIANT,
- value.shape(), HOST_MEMORY, attr);
+ value.shape(),
+ std::is_same<Device, CPUDevice>::value ? HOST_MEMORY : DEVICE_MEMORY,
+ attr);
mutex_lock ml(*variable->mu());
variable->is_initialized = true;
@@ -396,12 +345,8 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
const auto elements_in = value.flat<Variant>();
auto elements_out = variable->tensor()->flat<Variant>();
- auto copy_fn = std::bind(&VariantCopyFn<Device>, context,
- std::placeholders::_1, std::placeholders::_2);
for (int64 i = 0; i < elements_in.size(); ++i) {
- OP_REQUIRES_OK(context, VariantDeviceCopy(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- elements_in(i), &elements_out(i), copy_fn));
+ elements_out(i) = elements_in(i);
}
}
@@ -560,7 +505,14 @@ class ResourceGatherOp : public OpKernel {
}
Tensor* out = nullptr;
- OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
+ Tensor tmp;
+ if (params.dtype() == DT_VARIANT) {
+ tmp = Tensor(DT_VARIANT, result_shape);
+ c->set_output(0, tmp);
+ out = &tmp;
+ } else {
+ OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
+ }
if (N > 0) {
const int64 gather_dim_size = params.dim_size(0);
int64 inner_size = 1;
@@ -607,6 +559,23 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_CPU);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_GPU);
+// Variant objects themselves sit on CPU, even if they contain data
+// pointing to a device.
+REGISTER_KERNEL_BUILDER(Name("ResourceGather")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceGatherOp<GPUDevice, Variant, int32>)
+REGISTER_KERNEL_BUILDER(Name("ResourceGather")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int64>("Tindices"),
+ ResourceGatherOp<GPUDevice, Variant, int64>)
+
#endif // GOOGLE_CUDA
#undef REGISTER_GATHER_CPU
@@ -721,6 +690,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_CPU);
REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
scatter_op::UpdateOp::ASSIGN);
+REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
+ scatter_op::UpdateOp::ASSIGN);
// Registers GPU kernels.
#if GOOGLE_CUDA
@@ -733,6 +704,23 @@ REGISTER_SCATTER_KERNEL(string, CPU, "ResourceScatterUpdate",
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int32>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, Variant, int32,
+ scatter_op::UpdateOp::ASSIGN>)
+REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("indices")
+ .TypeConstraint<Variant>("dtype")
+ .TypeConstraint<int64>("Tindices"),
+ ResourceScatterUpdateOp<GPUDevice, Variant, int64,
+ scatter_op::UpdateOp::ASSIGN>)
+
#endif // GOOGLE_CUDA
#undef REGISTER_SCATTER_ARITHMETIC
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index 52666645bf..ebaa2bd9c6 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -20,8 +20,11 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -203,9 +206,9 @@ struct ScatterFunctorBase {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -216,6 +219,42 @@ struct ScatterFunctorBase {
}
};
+template <typename Device, typename Index>
+struct ScatterFunctorVariantAssignBase {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<Variant>::Matrix params,
+ typename TTypes<Variant>::ConstMatrix updates,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ const Index cols = static_cast<Index>(params.dimension(1));
+ DCHECK_EQ(N, updates.dimension(0));
+ DCHECK_EQ(cols, updates.dimension(1));
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Copy last Ndim-1 dimensions of updates[i] to params[index]
+ for (int j = 0; j < cols; ++j) {
+ const Variant& to_scatter = updates(i, j);
+ params(index, j) = to_scatter;
+ }
+ }
+ return -1;
+ }
+};
+
+template <typename Index>
+struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
+ : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
+
+template <typename Index>
+struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
+ : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
+
#ifdef TENSORFLOW_USE_SYCL
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
@@ -227,9 +266,9 @@ struct ScatterFunctorBase<SYCLDevice, T, Index, op> {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -252,9 +291,10 @@ struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
const Index limit = static_cast<Index>(params.dimension(0));
if (!std::is_same<T, string>::value) {
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in
+ // between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
memmove(params.data() + index * params.dimension(1),
@@ -263,9 +303,10 @@ struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
}
} else {
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in
+ // between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Copy last Ndim-1 dimensions of updates[i] to params[index]
@@ -321,9 +362,9 @@ struct ScatterScalarFunctorBase {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Broadcast update to params[index]
@@ -334,6 +375,41 @@ struct ScatterScalarFunctorBase {
}
};
+template <typename Device, typename Index>
+struct ScatterScalarFunctorVariantAssignBase {
+ Index operator()(OpKernelContext* c, const Device& d,
+ typename TTypes<Variant>::Matrix params,
+ const typename TTypes<Variant>::ConstScalar update,
+ typename TTypes<Index>::ConstFlat indices) {
+ // indices and params sizes were validated in DoCompute().
+ const Index N = static_cast<Index>(indices.size());
+ const Index limit = static_cast<Index>(params.dimension(0));
+ const Index cols = static_cast<Index>(params.dimension(1));
+ const Variant& to_scatter = update();
+ for (Index i = 0; i < N; i++) {
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
+ const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
+ if (!FastBoundsCheck(index, limit)) return i;
+ // Broadcast update to params[index]
+ for (Index j = 0; j < cols; ++j) {
+ params(index, j) = to_scatter;
+ }
+ }
+ return -1;
+ }
+};
+
+template <typename Index>
+struct ScatterScalarFunctor<CPUDevice, Variant, Index,
+ scatter_op::UpdateOp::ASSIGN>
+ : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
+template <typename Index>
+struct ScatterScalarFunctor<GPUDevice, Variant, Index,
+ scatter_op::UpdateOp::ASSIGN>
+ : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
+
#ifdef TENSORFLOW_USE_SYCL
template <typename T, typename Index, scatter_op::UpdateOp op>
struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
@@ -345,9 +421,9 @@ struct ScatterScalarFunctorBase<SYCLDevice, T, Index, op> {
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Broadcast update to params[index]
@@ -370,9 +446,9 @@ struct ScatterScalarFunctorBase<CPUDevice, T, Index,
const Index N = static_cast<Index>(indices.size());
const Index limit = static_cast<Index>(params.dimension(0));
for (Index i = 0; i < N; i++) {
- // Grab the index and check its validity. An earlier version of the
- // code checked it and then grabbed it from memory a second time, which
- // was a security risk since it could have changed in between.
+ // Grab the index and check its validity. Do this carefully,
+ // to avoid checking the value and grabbing it again from
+ // memory a second time (a security risk since it may change in between).
const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
if (!FastBoundsCheck(index, limit)) return i;
// Broadcast update to params[index]
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index f6e2a5ae25..857daae177 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/variable_ops.h"
@@ -40,14 +41,27 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) {
// updating.
PersistentTensor unused;
Tensor* tmp;
- AllocatorAttributes attr;
- attr.set_gpu_compatible(true);
- attr.set_nic_compatible(true);
- TF_RETURN_IF_ERROR(ctx->allocate_persistent(
- tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
- functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
- copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
- const_cast<const Tensor*>(tensor)->flat<T>());
+ if (std::is_same<T, Variant>::value) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
+
+ const auto elements_in = tensor->flat<Variant>();
+ auto elements_out = tmp->flat<Variant>();
+ for (int64 i = 0; i < elements_in.size(); ++i) {
+ elements_out(i) = elements_in(i);
+ }
+ } else {
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ tensor->dtype(), tensor->shape(), &unused, &tmp, attr));
+ functor::DenseUpdate<Device, T, ASSIGN> copy_functor;
+ copy_functor(ctx->eigen_device<Device>(), tmp->flat<T>(),
+ const_cast<const Tensor*>(tensor)->flat<T>());
+ }
*tensor = *tmp;
}
return Status::OK();
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index cad617638f..c151055ee6 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -30,7 +30,8 @@ REGISTER_OP("EmptyTensorList")
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{s, t}});
return Status::OK();
@@ -193,6 +194,7 @@ REGISTER_OP("TensorListReserve")
.Attr("element_dtype: type")
.Attr("shape_type: {int32, int64}")
.SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
DataType t;
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 64b0fa6c00..8cf24206ed 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -822,17 +822,32 @@ def constant_value_as_shape(tensor): # pylint: disable=invalid-name
all-or-nothing.
Args:
- tensor: The rank-1 Tensor to be evaluated.
+ tensor: The rank-0 or rank-1 Tensor to be evaluated.
Returns:
A `TensorShape` based on the constant value of the given `tensor`.
+
+ Raises:
+ ValueError: If the shape is rank-0 and is not statically known to be -1.
"""
if isinstance(tensor, ops.EagerTensor):
return tensor_shape.as_shape(
[dim if dim != -1 else None for dim in tensor.numpy()])
+ if tensor.get_shape().ndims == 0:
+ value = constant_value(tensor)
+ if value is None:
+ raise ValueError(
+ "Received a scalar with unknown value as shape; require a statically "
+ "known scalar with value '-1' to describe an unknown shape.")
+ if value != -1:
+ raise ValueError(
+ "Received a scalar value '%s' as shape; require a statically known "
+ "scalar with value '-1' to describe an unknown shape." % value)
+ return tensor_shape.unknown_shape()
+
shape = tensor.get_shape().with_rank(1)
- if tensor.get_shape() == [0]:
+ if shape == [0]:
return tensor_shape.scalar()
elif tensor.op.type == "Shape":
return tensor.op.inputs[0].get_shape()
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index dbbed39c72..d969f0e03a 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -31,8 +31,11 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@@ -43,71 +46,83 @@ def scalar_shape():
class ListOpsTest(test_util.TensorFlowTestCase):
+ @test_util.run_in_graph_and_eager_modes()
def testPushPop(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=scalar_shape())
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, 1.0)
+ self.assertAllEqual(self.evaluate(e), 1.0)
+ @test_util.run_in_graph_and_eager_modes()
def testPushPopGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testPushPop()
+ @test_util.run_in_graph_and_eager_modes()
def testStack(self):
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
element_shape=scalar_shape())
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
- self.assertAllEqual(t, [1.0, 2.0])
+ self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testStackGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testStack()
+ @test_util.run_in_graph_and_eager_modes()
def testTensorListFromTensor(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, 2.0)
+ self.assertAllEqual(self.evaluate(e), 2.0)
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, 1.0)
- self.assertAllEqual(list_ops.tensor_list_length(l), 0)
+ self.assertAllEqual(self.evaluate(e), 1.0)
+ self.assertAllEqual(self.evaluate(list_ops.tensor_list_length(l)), 0)
+ @test_util.run_in_graph_and_eager_modes()
def testFromTensorGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testTensorListFromTensor()
+ @test_util.run_in_graph_and_eager_modes()
def testGetSetItem(self):
t = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
- self.assertAllEqual(e0, 1.0)
+ self.assertAllEqual(self.evaluate(e0), 1.0)
l = list_ops.tensor_list_set_item(l, 0, 3.0)
t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
- self.assertAllEqual(t, [3.0, 2.0])
+ self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testGetSetGPU(self):
if not context.num_gpus():
return
with context.device("gpu:0"):
self.testGetSetItem()
+ @test_util.run_in_graph_and_eager_modes()
def testUnknownShape(self):
- l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
- element_shape=-1)
+ l = list_ops.empty_tensor_list(
+ element_dtype=dtypes.float32, element_shape=-1)
l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0, 2.0]))
- _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
- self.assertAllEqual(e, [1.0, 2.0])
+ l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(e), [1.0, 2.0])
+ l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(self.evaluate(e), 1.0)
+ @test_util.run_in_graph_and_eager_modes()
def testCPUGPUCopy(self):
if not context.num_gpus():
return
@@ -116,15 +131,16 @@ class ListOpsTest(test_util.TensorFlowTestCase):
with context.device("gpu:0"):
l_gpu = array_ops.identity(l)
self.assertAllEqual(
- list_ops.tensor_list_pop_back(
- l_gpu, element_dtype=dtypes.float32)[1],
- 2.0)
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
l_cpu = array_ops.identity(l_gpu)
self.assertAllEqual(
- list_ops.tensor_list_pop_back(
- l_cpu, element_dtype=dtypes.float32)[1],
- 2.0)
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStack(self):
with context.graph_mode(), self.test_session():
tl = list_ops.empty_tensor_list(
@@ -132,9 +148,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
element_dtype=dtypes.int32)
tl = list_ops.tensor_list_push_back(tl, [1])
self.assertAllEqual(
- list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
+ self.evaluate(
+ list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32)),
[[1]])
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStackInLoop(self):
with context.graph_mode(), self.test_session():
t1 = list_ops.empty_tensor_list(
@@ -149,9 +167,10 @@ class ListOpsTest(test_util.TensorFlowTestCase):
i, t1 = control_flow_ops.while_loop(lambda i, t1: math_ops.less(i, 4),
body, [i, t1])
- s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32).eval()
- self.assertAllEqual(s1, [0, 1, 2, 3])
+ s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.int32)
+ self.assertAllEqual(self.evaluate(s1), [0, 1, 2, 3])
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStackSwitchDtype(self):
with context.graph_mode(), self.test_session():
list_ = list_ops.empty_tensor_list(
@@ -169,11 +188,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
for _ in range(2):
list_, m = body(list_, m)
- s1 = list_ops.tensor_list_stack(
- list_, element_dtype=dtypes.float32).eval()
+ s1 = list_ops.tensor_list_stack(list_, element_dtype=dtypes.float32)
np_s1 = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)
- self.assertAllEqual(s1, np_s1)
+ self.assertAllEqual(self.evaluate(s1), np_s1)
+ @test_util.run_in_graph_and_eager_modes()
def testGraphStackInLoopSwitchDtype(self):
with context.graph_mode(), self.test_session():
t1 = list_ops.empty_tensor_list(
@@ -193,10 +212,11 @@ class ListOpsTest(test_util.TensorFlowTestCase):
i, m, t1 = control_flow_ops.while_loop(
lambda i, m, t1: math_ops.less(i, 4), body, [i, m, t1])
- s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32).eval()
+ s1 = list_ops.tensor_list_stack(t1, element_dtype=dtypes.float32)
np_s1 = np.vstack([np.arange(1, 4) * i for i in range(4)])
- self.assertAllEqual(s1, np_s1)
+ self.assertAllEqual(self.evaluate(s1), np_s1)
+ @test_util.run_in_graph_and_eager_modes()
def testSerialize(self):
# pylint: disable=g-import-not-at-top
try:
@@ -226,8 +246,9 @@ class ListOpsTest(test_util.TensorFlowTestCase):
l_ps, element_dtype=dtypes.float32)
with ops.device("/job:worker"):
worker_e = array_ops.identity(e)
- self.assertAllEqual(worker_e.eval(), [2.0])
+ self.assertAllEqual(self.evaluate(worker_e), [2.0])
+ @test_util.run_in_graph_and_eager_modes()
def testPushPopGradients(self):
with backprop.GradientTape() as tape:
l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
@@ -237,18 +258,24 @@ class ListOpsTest(test_util.TensorFlowTestCase):
l = list_ops.tensor_list_push_back(l, c)
l, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
e = 2 * e
- self.assertAllEqual(tape.gradient(e, [c])[0], 2.0)
+ self.assertAllEqual(self.evaluate(tape.gradient(e, [c])[0]), 2.0)
+ @test_util.run_in_graph_and_eager_modes()
def testStackFromTensorGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
tape.watch(c)
l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
- c2 = list_ops.tensor_list_stack(
- l, element_dtype=dtypes.float32)
+ c2 = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
result = c2 * 2.0
- self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
-
+ if context.in_eager_mode():
+ # TODO(b/77609620): Fix this in graph mode.
+ grad = tape.gradient(result, [c])[0]
+ else:
+ grad = gradients.gradients(result, [c])[0]
+ self.assertAllEqual(self.evaluate(grad), [2.0, 2.0])
+
+ @test_util.run_in_graph_and_eager_modes()
def testGetSetGradients(self):
with backprop.GradientTape() as tape:
c = constant_op.constant([1.0, 2.0])
@@ -261,16 +288,40 @@ class ListOpsTest(test_util.TensorFlowTestCase):
ee = list_ops.tensor_list_get_item(l, 1, element_dtype=dtypes.float32)
y = e * e + ee * ee
grad_c, grad_c2 = tape.gradient(y, [c, c2])
- self.assertAllEqual(grad_c, [0.0, 4.0])
- self.assertAllEqual(grad_c2, 6.0)
+ self.assertAllEqual(self.evaluate(grad_c), [0.0, 4.0])
+ self.assertAllEqual(self.evaluate(grad_c2), 6.0)
+ @test_util.run_in_graph_and_eager_modes()
def testSetOutOfBounds(self):
c = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
with self.assertRaises(errors.InvalidArgumentError):
- list_ops.tensor_list_set_item(l, 20, 3.0)
+ self.evaluate(list_ops.tensor_list_set_item(l, 20, 3.0))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testResourceVariableScatterGather(self):
+ c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
+ l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
+ v = vs.get_variable("var", initializer=[l] * 10, use_resource=True)
+ v_r_0_stacked = list_ops.tensor_list_stack(v[0], dtypes.float32)
+ self.evaluate(v.initializer)
+ self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_0_stacked))
+ v_r_sparse_stacked = list_ops.tensor_list_stack(
+ v.sparse_read(0), dtypes.float32)
+ self.assertAllEqual([1.0, 2.0], self.evaluate(v_r_sparse_stacked))
+ l_new_0 = list_ops.tensor_list_from_tensor(
+ [3.0, 4.0], element_shape=scalar_shape())
+ l_new_1 = list_ops.tensor_list_from_tensor(
+ [5.0, 6.0], element_shape=scalar_shape())
+ updated_v = state_ops.scatter_update(v, [3, 5], [l_new_0, l_new_1])
+ updated_v_elems = array_ops.unstack(updated_v)
+ updated_v_stacked = [
+ list_ops.tensor_list_stack(el, dtypes.float32) for el in updated_v_elems
+ ]
+ expected = ([[1.0, 2.0]] * 3 + [[3.0, 4.0], [1.0, 2.0], [5.0, 6.0]] +
+ [[1.0, 2.0]] * 4)
+ self.assertAllEqual(self.evaluate(updated_v_stacked), expected)
if __name__ == "__main__":
- ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index bba59ebcef..bdf0774bbf 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -54,8 +54,8 @@ def _TensorListStackGrad(unused_op, dtensor):
@ops.RegisterGradient("TensorListFromTensor")
def _TensorListFromTensorGrad(op, dlist):
"""Gradient for TensorListFromTensor."""
- if op.inputs[0].shape[0] is not None:
- num_elements = op.inputs[0].shape[0]
+ if op.inputs[0].shape[0].value is not None:
+ num_elements = op.inputs[0].shape[0].value
else:
num_elements = None
if dlist is None:
@@ -63,9 +63,10 @@ def _TensorListFromTensorGrad(op, dlist):
element_dtype=op.inputs[0].dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
- return gen_list_ops.tensor_list_stack(
- dlist, element_dtype=op.inputs[0].dtype,
- num_elements=num_elements)
+ tensor_grad = gen_list_ops.tensor_list_stack(
+ dlist, element_dtype=op.inputs[0].dtype, num_elements=num_elements)
+ shape_grad = None
+ return tensor_grad, shape_grad
@ops.RegisterGradient("TensorListGetItem")