aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-07 06:08:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-07 06:13:19 -0700
commit7c2cfdef2693f17295471f717f86f31a3d308659 (patch)
treee7439eec9d86b164b5c7916e9ae6369f38d0daea /tensorflow
parent4e0f4e462f80e2dc84aa38da3fbc39cb15da3482 (diff)
Add Stack*V2 operators, which use a resource handle instead of a Ref tensor to store the stack handle.
Change in preparation for adding support for loop gradients in XLA-compiled graphs. XLA compilation requires the use of resource types instead of Ref types, and also requires statically known shapes. PiperOrigin-RevId: 161194280
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/cc/ops/op_gen_overrides.pbtxt4
-rw-r--r--tensorflow/core/kernels/stack_ops.cc152
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc61
-rw-r--r--tensorflow/core/public/version.h1
-rw-r--r--tensorflow/python/kernel_tests/stack_ops_test.py159
-rw-r--r--tensorflow/python/ops/hidden_ops.txt4
6 files changed, 325 insertions, 56 deletions
diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt
index 1dffb10c03..a1f79177f7 100644
--- a/tensorflow/cc/ops/op_gen_overrides.pbtxt
+++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt
@@ -100,6 +100,10 @@ op { name: "Stack" skip: true }
op { name: "StackClose" skip: true }
op { name: "StackPop" skip: true }
op { name: "StackPush" skip: true }
+op { name: "StackV2" skip: true }
+op { name: "StackCloseV2" skip: true }
+op { name: "StackPopV2" skip: true }
+op { name: "StackPushV2" skip: true }
op { name: "TensorArrayCloseV2" skip: true }
op { name: "TensorArrayCloseV3" rename_to: "TensorArrayClose" }
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 2db3e5ef77..a474e75d6a 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -54,12 +54,19 @@ class Stack : public ResourceBase {
bool swapped_to_cpu;
};
- Stack(const DataType& elem_type, const Tensor& handle)
- : elem_type_(elem_type), handle_(handle), closed_(false) {}
+ Stack(const DataType& elem_type, const string& stack_name, int max_size)
+ : elem_type_(elem_type),
+ stack_name_(stack_name),
+ max_size_(max_size),
+ closed_(false) {}
Status Push(const TensorAndAllocation& value) {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(CheckNotClosed());
+ if (max_size_ >= 0 && stack_.size() >= max_size_) {
+ return errors::InvalidArgument("Stack[", stack_name_, "] overflowed ",
+ "its max_size (", max_size_, ")");
+ }
stack_.push_back(value);
return Status::OK();
}
@@ -68,8 +75,7 @@ class Stack : public ResourceBase {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(CheckNotClosed());
if (stack_.empty()) {
- const string& stack_name = handle_.vec<string>()(1);
- return errors::InvalidArgument("Stack[", stack_name,
+ return errors::InvalidArgument("Stack[", stack_name_,
"] is empty when calling Pop().");
}
*value = stack_.back();
@@ -98,25 +104,26 @@ class Stack : public ResourceBase {
string DebugString() override {
mutex_lock l(mu_);
- const string& stack_name = handle_.vec<string>()(1);
- return strings::StrCat("Stack[", stack_name, "]");
+ return strings::StrCat("Stack[", stack_name_, "]");
}
+ const string& stack_name() { return stack_name_; }
+
private:
friend class StackOp;
mutex* mu() { return &mu_; }
- Tensor* handle() { return &handle_; }
mutable mutex mu_;
DataType elem_type_;
+ const string stack_name_;
Tensor handle_;
+ int max_size_;
bool closed_ GUARDED_BY(mu_);
std::vector<TensorAndAllocation> stack_ GUARDED_BY(mu_);
Status CheckNotClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (closed_) {
- const string& stack_name = handle_.vec<string>()(1);
- return errors::InvalidArgument("Stack[", stack_name,
+ return errors::InvalidArgument("Stack[", stack_name_,
"] has already been closed.");
}
return Status::OK();
@@ -124,20 +131,26 @@ class Stack : public ResourceBase {
};
Status GetStack(OpKernelContext* ctx, Stack** stack) {
- Tensor Tstack_handle = ctx->mutable_input(0, false);
- if (Tstack_handle.NumElements() != 2) {
- return errors::InvalidArgument(
- "Stack handle must have two elements, but had shape: ",
- Tstack_handle.shape().DebugString());
+ string key;
+ if (ctx->input_dtype(0) == DT_RESOURCE) {
+ auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ key = resource.name();
+ } else {
+ Tensor Tstack_handle = ctx->mutable_input(0, false);
+ if (Tstack_handle.NumElements() != 2) {
+ return errors::InvalidArgument(
+ "Stack handle must have two elements, but had shape: ",
+ Tstack_handle.shape().DebugString());
+ }
+ const string& container = Tstack_handle.flat<string>()(0);
+ const string& stack_name = Tstack_handle.flat<string>()(1);
+ key = strings::StrCat(container, stack_name);
}
- const string& container = Tstack_handle.flat<string>()(0);
- const string& stack_name = Tstack_handle.flat<string>()(1);
ResourceMgr* rm = ctx->resource_manager();
if (rm == nullptr) {
return errors::Internal("No resource manager.");
}
- TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(),
- strings::StrCat(container, stack_name), stack));
+ TF_RETURN_IF_ERROR(rm->Lookup(ctx->step_container()->name(), key, stack));
return Status::OK();
}
@@ -154,25 +167,48 @@ class StackOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- // Create the stack handle.
- Tensor stack_handle;
- AllocatorAttributes alloc_attr;
- alloc_attr.set_on_host(true);
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING,
- tensorflow::TensorShape({2}),
- &stack_handle, alloc_attr));
+ int32 size = std::numeric_limits<int32>::max();
+ if (ctx->num_inputs() > 0) {
+ const Tensor* tensor_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_size", &tensor_size));
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_size->shape()),
+ errors::InvalidArgument(
+ "Stack size must be a scalar, but had shape: ",
+ tensor_size->shape().DebugString()));
+
+ int32 size_value = tensor_size->scalar<int32>()();
+ if (size_value >= 0) {
+ size = size_value;
+ }
+ }
+
+ static const char kContainer[] = "_stacks";
auto stack_id = Stack::stack_counter.fetch_add(1);
- auto handle = stack_handle.flat<string>();
- handle(0) = "_stacks";
- handle(1) = strings::StrCat(stack_name_, "_", stack_id);
+ string stack_name = strings::StrCat(stack_name_, "_", stack_id);
// Store the handle in a per-step container.
ResourceMgr* rm = ctx->resource_manager();
OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
- Stack* stack = new Stack(elem_type_, stack_handle);
- OP_REQUIRES_OK(ctx,
- rm->Create(ctx->step_container()->name(),
- strings::StrCat(handle(0), handle(1)), stack));
- ctx->set_output_ref(0, stack->mu(), stack->handle());
+ string key = strings::StrCat(kContainer, stack_name);
+ Stack* stack = new Stack(elem_type_, stack_name, size);
+ OP_REQUIRES_OK(ctx, rm->Create(ctx->step_container()->name(), key, stack));
+ if (IsRefType(ctx->expected_output_dtype(0))) {
+ // Create the stack handle.
+ AllocatorAttributes alloc_attr;
+ alloc_attr.set_on_host(true);
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_STRING,
+ tensorflow::TensorShape({2}),
+ &stack->handle_, alloc_attr));
+ auto handle = stack->handle_.flat<string>();
+ handle(0) = kContainer;
+ handle(1) = std::move(stack_name);
+ ctx->set_output_ref(0, stack->mu(), &stack->handle_);
+ } else {
+ Tensor* handle;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
+ handle->flat<ResourceHandle>()(0) =
+ MakePerStepResourceHandle<Stack>(ctx, key);
+ }
}
private:
@@ -185,9 +221,20 @@ class StackOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_CPU), StackOp);
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_GPU).HostMemory("handle"),
StackOp);
+REGISTER_KERNEL_BUILDER(Name("StackV2").Device(DEVICE_CPU), StackOp);
+REGISTER_KERNEL_BUILDER(Name("StackV2")
+ .Device(DEVICE_GPU)
+ .HostMemory("max_size")
+ .HostMemory("handle"),
+ StackOp);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("Stack").Device(DEVICE_SYCL).HostMemory("handle"),
StackOp);
+REGISTER_KERNEL_BUILDER(Name("StackV2")
+ .Device(DEVICE_SYCL)
+ .HostMemory("max_size")
+ .HostMemory("handle"),
+ StackOp);
#endif // TENSORFLOW_USE_SYCL
template <typename Device>
@@ -272,12 +319,19 @@ class StackPushOp : public AsyncOpKernel {
REGISTER_KERNEL_BUILDER(Name("StackPush").Device(DEVICE_CPU),
StackPushOp<CPUDevice>);
+REGISTER_KERNEL_BUILDER(Name("StackPushV2").Device(DEVICE_CPU),
+ StackPushOp<CPUDevice>);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("StackPush") \
.Device(DEVICE_GPU) \
.HostMemory("handle") \
.TypeConstraint<type>("T"), \
+ StackPushOp<GPUDevice>); \
+ REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("T"), \
StackPushOp<GPUDevice>);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
@@ -293,7 +347,14 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
.HostMemory("elem") \
.HostMemory("output") \
.TypeConstraint<type>("T"), \
- StackPushOp<GPUDevice>)
+ StackPushOp<GPUDevice>); \
+ REGISTER_KERNEL_BUILDER(Name("StackPushV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .HostMemory("elem") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ StackPushOp<GPUDevice>);
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
@@ -368,13 +429,19 @@ class StackPopOp : public AsyncOpKernel {
};
REGISTER_KERNEL_BUILDER(Name("StackPop").Device(DEVICE_CPU), StackPopOp);
+REGISTER_KERNEL_BUILDER(Name("StackPopV2").Device(DEVICE_CPU), StackPopOp);
#define REGISTER_GPU_KERNEL(type) \
REGISTER_KERNEL_BUILDER(Name("StackPop") \
.Device(DEVICE_GPU) \
.HostMemory("handle") \
.TypeConstraint<type>("elem_type"), \
- StackPopOp)
+ StackPopOp); \
+ REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .TypeConstraint<type>("elem_type"), \
+ StackPopOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
@@ -388,7 +455,13 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
.HostMemory("handle") \
.HostMemory("elem") \
.TypeConstraint<type>("elem_type"), \
- StackPopOp)
+ StackPopOp); \
+ REGISTER_KERNEL_BUILDER(Name("StackPopV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("handle") \
+ .HostMemory("elem") \
+ .TypeConstraint<type>("elem_type"), \
+ StackPopOp);
REGISTER_GPU_HOST_KERNEL(int32);
REGISTER_GPU_HOST_KERNEL(bool);
@@ -437,8 +510,15 @@ class StackCloseOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("StackClose").Device(DEVICE_CPU), StackCloseOp);
REGISTER_KERNEL_BUILDER(
Name("StackClose").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp);
+REGISTER_KERNEL_BUILDER(Name("StackCloseV2").Device(DEVICE_CPU), StackCloseOp);
+REGISTER_KERNEL_BUILDER(
+ Name("StackCloseV2").Device(DEVICE_GPU).HostMemory("handle"), StackCloseOp);
#ifdef TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(
Name("StackClose").Device(DEVICE_SYCL).HostMemory("handle"), StackCloseOp);
+REGISTER_KERNEL_BUILDER(
+ Name("StackCloseV2").Device(DEVICE_SYCL).HostMemory("handle"),
+ StackCloseOp);
#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index 4bf3d24324..51a964e3e3 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -1114,8 +1114,9 @@ dtype: The data type of accumulated gradients. Needs to correspond to the type
// --------------------------------------------------------------------------
-REGISTER_OP("Stack")
- .Output("handle: Ref(string)")
+REGISTER_OP("StackV2")
+ .Input("max_size: int32")
+ .Output("handle: resource")
.Attr("elem_type: type")
.Attr("stack_name: string = ''")
.SetIsStateful()
@@ -1123,14 +1124,16 @@ REGISTER_OP("Stack")
.Doc(R"doc(
A stack that produces elements in first-in last-out order.
+max_size: The maximum size of the stack if non-negative. If negative, the stack
+ size is unlimited.
handle: The handle to the stack.
elem_type: The type of the elements on the stack.
stack_name: Overrides the name used for the temporary stack resource. Default
value is the name of the 'Stack' op (which is guaranteed unique).
)doc");
-REGISTER_OP("StackPush")
- .Input("handle: Ref(string)")
+REGISTER_OP("StackPushV2")
+ .Input("handle: resource")
.Input("elem: T")
.Output("output: T")
.Attr("T: type")
@@ -1148,8 +1151,8 @@ output: The same tensor as the input 'elem'.
swap_memory: Swap `elem` to CPU. Default to false.
)doc");
-REGISTER_OP("StackPop")
- .Input("handle: Ref(string)")
+REGISTER_OP("StackPopV2")
+ .Input("handle: resource")
.Output("elem: elem_type")
.Attr("elem_type: type")
.SetShapeFn(shape_inference::UnknownShape)
@@ -1161,8 +1164,8 @@ elem: The tensor that is popped from the top of the stack.
elem_type: The type of the elem that is popped.
)doc");
-REGISTER_OP("StackClose")
- .Input("handle: Ref(string)")
+REGISTER_OP("StackCloseV2")
+ .Input("handle: resource")
.SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
.Doc(R"doc(
Delete the stack from its resource container.
@@ -1170,6 +1173,48 @@ Delete the stack from its resource container.
handle: The handle to a stack.
)doc");
+// Deprecated ref-typed variants of stack.
+
+REGISTER_OP("Stack")
+ .Output("handle: Ref(string)")
+ .Attr("elem_type: type")
+ .Attr("stack_name: string = ''")
+ .SetIsStateful()
+ .SetShapeFn(TwoElementOutput)
+ .Doc(R"doc(
+Deprecated, use StackV2.
+)doc");
+
+REGISTER_OP("StackPush")
+ .Input("handle: Ref(string)")
+ .Input("elem: T")
+ .Output("output: T")
+ .Attr("T: type")
+ .Attr("swap_memory: bool = false")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Deprecated, use StackPushV2.
+)doc");
+
+REGISTER_OP("StackPop")
+ .Input("handle: Ref(string)")
+ .Output("elem: elem_type")
+ .Attr("elem_type: type")
+ .SetShapeFn(shape_inference::UnknownShape)
+ .Doc(R"doc(
+Deprecated, use StackPopV2.
+)doc");
+
+REGISTER_OP("StackClose")
+ .Input("handle: Ref(string)")
+ .SetShapeFn(TwoElementVectorInputsAndScalarOutputs)
+ .Doc(R"doc(
+Deprecated, use StackCloseV2.
+)doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("TensorArrayV3")
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 0e5611e359..e76876c9dc 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -89,6 +89,7 @@ limitations under the License.
// produced at version 22 or later. (04/10/2016)
// 23. Remove NonMaxSuppression in favor of NonMaxSuppressionV2.
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
+// 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
diff --git a/tensorflow/python/kernel_tests/stack_ops_test.py b/tensorflow/python/kernel_tests/stack_ops_test.py
index 441256df26..aa409336f5 100644
--- a/tensorflow/python/kernel_tests/stack_ops_test.py
+++ b/tensorflow/python/kernel_tests/stack_ops_test.py
@@ -22,7 +22,6 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import control_flow_ops
@@ -35,10 +34,11 @@ class StackOpTest(test.TestCase):
def _testStackPushPop(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
- h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
- c = gen_data_flow_ops._stack_push(h, [[4.0, 5.0]])
+ h = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+ c = gen_data_flow_ops._stack_push_v2(h, [[4.0, 5.0]])
with ops.control_dependencies([c]):
- c1 = gen_data_flow_ops._stack_pop(h, dtypes.float32)
+ c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
self.assertAllClose([[4.0, 5.0]], c1.eval())
def testStackPushPop(self):
@@ -49,10 +49,11 @@ class StackOpTest(test.TestCase):
with self.test_session(use_gpu=use_gpu):
a = np.arange(2000)
x = constant_op.constant(a, dtype=dtypes.float32)
- h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
- c = gen_data_flow_ops._stack_push(h, x, swap_memory=True)
+ h = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+ c = gen_data_flow_ops._stack_push_v2(h, x, swap_memory=True)
with ops.control_dependencies([c]):
- c1 = gen_data_flow_ops._stack_pop(h, dtypes.float32)
+ c1 = gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
self.assertAllClose(a, c1.eval())
def testStackPushPopSwap(self):
@@ -62,7 +63,8 @@ class StackOpTest(test.TestCase):
def _testStackWhileSwap(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
n = constant_op.constant(0)
- h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
+ h = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
def c(x):
return math_ops.less(x, 10)
@@ -70,7 +72,7 @@ class StackOpTest(test.TestCase):
def b(x):
with ops.control_dependencies([x]):
a = constant_op.constant(np.ones(2000), dtype=dtypes.float32)
- v = gen_data_flow_ops._stack_push(h, a, swap_memory=True)
+ v = gen_data_flow_ops._stack_push_v2(h, a, swap_memory=True)
with ops.control_dependencies([v]):
return math_ops.add(x, 1)
@@ -79,14 +81,15 @@ class StackOpTest(test.TestCase):
v = constant_op.constant(np.zeros(2000), dtype=dtypes.float32)
def c1(x, y):
+ del y
return math_ops.greater(x, 0)
def b1(x, y):
nx = math_ops.subtract(x, 1)
- ny = y + gen_data_flow_ops._stack_pop(h, dtypes.float32)
+ ny = y + gen_data_flow_ops._stack_pop_v2(h, dtypes.float32)
return [nx, ny]
- rx, ry = control_flow_ops.while_loop(
+ _, ry = control_flow_ops.while_loop(
c1, b1, [r, v], [r.get_shape(), tensor_shape.unknown_shape()])
self.assertAllClose(np.ones(2000) * 10.0, ry.eval())
@@ -96,6 +99,102 @@ class StackOpTest(test.TestCase):
def _testMultiStack(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
+ h1 = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+ c1 = gen_data_flow_ops._stack_push_v2(h1, 4.0)
+ with ops.control_dependencies([c1]):
+ c1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
+ h2 = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="bar")
+ c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0)
+ with ops.control_dependencies([c2]):
+ c2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
+ r = c1 + c2
+ self.assertAllClose(9.0, r.eval())
+
+ def testMultiStack(self):
+ self._testMultiStack(use_gpu=False)
+ self._testMultiStack(use_gpu=True)
+
+ def _testSameNameStacks(self, use_gpu):
+ """Different stacks with the same name do not interfere."""
+ with self.test_session(use_gpu=use_gpu) as sess:
+ h1 = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+ h2 = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+
+ c1 = gen_data_flow_ops._stack_push_v2(h1, 4.0)
+ with ops.control_dependencies([c1]):
+ c2 = gen_data_flow_ops._stack_push_v2(h2, 5.0)
+ with ops.control_dependencies([c2]):
+ pop1 = gen_data_flow_ops._stack_pop_v2(h1, dtypes.float32)
+ pop2 = gen_data_flow_ops._stack_pop_v2(h2, dtypes.float32)
+
+ out1, out2 = sess.run([pop1, pop2])
+ self.assertAllClose(out1, 4.0)
+ self.assertAllClose(out2, 5.0)
+
+ def testSameNameStacks(self):
+ self._testSameNameStacks(use_gpu=False)
+ self._testSameNameStacks(use_gpu=True)
+
+ def _testCloseStack(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ h = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+ c1 = gen_data_flow_ops._stack_close_v2(h)
+ sess.run(c1)
+
+ def testCloseStack(self):
+ self._testCloseStack(use_gpu=False)
+ self._testCloseStack(use_gpu=True)
+
+ def _testPushCloseStack(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu) as sess:
+ h = gen_data_flow_ops._stack_v2(
+ -1, elem_type=dtypes.float32, stack_name="foo")
+ c = gen_data_flow_ops._stack_push_v2(h, [[4.0, 5.0]])
+ with ops.control_dependencies([c]):
+ c1 = gen_data_flow_ops._stack_close_v2(h)
+ sess.run(c1)
+
+ def testPushCloseStack(self):
+ self._testPushCloseStack(use_gpu=False)
+ self._testPushCloseStack(use_gpu=True)
+
+
+class StackOpRefTest(test.TestCase):
+ """Tests for deprecated non-resource variant of stack ops."""
+
+ def _testStackPushPop(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
+ c = gen_data_flow_ops._stack_push(h, [[4.0, 5.0]])
+ with ops.control_dependencies([c]):
+ c1 = gen_data_flow_ops._stack_pop(h, dtypes.float32)
+ self.assertAllClose([[4.0, 5.0]], c1.eval())
+
+ def testStackPushPop(self):
+ self._testStackPushPop(use_gpu=False)
+ self._testStackPushPop(use_gpu=True)
+
+ def _testStackPushPopSwap(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ a = np.arange(2000)
+ x = constant_op.constant(a, dtype=dtypes.float32)
+ h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
+ c = gen_data_flow_ops._stack_push(h, x, swap_memory=True)
+ with ops.control_dependencies([c]):
+ c1 = gen_data_flow_ops._stack_pop(h, dtypes.float32)
+ self.assertAllClose(a, c1.eval())
+
+ def testStackPushPopSwap(self):
+ self._testStackPushPopSwap(use_gpu=False)
+ self._testStackPushPopSwap(use_gpu=True)
+
+ def _testMultiStack(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
h1 = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
c1 = gen_data_flow_ops._stack_push(h1, 4.0)
with ops.control_dependencies([c1]):
@@ -107,6 +206,42 @@ class StackOpTest(test.TestCase):
r = c1 + c2
self.assertAllClose(9.0, r.eval())
+ def _testStackWhileSwap(self, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ n = constant_op.constant(0)
+ h = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
+
+ def c(x):
+ return math_ops.less(x, 10)
+
+ def b(x):
+ with ops.control_dependencies([x]):
+ a = constant_op.constant(np.ones(2000), dtype=dtypes.float32)
+ v = gen_data_flow_ops._stack_push(h, a, swap_memory=True)
+ with ops.control_dependencies([v]):
+ return math_ops.add(x, 1)
+
+ r = control_flow_ops.while_loop(c, b, [n])
+
+ v = constant_op.constant(np.zeros(2000), dtype=dtypes.float32)
+
+ def c1(x, y):
+ del y
+ return math_ops.greater(x, 0)
+
+ def b1(x, y):
+ nx = math_ops.subtract(x, 1)
+ ny = y + gen_data_flow_ops._stack_pop(h, dtypes.float32)
+ return [nx, ny]
+
+ _, ry = control_flow_ops.while_loop(
+ c1, b1, [r, v], [r.get_shape(), tensor_shape.unknown_shape()])
+ self.assertAllClose(np.ones(2000) * 10.0, ry.eval())
+
+ def testStackWhileSwap(self):
+ self._testStackWhileSwap(use_gpu=False)
+ self._testStackWhileSwap(use_gpu=True)
+
def testMultiStack(self):
self._testMultiStack(use_gpu=False)
self._testMultiStack(use_gpu=True)
@@ -117,7 +252,7 @@ class StackOpTest(test.TestCase):
c1 = gen_data_flow_ops._stack_push(h1, 4.0)
h2 = gen_data_flow_ops._stack(dtypes.float32, stack_name="foo")
c2 = gen_data_flow_ops._stack_push(h2, 5.0)
- r = c1 + c2
+ _ = c1 + c2
self.assertNotEqual(h1.eval()[1], h2.eval()[1])
def testSameNameStacks(self):
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 9aef6bffde..6408d52d8c 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -111,6 +111,10 @@ Stack
StackClose
StackPop
StackPush
+StackV2
+StackCloseV2
+StackPopV2
+StackPushV2
TensorArray
TensorArrayClose
TensorArrayCloseV2