aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 12:21:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 13:32:40 -0700
commitb7d5df182b7394ab17c11ccc949ce07812920bd9 (patch)
treeab244238a4c7fdb099a62ca8c2396fd8e15216c3
parent4323a658b5228fe8d5482941edfacf58506dea34 (diff)
Make (tf.contrib) BlockLSTMOp take 3D tensors instead of lists of 2D tensors.
This facilitates dealing with dynamic time lengths. Updated documentation. Change: 134699973
-rw-r--r--tensorflow/contrib/rnn/kernels/lstm_ops.cc403
-rw-r--r--tensorflow/contrib/rnn/kernels/lstm_ops.h23
-rw-r--r--tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc3
-rw-r--r--tensorflow/contrib/rnn/ops/lstm_ops.cc213
-rw-r--r--tensorflow/contrib/rnn/ops/lstm_ops_test.cc181
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py147
6 files changed, 611 insertions, 359 deletions
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc
index 2749d7d797..7fec457a4a 100644
--- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc
@@ -523,11 +523,118 @@ REGISTER_GPU_KERNEL(float);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
+namespace {
+
+// This helper class can be used to access timeslices of a 3D tensor. If a slice
+// happens to be unaligned (usually because both batch size and number of cells
+// are odd - this isn't common) this involves overhead, since data needs to be
+// copied. However, if all slices are aligned, the bits aren't copied. In the
+// cases where copying is needed, the outputs have to be recopied back.
+// At the end of each time step you should call FinishTimeStep which does this,
+// and also allows for reuse of temporary tensors.
+template <typename Device, typename T>
+class SliceHelper {
+ public:
+ SliceHelper(OpKernelContext* ctx)
+ : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {}
+
+ ~SliceHelper() {
+ CHECK(copy_out_.empty());
+ for (const auto& entry : pool_) {
+ CHECK(!entry.second.second); // nothing is in use
+ }
+ }
+
+ // Slice through an input tensor. This may copy unaligned slices, but no
+ // copying back will be done at the end.
+ const Tensor InputSlice(const Tensor& t, int pos, const string& name) {
+ Tensor res = UnalignedSlice(t, pos);
+ if (res.IsAligned()) {
+ return res;
+ } else {
+ return AlignTensor(res, name);
+ }
+ }
+
+ // Slice through an output tensor. This may copy unaligned slices, and
+ // schedule copying back on destruction.
+ Tensor OutputSlice(Tensor* t, int pos, const string& name) {
+ Tensor res = UnalignedSlice(*t, pos);
+ if (res.IsAligned()) {
+ return res;
+ } else {
+ Tensor aligned = AlignTensor(res, name);
+ copy_out_.emplace_back(res, aligned);
+ return aligned;
+ }
+ }
+
+ void FinishTimeStep() {
+ for (const auto& p : copy_out_) {
+ const Tensor& aligned = p.second;
+ Tensor original = p.first;
+ // Copy from aligned back to original.
+ functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(),
+ original.unaligned_flat<T>());
+ }
+ copy_out_.clear();
+ // Mark all entries as not in use.
+ for (auto& entry : pool_) {
+ entry.second.second = false;
+ }
+ }
+
+ private:
+ // Return a slice at position 'pos'. Result may be unaligned. The resulting
+ // tensor always shares data with the source tensor.
+ Tensor UnalignedSlice(const Tensor& t, int pos) const {
+ Tensor res;
+ // CHECK should never fail here, since the number of elements must match
+ CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)}));
+ return res;
+ }
+
+ // Assumes input is not aligned, creates a temporary aligned tensor of the
+ // same shape and copies the original tensor's content into it.
+ Tensor AlignTensor(const Tensor& t, const string& name) {
+ VLOG(1) << "AlignTensor called for " << name << ", shape "
+ << t.shape().DebugString()
+ << ". This is unnecessary copying. Consider using shapes with even "
+ << "sizes";
+ Tensor aligned;
+ auto found = pool_.find(name);
+ if (found != pool_.end()) { // found in pool
+ CHECK(!found->second.second) << "Tensor " << name << " is in use";
+ found->second.second = true; // mark in use
+ aligned = found->second.first;
+ CHECK(aligned.shape().IsSameSize(t.shape()));
+ CHECK_EQ(aligned.dtype(), t.dtype());
+ } else { // allocate a new temporary tensor
+ ctx_->allocate_temp(t.dtype(), t.shape(), &aligned);
+ pool_.emplace(name, std::make_pair(aligned, true));
+ }
+ functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(),
+ aligned.flat<T>());
+ return aligned;
+ }
+
+ // Tensors to be copied.
+ std::vector<std::pair<Tensor, const Tensor>> copy_out_;
+ // A pool of pre-allocated temporary tensors, with an indicator for whether
+ // it's in use.
+ std::map<string, std::pair<Tensor, bool>> pool_;
+ // Op context
+ OpKernelContext* ctx_ = nullptr;
+ // Device
+ const Device& device_;
+};
+
+} // namespace
+
template <typename Device, typename T, bool USE_CUBLAS>
class BlockLSTMOp : public OpKernel {
public:
explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
@@ -537,77 +644,117 @@ class BlockLSTMOp : public OpKernel {
const Tensor* seq_len_max_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
- OpInputList x_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list));
- const int64 batch_size = x_list[0].dim_size(0);
- const int64 input_size = x_list[0].dim_size(1);
+ const Tensor* x;
+ OP_REQUIRES_OK(ctx, ctx->input("x", &x));
+ OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
+ const int64 timelen = x->dim_size(0);
+ const int64 batch_size = x->dim_size(1);
+ const int64 input_size = x->dim_size(2);
const Tensor* cs_prev_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
+ OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2,
+ errors::InvalidArgument("cs_prev must be 2D"));
+ OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size,
+ errors::InvalidArgument("cs_prev.dims(0) != batch_size: ",
+ cs_prev_tensor->dim_size(0), " vs. ",
+ batch_size));
+ const int64 cell_size = cs_prev_tensor->dim_size(1);
+
+ if (batch_size * input_size % 2 == 1) {
+ LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
+ << "input_size are odd. You are using: batch_size="
+ << batch_size << ", input_size=" << input_size;
+ }
+ if (batch_size * cell_size % 2 == 1) {
+ LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and "
+ << "cell_size are odd. You are using: batch_size="
+ << batch_size << ", cell_size=" << cell_size;
+ }
const Tensor* h_prev_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor));
+ OP_REQUIRES(ctx, h_prev_tensor->dims() == 2,
+ errors::InvalidArgument("h_prev must be 2D"));
+ OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size,
+ errors::InvalidArgument("h_prev.dims(0) != batch_size: ",
+ h_prev_tensor->dim_size(0), " vs. ",
+ batch_size));
+ OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size,
+ errors::InvalidArgument("h_prev.dims(1) != cell_size: ",
+ h_prev_tensor->dim_size(1), " vs. ",
+ cell_size));
const Tensor* w_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor));
+ OP_REQUIRES(ctx, w_tensor->dims() == 2,
+ errors::InvalidArgument("w must be 2D"));
+ OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size,
+ errors::InvalidArgument(
+ "w.dim_size(0) != input_size + cell_size: ",
+ w_tensor->dim_size(0), " vs. ", input_size + cell_size));
+ OP_REQUIRES(
+ ctx, w_tensor->dim_size(1) == cell_size * 4,
+ errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ",
+ w_tensor->dim_size(1), " vs. ", cell_size * 4));
const Tensor* wci_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor));
+ OP_REQUIRES(ctx, wci_tensor->dims() == 1,
+ errors::InvalidArgument("wci must be 1D"));
+ OP_REQUIRES(
+ ctx, wci_tensor->dim_size(0) == cell_size,
+ errors::InvalidArgument("wci.dim_size(0) != cell_size: ",
+ wci_tensor->dim_size(0), " vs. ", cell_size));
const Tensor* wcf_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor));
+ OP_REQUIRES(ctx, wcf_tensor->dims() == 1,
+ errors::InvalidArgument("wcf must be 1D"));
+ OP_REQUIRES(
+ ctx, wcf_tensor->dim_size(0) == cell_size,
+ errors::InvalidArgument("wcf.dim_size(0) != cell_size: ",
+ wcf_tensor->dim_size(0), " vs. ", cell_size));
const Tensor* wco_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor));
+ OP_REQUIRES(ctx, wco_tensor->dims() == 1,
+ errors::InvalidArgument("wco must be 1D"));
+ OP_REQUIRES(
+ ctx, wco_tensor->dim_size(0) == cell_size,
+ errors::InvalidArgument("wco.dim_size(0) != cell_size: ",
+ wco_tensor->dim_size(0), " vs. ", cell_size));
const Tensor* b_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor));
- const int64 cell_size = b_tensor->dim_size(0) / 4;
-
- OpOutputList i_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("i", &i_list));
-
- OpOutputList cs_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("cs", &cs_list));
-
- OpOutputList f_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("f", &f_list));
-
- OpOutputList o_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("o", &o_list));
-
- OpOutputList ci_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("ci", &ci_list));
-
- OpOutputList co_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("co", &co_list));
+ OP_REQUIRES(ctx, b_tensor->dims() == 1,
+ errors::InvalidArgument("b must be 1D"));
+ OP_REQUIRES(
+ ctx, b_tensor->dim_size(0) == cell_size * 4,
+ errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ",
+ b_tensor->dim_size(0), " vs. ", cell_size * 4));
- OpOutputList h_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("h", &h_list));
+ TensorShape batch_cell_shape({timelen, batch_size, cell_size});
+ Tensor* i_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out));
- TensorShape batch_cell_shape({batch_size, cell_size});
- for (int64 t = 0; t < max_len_; ++t) {
- Tensor* i_tensor = nullptr;
- OP_REQUIRES_OK(ctx, i_list.allocate(t, batch_cell_shape, &i_tensor));
+ Tensor* cs_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out));
- Tensor* cs_tensor = nullptr;
- OP_REQUIRES_OK(ctx, cs_list.allocate(t, batch_cell_shape, &cs_tensor));
+ Tensor* f_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out));
- Tensor* f_tensor = nullptr;
- OP_REQUIRES_OK(ctx, f_list.allocate(t, batch_cell_shape, &f_tensor));
+ Tensor* o_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out));
- Tensor* o_tensor = nullptr;
- OP_REQUIRES_OK(ctx, o_list.allocate(t, batch_cell_shape, &o_tensor));
+ Tensor* ci_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out));
- Tensor* ci_tensor = nullptr;
- OP_REQUIRES_OK(ctx, ci_list.allocate(t, batch_cell_shape, &ci_tensor));
+ Tensor* co_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out));
- Tensor* co_tensor = nullptr;
- OP_REQUIRES_OK(ctx, co_list.allocate(t, batch_cell_shape, &co_tensor));
-
- Tensor* h_tensor = nullptr;
- OP_REQUIRES_OK(ctx, h_list.allocate(t, batch_cell_shape, &h_tensor));
- }
+ Tensor* h_out;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out));
Tensor xh_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
@@ -628,19 +775,22 @@ class BlockLSTMOp : public OpKernel {
: nullptr;
const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
+ SliceHelper<Device, T> slicer(ctx);
for (int64 t = 0; t < seq_len_max; ++t) {
- const Tensor& x_tensor = x_list[t];
+ const Tensor x_tensor = slicer.InputSlice(*x, t, "x");
const Tensor& cs_prev_tensor2 =
- t == 0 ? *cs_prev_tensor : *cs_list[t - 1];
- const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : *h_list[t - 1];
-
- Tensor* i_tensor = i_list[t];
- Tensor* cs_tensor = cs_list[t];
- Tensor* f_tensor = f_list[t];
- Tensor* o_tensor = o_list[t];
- Tensor* ci_tensor = ci_list[t];
- Tensor* co_tensor = co_list[t];
- Tensor* h_tensor = h_list[t];
+ t == 0 ? *cs_prev_tensor
+ : slicer.OutputSlice(cs_out, t - 1, "cs_prev");
+ const Tensor& h_prev_tensor2 =
+ t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev");
+
+ Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out");
+ Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out");
+ Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out");
+ Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out");
+ Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out");
+ Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out");
+ Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out");
functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
@@ -648,23 +798,25 @@ class BlockLSTMOp : public OpKernel {
x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(),
h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(),
wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(),
- b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor->matrix<T>(),
- cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(),
- ci_tensor->matrix<T>(), co_tensor->matrix<T>(),
- icfo_tensor.matrix<T>(), h_tensor->matrix<T>());
+ b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(),
+ cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(),
+ ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(),
+ h_tensor.matrix<T>());
+ slicer.FinishTimeStep();
}
- for (int64 t = seq_len_max; t < max_len_; ++t) {
- Tensor* cs_tensor = cs_list[t];
- Tensor* h_tensor = h_list[t];
+ if (seq_len_max < timelen) {
+ Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen);
+ Tensor h_tensor = h_out->Slice(seq_len_max, timelen);
- functor::TensorZero<Device, T>()(device, cs_tensor->flat<float>());
- functor::TensorZero<Device, T>()(device, h_tensor->flat<float>());
+ functor::TensorUnalignedZero<Device, T>()(
+ device, cs_tensor.unaligned_flat<float>());
+ functor::TensorUnalignedZero<Device, T>()(
+ device, h_tensor.unaligned_flat<float>());
}
}
private:
- int64 max_len_;
float forget_bias_;
float cell_clip_;
bool use_peephole_;
@@ -685,7 +837,13 @@ namespace functor {
void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::Flat t); \
\
- extern template struct TensorZero<GPUDevice, T>;
+ extern template struct TensorZero<GPUDevice, T>; \
+ \
+ template <> \
+ void TensorUnalignedZero<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::UnalignedFlat t); \
+ \
+ extern template struct TensorUnalignedZero<GPUDevice, T>;
DECLARE_GPU_SPEC(float);
// DECLARE_GPU_SPEC(double);
@@ -708,7 +866,6 @@ template <typename Device, typename T, bool USE_CUBLAS>
class BlockLSTMGradOp : public OpKernel {
public:
explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
}
@@ -716,10 +873,12 @@ class BlockLSTMGradOp : public OpKernel {
const Tensor* seq_len_max_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor));
- OpInputList x_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list));
- const int64 batch_size = x_list[0].dim_size(0);
- const int64 input_size = x_list[0].dim_size(1);
+ const Tensor* x;
+ OP_REQUIRES_OK(ctx, ctx->input("x", &x));
+ OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D"));
+ const int64 timelen = x->dim_size(0);
+ const int64 batch_size = x->dim_size(1);
+ const int64 input_size = x->dim_size(2);
const Tensor* cs_prev_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor));
@@ -751,35 +910,37 @@ class BlockLSTMGradOp : public OpKernel {
errors::InvalidArgument("w and b cell_size don't match: ", cell_size,
" vs. ", b_tensor->dim_size(0)));
- OpInputList i_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("i", &i_list));
+ const Tensor* i_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("i", &i_out));
- OpInputList cs_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("cs", &cs_list));
+ const Tensor* cs_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out));
- OpInputList f_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("f", &f_list));
+ const Tensor* f_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("f", &f_out));
- OpInputList o_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("o", &o_list));
+ const Tensor* o_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("o", &o_out));
- OpInputList ci_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("ci", &ci_list));
+ const Tensor* ci_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out));
- OpInputList co_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("co", &co_list));
+ const Tensor* co_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("co", &co_out));
- OpInputList h_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("h", &h_list));
+ const Tensor* h_out = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("h", &h_out));
- OpInputList cs_grad_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("cs_grad", &cs_grad_list));
+ const Tensor* cs_grad = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad));
- OpInputList h_grad_list;
- OP_REQUIRES_OK(ctx, ctx->input_list("h_grad", &h_grad_list));
+ const Tensor* h_grad = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad));
- OpOutputList x_grad_list;
- OP_REQUIRES_OK(ctx, ctx->output_list("x_grad", &x_grad_list));
+ TensorShape batch_input_shape({timelen, batch_size, input_size});
+ Tensor* x_grad;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output("x_grad", batch_input_shape, &x_grad));
Tensor* cs_prev_grad_tensor = nullptr;
OP_REQUIRES_OK(ctx,
@@ -811,13 +972,7 @@ class BlockLSTMGradOp : public OpKernel {
OP_REQUIRES_OK(
ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor));
- TensorShape batch_input_shape({batch_size, input_size});
TensorShape batch_cell_shape({batch_size, cell_size});
- for (int64 t = 0; t < max_len_; ++t) {
- Tensor* x_grad_tensor = nullptr;
- OP_REQUIRES_OK(
- ctx, x_grad_list.allocate(t, batch_input_shape, &x_grad_tensor));
- }
Tensor xh_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(
@@ -882,33 +1037,40 @@ class BlockLSTMGradOp : public OpKernel {
functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<float>());
const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()();
+ SliceHelper<Device, T> slicer(ctx);
for (int64 t = seq_len_max - 1; t >= 0; --t) {
- const Tensor& x_tensor = x_list[t];
- const Tensor& cs_prev_tensor2 = t == 0 ? *cs_prev_tensor : cs_list[t - 1];
- const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : h_list[t - 1];
- const Tensor& i_tensor = i_list[t];
- const Tensor& cs_tensor = cs_list[t];
- const Tensor& f_tensor = f_list[t];
- const Tensor& o_tensor = o_list[t];
- const Tensor& ci_tensor = ci_list[t];
- const Tensor& co_tensor = co_list[t];
+ const Tensor& x_tensor = slicer.InputSlice(*x, t, "x");
+ const Tensor& cs_prev_tensor2 =
+ t == 0 ? *cs_prev_tensor
+ : slicer.InputSlice(*cs_out, t - 1, "cs_prev");
+ const Tensor& h_prev_tensor2 =
+ t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev");
+ const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out");
+ const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out");
+ const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out");
+ const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out");
+ const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out");
+ const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out");
// Grab previous CS grad.
const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor;
+ const Tensor const_cs_grad_slice =
+ slicer.InputSlice(*cs_grad, t, "cs_grad");
functor::TensorAdd<Device, T>()(
device, const_cs_prev_grad_tensor.flat<T>(),
- cs_grad_list[t].flat<T>(), cs_grad_tensor.flat<T>());
+ const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>());
// Combine previous h grad and h grad coming on top.
const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor;
+ const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad");
functor::TensorAdd<Device, T>()(
- device, const_h_prev_grad_tensor.flat<T>(), h_grad_list[t].flat<T>(),
- h_grad_tensor.flat<T>());
+ device, const_h_prev_grad_tensor.flat<T>(),
+ const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>());
const Tensor& const_cs_grad_tensor = cs_grad_tensor;
const Tensor& const_h_grad_tensor = h_grad_tensor;
- Tensor* x_grad_tensor = x_grad_list[t];
+ Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad");
functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size,
cell_size)(
ctx, stream, device, use_peephole_, x_tensor.matrix<T>(),
@@ -922,19 +1084,20 @@ class BlockLSTMGradOp : public OpKernel {
df_tensor.matrix<T>(), di_tensor.matrix<T>(),
dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(),
h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(),
- x_grad_tensor->matrix<T>(), w_grad_tensor->matrix<T>(),
+ x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(),
wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(),
wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>());
+ slicer.FinishTimeStep();
}
- for (int64 t = seq_len_max; t < max_len_; ++t) {
- Tensor* x_grad_tensor = x_grad_list[t];
- functor::TensorZero<Device, T>()(device, x_grad_tensor->flat<T>());
+ if (seq_len_max < timelen) {
+ Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen);
+ functor::TensorUnalignedZero<Device, T>()(
+ device, x_grad_tensor.unaligned_flat<T>());
}
}
private:
- int64 max_len_;
bool use_peephole_;
};
@@ -955,6 +1118,16 @@ namespace functor {
typename TTypes<T>::Flat dst); \
\
template <> \
+ void TensorCopyUnaligned<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src, \
+ typename TTypes<T>::Flat dst); \
+ \
+ template <> \
+ void TensorCopyToUnaligned<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstFlat src, \
+ typename TTypes<T>::UnalignedFlat dst); \
+ \
+ template <> \
void TensorAdd<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::ConstFlat a, \
typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c); \
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h
index 5a9dda5755..1332b88002 100644
--- a/tensorflow/contrib/rnn/kernels/lstm_ops.h
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h
@@ -41,6 +41,13 @@ struct TensorZero {
};
template <typename Device, typename T>
+struct TensorUnalignedZero {
+ void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) {
+ t.device(d) = t.constant(T(0));
+ }
+};
+
+template <typename Device, typename T>
struct TensorCopy {
void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
typename TTypes<T>::Flat dst) {
@@ -49,6 +56,22 @@ struct TensorCopy {
};
template <typename Device, typename T>
+struct TensorCopyUnaligned {
+ void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src,
+ typename TTypes<T>::Flat dst) {
+ dst.device(d) = src;
+ }
+};
+
+template <typename Device, typename T>
+struct TensorCopyToUnaligned {
+ void operator()(const Device& d, typename TTypes<T>::ConstFlat src,
+ typename TTypes<T>::UnalignedFlat dst) {
+ dst.device(d) = src;
+ }
+};
+
+template <typename Device, typename T>
struct TensorAdd {
void operator()(const Device& d, typename TTypes<T>::ConstFlat a,
typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) {
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc
index 47298d9501..b33ca5fc8d 100644
--- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc
+++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc
@@ -26,7 +26,10 @@ typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_SPECS(T) \
template struct TensorZero<GPUDevice, T>; \
+ template struct TensorUnalignedZero<GPUDevice, T>; \
template struct TensorCopy<GPUDevice, T>; \
+ template struct TensorCopyUnaligned<GPUDevice, T>; \
+ template struct TensorCopyToUnaligned<GPUDevice, T>; \
template struct TensorAdd<GPUDevice, T>; \
template struct LSTMBlockCellFprop<GPUDevice, T, true>; \
template struct LSTMBlockCellBprop<GPUDevice, T, true>; \
diff --git a/tensorflow/contrib/rnn/ops/lstm_ops.cc b/tensorflow/contrib/rnn/ops/lstm_ops.cc
index 08dbd3c822..2de40825c9 100644
--- a/tensorflow/contrib/rnn/ops/lstm_ops.cc
+++ b/tensorflow/contrib/rnn/ops/lstm_ops.cc
@@ -58,8 +58,8 @@ REGISTER_OP("LSTMBlockCell")
.Doc(R"doc(
Computes the LSTM cell forward propagation for 1 time step.
-This implementation uses 1 weight matrix and 1 bias vector, there is no
-diagonal peephole connection.
+This implementation uses 1 weight matrix and 1 bias vector, and there's an
+optional peephole connection.
This kernel op implements the following mathematical equations:
@@ -68,21 +68,34 @@ xh = [x, h_prev]
[i, f, ci, o] = xh * w + b
f = f + forget_bias
-i = sigmoid(i)
-f = sigmoid(f)
+if not use_peephole:
+ wci = wcf = wco = 0
+
+i = sigmoid(cs_prev * wci + i)
+f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
-o = sigmoid(o)
cs = ci .* i + cs_prev .* f
-co = tanh(cs)
+cs = clip(cs, cell_clip)
+o = sigmoid(cs * wco + f)
+co = tanh(cs)
h = co .* o
```
+cell_clip: Value to clip the 'cs' value to.
+use_peephole: Whether to use peephole weights.
forget_bias: The forget gate bias.
-x: The input to the LSTM cell.
+
+x: The input to the LSTM cell, shape (batch_size, num_inputs).
+cs_prev: Value of the cell state at previous time step.
+h_prev: Output of the previous cell at previous time step.
w: The weight matrix.
+wci: The weight matrix for input gate peephole connection.
+wcf: The weight matrix for forget gate peephole connection.
+wco: The weight matrix for output gate peephole connection.
b: The bias vector.
+
i: The input gate.
cs: The cell state before the tanh.
f: The forget gate.
@@ -139,10 +152,14 @@ Computes the LSTM cell backward propagation for 1 timestep.
This implementation is to be used in conjunction of LSTMBlockCell.
-x: The input to the LSTM cell.
+use_peephole: Whether the cell uses peephole connections.
+x: The input to the LSTM cell, shape (batch_size, num_inputs).
cs_prev: The previous cell state.
h_prev: The previous h state.
w: The weight matrix.
+wci: The weight matrix for input gate peephole connection.
+wcf: The weight matrix for forget gate peephole connection.
+wco: The weight matrix for output gate peephole connection.
b: The bias vector.
i: The input gate.
cs: The cell state before the tanh.
@@ -150,14 +167,18 @@ f: The forget gate.
o: The output gate.
ci: The cell input.
co: The cell after the tanh.
-h_grad: THe gradient of h vector.
-cs_prev_grad: The gradient of cs.
+cs_grad: The current gradient of cs.
+h_grad: The gradient of h vector.
+cs_prev_grad: The gradient of cs to be back-propped.
dicfo: The derivative wrt to [i, cs, f, o].
+wci_grad: The gradient for wci to be back-propped.
+wcf_grad: The gradient for wcf to be back-propped.
+wco_grad: The gradient for wco to be back-propped.
)doc");
REGISTER_OP("BlockLSTM")
.Input("seq_len_max: int64")
- .Input("x: max_len * T")
+ .Input("x: T")
.Input("cs_prev: T")
.Input("h_prev: T")
.Input("w: T")
@@ -165,46 +186,83 @@ REGISTER_OP("BlockLSTM")
.Input("wcf: T")
.Input("wco: T")
.Input("b: T")
- .Output("i: max_len * T")
- .Output("cs: max_len * T")
- .Output("f: max_len * T")
- .Output("o: max_len * T")
- .Output("ci: max_len * T")
- .Output("co: max_len * T")
- .Output("h: max_len * T")
- .Attr("max_len: int")
+ .Output("i: T")
+ .Output("cs: T")
+ .Output("f: T")
+ .Output("o: T")
+ .Output("ci: T")
+ .Output("co: T")
+ .Output("h: T")
.Attr("forget_bias: float = 1.0")
.Attr("cell_clip: float = 3.0")
.Attr("use_peephole: bool = false")
.Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle x, b;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
- DimensionHandle batch_size = c->Dim(x, 0);
+ DimensionHandle timelen = c->Dim(x, 0);
+ DimensionHandle batch_size = c->Dim(x, 1);
DimensionHandle cell_size;
TF_RETURN_IF_ERROR(
c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
- int64 max_len;
- TF_RETURN_IF_ERROR(c->GetAttr("max_len", &max_len));
-
- DCHECK_EQ(max_len * 7, c->num_outputs());
- ShapeHandle output = c->Matrix(batch_size, cell_size);
- for (int i = 0; i < max_len; ++i) {
- for (int j = 0; j < 7; ++j) {
- c->set_output(i * 7 + j, output);
- }
+ DCHECK_EQ(7, c->num_outputs());
+ ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
+ for (int i = 0; i < 7; ++i) {
+ c->set_output(i, output);
}
return Status::OK();
})
.Doc(R"doc(
+Computes the LSTM cell forward propagation for all the time steps.
+
+This is equivalent to applying LSTMBlockCell in a loop, like so:
+
+```python
+for x1 in unpack(x):
+ i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock(
+ x1, cs_prev, h_prev, w, wci, wcf, wco, b)
+ cs_prev = cs1
+ h_prev = h1
+ i.append(i1)
+ cs.append(cs1)
+ f.append(f1)
+ o.append(o1)
+ ci.append(ci1)
+ co.append(co1)
+ h.append(h1)
+return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h)
+```
+
+cell_clip: Value to clip the 'cs' value to.
+use_peephole: Whether to use peephole weights.
+forget_bias: The forget gate bias.
+
+seq_len_max: Maximum time length actually used by this input. Outputs are padded
+ with zeros beyond this length.
+x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).
+cs_prev: Value of the initial cell state.
+h_prev: Initial output of cell (to be used for peephole).
+w: The weight matrix.
+wci: The weight matrix for input gate peephole connection.
+wcf: The weight matrix for forget gate peephole connection.
+wco: The weight matrix for output gate peephole connection.
+b: The bias vector.
+
+i: The input gate over the whole time sequence.
+cs: The cell state before the tanh over the whole time sequence.
+f: The forget gate over the whole time sequence.
+o: The output gate over the whole time sequence.
+ci: The cell input over the whole time sequence.
+co: The cell after the tanh over the whole time sequence.
+h: The output h vector over the whole time sequence.
)doc");
REGISTER_OP("BlockLSTMGrad")
.Input("seq_len_max: int64")
- .Input("x: max_len * T")
+ .Input("x: T")
.Input("cs_prev: T")
.Input("h_prev: T")
.Input("w: T")
@@ -212,16 +270,16 @@ REGISTER_OP("BlockLSTMGrad")
.Input("wcf: T")
.Input("wco: T")
.Input("b: T")
- .Input("i: max_len * T")
- .Input("cs: max_len * T")
- .Input("f: max_len * T")
- .Input("o: max_len * T")
- .Input("ci: max_len * T")
- .Input("co: max_len * T")
- .Input("h: max_len * T")
- .Input("cs_grad: max_len * T")
- .Input("h_grad: max_len * T")
- .Output("x_grad: max_len * T")
+ .Input("i: T")
+ .Input("cs: T")
+ .Input("f: T")
+ .Input("o: T")
+ .Input("ci: T")
+ .Input("co: T")
+ .Input("h: T")
+ .Input("cs_grad: T")
+ .Input("h_grad: T")
+ .Output("x_grad: T")
.Output("cs_prev_grad: T")
.Output("h_prev_grad: T")
.Output("w_grad: T")
@@ -229,36 +287,65 @@ REGISTER_OP("BlockLSTMGrad")
.Output("wcf_grad: T")
.Output("wco_grad: T")
.Output("b_grad: T")
- .Attr("max_len: int")
.Attr("use_peephole: bool")
.Attr("T: {float}")
.SetShapeFn([](InferenceContext* c) {
- int64 max_len;
- TF_RETURN_IF_ERROR(c->GetAttr("max_len", &max_len));
-
ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + max_len), 2, &cs_prev));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2 + max_len), 2, &h_prev));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3 + max_len), 2, &w));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4 + max_len), 1, &wci));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(5 + max_len), 1, &wco));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(6 + max_len), 1, &wcf));
- TF_RETURN_IF_ERROR(c->WithRank(c->input(7 + max_len), 1, &b));
-
- int out_idx = 0;
- for (int i = 0; i < max_len; ++i) c->set_output(out_idx++, x);
- c->set_output(out_idx++, cs_prev);
- c->set_output(out_idx++, h_prev);
- c->set_output(out_idx++, w);
- c->set_output(out_idx++, wci);
- c->set_output(out_idx++, wco);
- c->set_output(out_idx++, wcf);
- c->set_output(out_idx++, b);
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
+
+ c->set_output(0, x);
+ c->set_output(1, cs_prev);
+ c->set_output(2, h_prev);
+ c->set_output(3, w);
+ c->set_output(4, wci);
+ c->set_output(5, wco);
+ c->set_output(6, wcf);
+ c->set_output(7, b);
return Status::OK();
})
.Doc(R"doc(
+Computes the LSTM cell backward propagation for the entire time sequence.
+
+This implementation is to be used in conjunction of LSTMBlock.
+
+use_peephole: Whether to use peephole weights.
+
+seq_len_max: Maximum time length actually used by this input. Outputs are padded
+ with zeros beyond this length.
+x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).
+cs_prev: Value of the initial cell state.
+h_prev: Initial output of cell (to be used for peephole).
+w: The weight matrix.
+wci: The weight matrix for input gate peephole connection.
+wcf: The weight matrix for forget gate peephole connection.
+wco: The weight matrix for output gate peephole connection.
+b: The bias vector.
+i: The input gate over the whole time sequence.
+cs: The cell state before the tanh over the whole time sequence.
+f: The forget gate over the whole time sequence.
+o: The output gate over the whole time sequence.
+ci: The cell input over the whole time sequence.
+co: The cell after the tanh over the whole time sequence.
+h: The output h vector over the whole time sequence.
+cs_grad: The current gradient of cs.
+h_grad: The gradient of h vector.
+
+x_grad: The gradient of x to be back-propped.
+cs_prev_grad: The gradient of cs_prev to be back-propped.
+h_prev_grad: The gradient of h_prev to be back-propped.
+w_grad: The gradient for w to be back-propped.
+wci_grad: The gradient for wci to be back-propped.
+wcf_grad: The gradient for wcf to be back-propped.
+wco_grad: The gradient for wco to be back-propped.
+b_grad: The gradient for w to be back-propped.
)doc");
} // end namespace tensorflow
diff --git a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc
index 4acbfa7ced..a25e0674ea 100644
--- a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc
+++ b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc
@@ -79,118 +79,85 @@ TEST_F(LSTMOpsTest, LSTMBlockCellGrad_ShapeFn) {
TEST_F(LSTMOpsTest, BlockLSTM_ShapeFn) {
ShapeInferenceTestOp op("BlockLSTM");
- auto set_op = [&op](int max_len) {
- std::vector<NodeDefBuilder::NodeOut> x_list;
- for (int i = 0; i < max_len; ++i) x_list.emplace_back("a", 0, DT_FLOAT);
- TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTM")
- .Input({"seq_len_max", 0, DT_INT64})
- .Input(x_list)
- .Input({"cs_prev", 0, DT_FLOAT})
- .Input({"h_prev", 0, DT_FLOAT})
- .Input({"w", 0, DT_FLOAT})
- .Input({"wci", 0, DT_FLOAT})
- .Input({"wcf", 0, DT_FLOAT})
- .Input({"wco", 0, DT_FLOAT})
- .Input({"b", 0, DT_FLOAT})
- .Attr("max_len", max_len)
- .Finalize(&op.node_def));
- };
-
- for (int max_len = 1; max_len < 11; max_len += 3) {
- set_op(max_len);
- int num_in = max_len + 8;
- int num_out = max_len * 7;
-
- // Middle inputs don't affect shape inference.
- string infix = ";" + JoinedCopies("?", num_in - 3) + ";";
-
- // Rank checks.
- INFER_ERROR("must be rank 2", op, "?;[?]" + infix + "?");
- INFER_ERROR("must be rank 1", op, "?;?" + infix + "[?,?]");
-
- // Output
- INFER_OK(op, "?;?" + infix + "?", JoinedCopies("[?,?]", num_out));
- INFER_OK(op, "?;[?,?]" + infix + "?", JoinedCopies("[d1_0,?]", num_out));
- INFER_OK(op, "?;[?,?]" + infix + "[?]", JoinedCopies("[d1_0,?]", num_out));
- INFER_OK(op, "?;[?,?]" + infix + "[20]", JoinedCopies("[d1_0,5]", num_out));
-
- // cell_size must be divisible by 4.
- INFER_ERROR("must be evenly divisible", op, "?;?" + infix + "[11]");
- }
+ TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTM")
+ .Input({"seq_len_max", 0, DT_INT64})
+ .Input({"x", 0, DT_FLOAT})
+ .Input({"cs_prev", 0, DT_FLOAT})
+ .Input({"h_prev", 0, DT_FLOAT})
+ .Input({"w", 0, DT_FLOAT})
+ .Input({"wci", 0, DT_FLOAT})
+ .Input({"wcf", 0, DT_FLOAT})
+ .Input({"wco", 0, DT_FLOAT})
+ .Input({"b", 0, DT_FLOAT})
+ .Finalize(&op.node_def));
+
+ // Middle inputs don't affect shape inference.
+ string infix = ";" + JoinedCopies("?", 6) + ";";
+
+ // Rank checks.
+ INFER_ERROR("must be rank 3", op, "?;[?]" + infix + "?");
+ INFER_ERROR("must be rank 1", op, "?;?" + infix + "[?,?]");
+
+ // Output
+ INFER_OK(op, "?;?" + infix + "?", JoinedCopies("[?,?,?]", 7));
+ INFER_OK(op, "?;[?,?,?]" + infix + "?", JoinedCopies("[d1_0,d1_1,?]", 7));
+ INFER_OK(op, "?;[?,?,?]" + infix + "[?]", JoinedCopies("[d1_0,d1_1,?]", 7));
+ INFER_OK(op, "?;[?,?,?]" + infix + "[20]", JoinedCopies("[d1_0,d1_1,5]", 7));
+
+ // cell_size must be divisible by 4.
+ INFER_ERROR("must be evenly divisible", op, "?;?" + infix + "[11]");
}
TEST_F(LSTMOpsTest, BlockLSTMGrad_ShapeFn) {
ShapeInferenceTestOp op("BlockLSTMGrad");
-
- auto set_op = [&op](int max_len) {
- std::vector<NodeDefBuilder::NodeOut> x_list;
- for (int i = 0; i < max_len; ++i) x_list.emplace_back("a", 0, DT_FLOAT);
- TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTMGrad")
- .Input({"seq_len_max", 0, DT_INT64})
- .Input(x_list)
- .Input({"cs_prev", 0, DT_FLOAT})
- .Input({"h_prev", 0, DT_FLOAT})
- .Input({"w", 0, DT_FLOAT})
- .Input({"wci", 0, DT_FLOAT})
- .Input({"wcf", 0, DT_FLOAT})
- .Input({"wco", 0, DT_FLOAT})
- .Input({"b", 0, DT_FLOAT})
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Input(x_list)
- .Attr("max_len", max_len)
- .Finalize(&op.node_def));
- };
-
- for (int max_len = 1; max_len < 11; max_len += 3) {
- set_op(max_len);
- int num_in = max_len * 10 + 8;
- int num_out = max_len + 7;
-
- // Last inputs don't affect shape inference.
- string suffix = ";" + JoinedCopies("?", 9 * max_len);
-
- // Rank check for x
- string invalid_x = JoinedCopies("[?]", max_len);
- INFER_ERROR("must be rank 2", op,
- "?;" + invalid_x + ";?;?;?;?;?;?;?" + suffix);
-
- // Rank checks for cs_prev through b.
- string unknown_x = JoinedCopies("?", max_len);
- INFER_ERROR("must be rank 2", op,
- "?;" + unknown_x + ";[1];?;?;?;?;?;?" + suffix);
- INFER_ERROR("must be rank 2", op,
- "?;" + unknown_x + ";?;[1];?;?;?;?;?" + suffix);
- INFER_ERROR("must be rank 2", op,
- "?;" + unknown_x + ";?;?;[1];?;?;?;?" + suffix);
- INFER_ERROR("must be rank 1", op,
- "?;" + unknown_x + ";?;?;?;[1,?];?;?;?" + suffix);
- INFER_ERROR("must be rank 1", op,
- "?;" + unknown_x + ";?;?;?;?;[1,?];?;?" + suffix);
- INFER_ERROR("must be rank 1", op,
- "?;" + unknown_x + ";?;?;?;?;?;[1,?];?" + suffix);
- INFER_ERROR("must be rank 1", op,
- "?;" + unknown_x + ";?;?;?;?;?;?;[1,?]" + suffix);
-
- // Output with all input knowns makes known rank outputs.
- INFER_OK(op, JoinedCopies("?", num_in),
- JoinedCopies("[?,?]", num_out - 4) + ";" + JoinedCopies("[?]", 4));
-
- // Output with copies input shapes to output.
- string input = strings::StrCat("?;", JoinedCopies("[?,?]", max_len + 3),
- ";", JoinedCopies("[?]", 4), suffix);
- string expected = JoinedCopies("in1", max_len); // copies of x;
- for (int i = max_len; i < num_out; ++i) {
- strings::StrAppend(&expected, ";in", (i + 1));
- }
- INFER_OK(op, input, expected);
+ TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTMGrad")
+ .Input({"seq_len_max", 0, DT_INT64})
+ .Input({"x", 0, DT_FLOAT})
+ .Input({"cs_prev", 0, DT_FLOAT})
+ .Input({"h_prev", 0, DT_FLOAT})
+ .Input({"w", 0, DT_FLOAT})
+ .Input({"wci", 0, DT_FLOAT})
+ .Input({"wcf", 0, DT_FLOAT})
+ .Input({"wco", 0, DT_FLOAT})
+ .Input({"b", 0, DT_FLOAT})
+ .Input({"i", 0, DT_FLOAT})
+ .Input({"cs", 0, DT_FLOAT})
+ .Input({"f", 0, DT_FLOAT})
+ .Input({"o", 0, DT_FLOAT})
+ .Input({"ci", 0, DT_FLOAT})
+ .Input({"co", 0, DT_FLOAT})
+ .Input({"h", 0, DT_FLOAT})
+ .Input({"cs_grad", 0, DT_FLOAT})
+ .Input({"h_grad", 0, DT_FLOAT})
+ .Finalize(&op.node_def));
+
+ // Last inputs don't affect shape inference.
+ string suffix = ";" + JoinedCopies("?", 9);
+
+ // Rank check for x
+ INFER_ERROR("must be rank 3", op, "?;[?];?;?;?;?;?;?;?" + suffix);
+
+ // Rank checks for cs_prev through b.
+ INFER_ERROR("must be rank 2", op, "?;?;[1];?;?;?;?;?;?" + suffix);
+ INFER_ERROR("must be rank 2", op, "?;?;?;[1];?;?;?;?;?" + suffix);
+ INFER_ERROR("must be rank 2", op, "?;?;?;?;[1];?;?;?;?" + suffix);
+ INFER_ERROR("must be rank 1", op, "?;?;?;?;?;[1,?];?;?;?" + suffix);
+ INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;[1,?];?;?" + suffix);
+ INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;?;[1,?];?" + suffix);
+ INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;?;?;[1,?]" + suffix);
+
+ // Output with all input knowns makes known rank outputs.
+ INFER_OK(op, JoinedCopies("?", 18), "[?,?,?];" + JoinedCopies("[?,?]", 3) +
+ ";" + JoinedCopies("[?]", 4));
+
+ // Output with copies input shapes to output.
+ string input = strings::StrCat("?;[?,?,?];", JoinedCopies("[?,?]", 3), ";",
+ JoinedCopies("[?]", 4), suffix);
+ string expected = "in1";
+ for (int i = 1; i < 8; ++i) {
+ strings::StrAppend(&expected, ";in", (i + 1));
}
+ INFER_OK(op, input, expected);
}
} // namespace tensorflow
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 7d863d17dc..4979856668 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""LSTM Block Cell ops."""
from __future__ import absolute_import
from __future__ import division
@@ -50,8 +49,8 @@ def _lstm_block_cell(x,
name=None):
r"""Computes the LSTM cell forward propagation for 1 time step.
- This implementation uses 1 weight matrix and 1 bias vector, there is no
- diagonal peephole connection.
+ This implementation uses 1 weight matrix and 1 bias vector, and there's an
+ optional peephole connection.
This kernel op implements the following mathematical equations:
@@ -60,30 +59,41 @@ def _lstm_block_cell(x,
[i, f, ci, o] = xh * w + b
f = f + forget_bias
- i = sigmoid(i)
- f = sigmoid(f)
+ if not use_peephole:
+ wci = wcf = wco = 0
+
+ i = sigmoid(cs_prev * wci + i)
+ f = sigmoid(cs_prev * wcf + f)
ci = tanh(ci)
- o = sigmoid(o)
cs = ci .* i + cs_prev .* f
- co = tanh(cs)
+ cs = clip(cs, cell_clip)
+ o = sigmoid(cs * wco + f)
+ co = tanh(cs)
h = co .* o
```
Args:
- x: A `Tensor`. Must be one of the following types: `float32`, `float64`.
- The input to the LSTM cell.
+ x: A `Tensor`. Must be one of the following types: `float32`.
+ The input to the LSTM cell, shape (batch_size, num_inputs).
cs_prev: A `Tensor`. Must have the same type as `x`.
+ Value of the cell state at previous time step.
h_prev: A `Tensor`. Must have the same type as `x`.
+ Output of the previous cell at previous time step.
w: A `Tensor`. Must have the same type as `x`. The weight matrix.
b: A `Tensor`. Must have the same type as `x`. The bias vector.
wci: A `Tensor`. Must have the same type as `x`.
+ The weight matrix for input gate peephole connection.
wcf: A `Tensor`. Must have the same type as `x`.
+ The weight matrix for forget gate peephole connection.
wco: A `Tensor`. Must have the same type as `x`.
+ The weight matrix for output gate peephole connection.
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
cell_clip: An optional `float`. Defaults to `3`.
+ Value to clip the 'cs' value to.
use_peephole: An optional `bool`. Defaults to `False`.
+ Whether to use peephole weights.
name: A name for the operation (optional).
Returns:
@@ -108,18 +118,19 @@ def _lstm_block_cell(x,
wcf = wci
# pylint: disable=protected-access
- return _lstm_ops_so.lstm_block_cell(x=x,
- cs_prev=cs_prev,
- h_prev=h_prev,
- w=w,
- wci=wci,
- wco=wco,
- wcf=wcf,
- b=b,
- forget_bias=forget_bias,
- cell_clip=cell_clip,
- use_peephole=use_peephole,
- name=name)
+ return _lstm_ops_so.lstm_block_cell(
+ x=x,
+ cs_prev=cs_prev,
+ h_prev=h_prev,
+ w=w,
+ wci=wci,
+ wco=wco,
+ wcf=wcf,
+ b=b,
+ forget_bias=forget_bias,
+ cell_clip=cell_clip,
+ use_peephole=use_peephole,
+ name=name)
# pylint: enable=protected-access
@@ -180,9 +191,8 @@ def _block_lstm(seq_len_max,
cell_size = cell_size4 / 4
zero_state = None
if cs_prev is None or h_prev is None:
- zero_state = array_ops.constant(0,
- dtype=dtypes.float32,
- shape=[batch_size, cell_size])
+ zero_state = array_ops.constant(
+ 0, dtype=dtypes.float32, shape=[batch_size, cell_size])
if cs_prev is None:
cs_prev = zero_state
if h_prev is None:
@@ -193,26 +203,30 @@ def _block_lstm(seq_len_max,
wcf = wci
# pylint: disable=protected-access
- return _lstm_ops_so.block_lstm(seq_len_max=seq_len_max,
- x=x,
- cs_prev=cs_prev,
- h_prev=h_prev,
- w=w,
- wci=wci,
- wco=wco,
- wcf=wcf,
- b=b,
- forget_bias=forget_bias,
- cell_clip=cell_clip,
- name=name,
- use_peephole=use_peephole)
+ i, cs, f, o, ci, co, h = _lstm_ops_so.block_lstm(
+ seq_len_max=seq_len_max,
+ x=array_ops.pack(x),
+ cs_prev=cs_prev,
+ h_prev=h_prev,
+ w=w,
+ wci=wci,
+ wco=wco,
+ wcf=wcf,
+ b=b,
+ forget_bias=forget_bias,
+ cell_clip=cell_clip,
+ name=name,
+ use_peephole=use_peephole)
+
+ return array_ops.unpack(i), array_ops.unpack(cs), array_ops.unpack(
+ f), array_ops.unpack(o), array_ops.unpack(ci), array_ops.unpack(
+ co), array_ops.unpack(h)
# pylint: enable=protected-access
# pylint: enable=invalid-name
_lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"]
-
ops.RegisterShape("LSTMBlockCell")(common_shapes.call_cpp_shape_fn)
@@ -283,28 +297,11 @@ ops.RegisterShape("BlockLSTM")(common_shapes.call_cpp_shape_fn)
@ops.RegisterGradient("BlockLSTM")
def _BlockLSTMGrad(op, *grad):
"""Gradient for BlockLSTM."""
- max_len = op.get_attr("max_len")
-
- seq_len_max = op.inputs[0]
- x = op.inputs[1:1 + max_len]
- cs_prev = op.inputs[-7]
- h_prev = op.inputs[-6]
- w = op.inputs[-5]
- wci = op.inputs[-4]
- wco = op.inputs[-3]
- wcf = op.inputs[-2]
- b = op.inputs[-1]
-
- i = op.outputs[0 * max_len:1 * max_len]
- cs = op.outputs[1 * max_len:2 * max_len]
- f = op.outputs[2 * max_len:3 * max_len]
- o = op.outputs[3 * max_len:4 * max_len]
- ci = op.outputs[4 * max_len:5 * max_len]
- co = op.outputs[5 * max_len:6 * max_len]
- h = op.outputs[6 * max_len:7 * max_len]
-
- cs_grad = grad[-max_len * 2:-max_len]
- h_grad = grad[-max_len:]
+ seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs
+ i, cs, f, o, ci, co, h = op.outputs
+
+ cs_grad = grad[1]
+ h_grad = grad[6]
(x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad,
b_grad) = _lstm_ops_so.block_lstm_grad(
@@ -328,8 +325,8 @@ def _BlockLSTMGrad(op, *grad):
h_grad,
use_peephole=op.get_attr("use_peephole"))
- return [None] + x_grad + [cs_prev_grad, h_prev_grad, w_grad, wci_grad,
- wco_grad, wcf_grad, b_grad]
+ return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad,
+ wcf_grad, b_grad]
ops.RegisterShape("BlockLSTMGrad")(common_shapes.call_cpp_shape_fn)
@@ -379,21 +376,23 @@ class LSTMBlockCell(rnn_cell.RNNCell):
input_size = x_shape[1]
w = vs.get_variable("W", [input_size + self._num_units,
self._num_units * 4])
- b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]],
- initializer=init_ops.constant_initializer(0.0))
+ b = vs.get_variable(
+ "b", [w.get_shape().with_rank(2)[1]],
+ initializer=init_ops.constant_initializer(0.0))
wci = vs.get_variable("wci", [self._num_units])
wco = vs.get_variable("wco", [self._num_units])
wcf = vs.get_variable("wcf", [self._num_units])
(cs_prev, h_prev) = states_prev
- (_, cs, _, _, _, _, h) = _lstm_block_cell(x,
- cs_prev,
- h_prev,
- w,
- b,
- wci=wci,
- wco=wco,
- wcf=wcf,
- forget_bias=self._forget_bias,
- use_peephole=self._use_peephole)
+ (_, cs, _, _, _, _, h) = _lstm_block_cell(
+ x,
+ cs_prev,
+ h_prev,
+ w,
+ b,
+ wci=wci,
+ wco=wco,
+ wcf=wcf,
+ forget_bias=self._forget_bias,
+ use_peephole=self._use_peephole)
return (h, (cs, h))