diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-29 12:21:31 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-29 13:32:40 -0700 |
commit | b7d5df182b7394ab17c11ccc949ce07812920bd9 (patch) | |
tree | ab244238a4c7fdb099a62ca8c2396fd8e15216c3 | |
parent | 4323a658b5228fe8d5482941edfacf58506dea34 (diff) |
Make (tf.contrib) BlockLSTMOp take 3D tensors instead of lists of 2D tensors.
This facilitates dealing with dynamic time lengths.
Updated documentation.
Change: 134699973
-rw-r--r-- | tensorflow/contrib/rnn/kernels/lstm_ops.cc | 403 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/lstm_ops.h | 23 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc | 3 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/ops/lstm_ops.cc | 213 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/ops/lstm_ops_test.cc | 181 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 147 |
6 files changed, 611 insertions, 359 deletions
diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index 2749d7d797..7fec457a4a 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -523,11 +523,118 @@ REGISTER_GPU_KERNEL(float); #undef REGISTER_GPU_KERNEL #endif // GOOGLE_CUDA +namespace { + +// This helper class can be used to access timeslices of a 3D tensor. If a slice +// happens to be unaligned (usually because both batch size and number of cells +// are odd - this isn't common) this involves overhead, since data needs to be +// copied. However, if all slices are aligned, the bits aren't copied. In the +// cases where copying is needed, the outputs have to be recopied back. +// At the end of each time step you should call FinishTimeStep which does this, +// and also allows for reuse of temporary tensors. +template <typename Device, typename T> +class SliceHelper { + public: + SliceHelper(OpKernelContext* ctx) + : ctx_(ctx), device_(ctx_->eigen_device<Device>()) {} + + ~SliceHelper() { + CHECK(copy_out_.empty()); + for (const auto& entry : pool_) { + CHECK(!entry.second.second); // nothing is in use + } + } + + // Slice through an input tensor. This may copy unaligned slices, but no + // copying back will be done at the end. + const Tensor InputSlice(const Tensor& t, int pos, const string& name) { + Tensor res = UnalignedSlice(t, pos); + if (res.IsAligned()) { + return res; + } else { + return AlignTensor(res, name); + } + } + + // Slice through an output tensor. This may copy unaligned slices, and + // schedule copying back on destruction. + Tensor OutputSlice(Tensor* t, int pos, const string& name) { + Tensor res = UnalignedSlice(*t, pos); + if (res.IsAligned()) { + return res; + } else { + Tensor aligned = AlignTensor(res, name); + copy_out_.emplace_back(res, aligned); + return aligned; + } + } + + void FinishTimeStep() { + for (const auto& p : copy_out_) { + const Tensor& aligned = p.second; + Tensor original = p.first; + // Copy from aligned back to original. + functor::TensorCopyToUnaligned<Device, T>()(device_, aligned.flat<T>(), + original.unaligned_flat<T>()); + } + copy_out_.clear(); + // Mark all entries as not in use. + for (auto& entry : pool_) { + entry.second.second = false; + } + } + + private: + // Return a slice at position 'pos'. Result may be unaligned. The resulting + // tensor always shares data with the source tensor. + Tensor UnalignedSlice(const Tensor& t, int pos) const { + Tensor res; + // CHECK should never fail here, since the number of elements must match + CHECK(res.CopyFrom(t.Slice(pos, pos + 1), {t.dim_size(1), t.dim_size(2)})); + return res; + } + + // Assumes input is not aligned, creates a temporary aligned tensor of the + // same shape and copies the original tensor's content into it. + Tensor AlignTensor(const Tensor& t, const string& name) { + VLOG(1) << "AlignTensor called for " << name << ", shape " + << t.shape().DebugString() + << ". This is unnecessary copying. Consider using shapes with even " + << "sizes"; + Tensor aligned; + auto found = pool_.find(name); + if (found != pool_.end()) { // found in pool + CHECK(!found->second.second) << "Tensor " << name << " is in use"; + found->second.second = true; // mark in use + aligned = found->second.first; + CHECK(aligned.shape().IsSameSize(t.shape())); + CHECK_EQ(aligned.dtype(), t.dtype()); + } else { // allocate a new temporary tensor + ctx_->allocate_temp(t.dtype(), t.shape(), &aligned); + pool_.emplace(name, std::make_pair(aligned, true)); + } + functor::TensorCopyUnaligned<Device, T>()(device_, t.unaligned_flat<T>(), + aligned.flat<T>()); + return aligned; + } + + // Tensors to be copied. + std::vector<std::pair<Tensor, const Tensor>> copy_out_; + // A pool of pre-allocated temporary tensors, with an indicator for whether + // it's in use. + std::map<string, std::pair<Tensor, bool>> pool_; + // Op context + OpKernelContext* ctx_ = nullptr; + // Device + const Device& device_; +}; + +} // namespace + template <typename Device, typename T, bool USE_CUBLAS> class BlockLSTMOp : public OpKernel { public: explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); @@ -537,77 +644,117 @@ class BlockLSTMOp : public OpKernel { const Tensor* seq_len_max_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); - OpInputList x_list; - OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list)); - const int64 batch_size = x_list[0].dim_size(0); - const int64 input_size = x_list[0].dim_size(1); + const Tensor* x; + OP_REQUIRES_OK(ctx, ctx->input("x", &x)); + OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D")); + const int64 timelen = x->dim_size(0); + const int64 batch_size = x->dim_size(1); + const int64 input_size = x->dim_size(2); const Tensor* cs_prev_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); + OP_REQUIRES(ctx, cs_prev_tensor->dims() == 2, + errors::InvalidArgument("cs_prev must be 2D")); + OP_REQUIRES(ctx, cs_prev_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("cs_prev.dims(0) != batch_size: ", + cs_prev_tensor->dim_size(0), " vs. ", + batch_size)); + const int64 cell_size = cs_prev_tensor->dim_size(1); + + if (batch_size * input_size % 2 == 1) { + LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and " + << "input_size are odd. You are using: batch_size=" + << batch_size << ", input_size=" << input_size; + } + if (batch_size * cell_size % 2 == 1) { + LOG(WARNING) << "BlockLSTMOp is inefficient when both batch_size and " + << "cell_size are odd. You are using: batch_size=" + << batch_size << ", cell_size=" << cell_size; + } const Tensor* h_prev_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("h_prev", &h_prev_tensor)); + OP_REQUIRES(ctx, h_prev_tensor->dims() == 2, + errors::InvalidArgument("h_prev must be 2D")); + OP_REQUIRES(ctx, h_prev_tensor->dim_size(0) == batch_size, + errors::InvalidArgument("h_prev.dims(0) != batch_size: ", + h_prev_tensor->dim_size(0), " vs. ", + batch_size)); + OP_REQUIRES(ctx, h_prev_tensor->dim_size(1) == cell_size, + errors::InvalidArgument("h_prev.dims(1) != cell_size: ", + h_prev_tensor->dim_size(1), " vs. ", + cell_size)); const Tensor* w_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("w", &w_tensor)); + OP_REQUIRES(ctx, w_tensor->dims() == 2, + errors::InvalidArgument("w must be 2D")); + OP_REQUIRES(ctx, w_tensor->dim_size(0) == input_size + cell_size, + errors::InvalidArgument( + "w.dim_size(0) != input_size + cell_size: ", + w_tensor->dim_size(0), " vs. ", input_size + cell_size)); + OP_REQUIRES( + ctx, w_tensor->dim_size(1) == cell_size * 4, + errors::InvalidArgument("w.dim_size(1) != cell_size * 4: ", + w_tensor->dim_size(1), " vs. ", cell_size * 4)); const Tensor* wci_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wci", &wci_tensor)); + OP_REQUIRES(ctx, wci_tensor->dims() == 1, + errors::InvalidArgument("wci must be 1D")); + OP_REQUIRES( + ctx, wci_tensor->dim_size(0) == cell_size, + errors::InvalidArgument("wci.dim_size(0) != cell_size: ", + wci_tensor->dim_size(0), " vs. ", cell_size)); const Tensor* wcf_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wcf", &wcf_tensor)); + OP_REQUIRES(ctx, wcf_tensor->dims() == 1, + errors::InvalidArgument("wcf must be 1D")); + OP_REQUIRES( + ctx, wcf_tensor->dim_size(0) == cell_size, + errors::InvalidArgument("wcf.dim_size(0) != cell_size: ", + wcf_tensor->dim_size(0), " vs. ", cell_size)); const Tensor* wco_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("wco", &wco_tensor)); + OP_REQUIRES(ctx, wco_tensor->dims() == 1, + errors::InvalidArgument("wco must be 1D")); + OP_REQUIRES( + ctx, wco_tensor->dim_size(0) == cell_size, + errors::InvalidArgument("wco.dim_size(0) != cell_size: ", + wco_tensor->dim_size(0), " vs. ", cell_size)); const Tensor* b_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("b", &b_tensor)); - const int64 cell_size = b_tensor->dim_size(0) / 4; - - OpOutputList i_list; - OP_REQUIRES_OK(ctx, ctx->output_list("i", &i_list)); - - OpOutputList cs_list; - OP_REQUIRES_OK(ctx, ctx->output_list("cs", &cs_list)); - - OpOutputList f_list; - OP_REQUIRES_OK(ctx, ctx->output_list("f", &f_list)); - - OpOutputList o_list; - OP_REQUIRES_OK(ctx, ctx->output_list("o", &o_list)); - - OpOutputList ci_list; - OP_REQUIRES_OK(ctx, ctx->output_list("ci", &ci_list)); - - OpOutputList co_list; - OP_REQUIRES_OK(ctx, ctx->output_list("co", &co_list)); + OP_REQUIRES(ctx, b_tensor->dims() == 1, + errors::InvalidArgument("b must be 1D")); + OP_REQUIRES( + ctx, b_tensor->dim_size(0) == cell_size * 4, + errors::InvalidArgument("b.dim_size(0) != cell_size * 4: ", + b_tensor->dim_size(0), " vs. ", cell_size * 4)); - OpOutputList h_list; - OP_REQUIRES_OK(ctx, ctx->output_list("h", &h_list)); + TensorShape batch_cell_shape({timelen, batch_size, cell_size}); + Tensor* i_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("i", batch_cell_shape, &i_out)); - TensorShape batch_cell_shape({batch_size, cell_size}); - for (int64 t = 0; t < max_len_; ++t) { - Tensor* i_tensor = nullptr; - OP_REQUIRES_OK(ctx, i_list.allocate(t, batch_cell_shape, &i_tensor)); + Tensor* cs_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("cs", batch_cell_shape, &cs_out)); - Tensor* cs_tensor = nullptr; - OP_REQUIRES_OK(ctx, cs_list.allocate(t, batch_cell_shape, &cs_tensor)); + Tensor* f_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("f", batch_cell_shape, &f_out)); - Tensor* f_tensor = nullptr; - OP_REQUIRES_OK(ctx, f_list.allocate(t, batch_cell_shape, &f_tensor)); + Tensor* o_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("o", batch_cell_shape, &o_out)); - Tensor* o_tensor = nullptr; - OP_REQUIRES_OK(ctx, o_list.allocate(t, batch_cell_shape, &o_tensor)); + Tensor* ci_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("ci", batch_cell_shape, &ci_out)); - Tensor* ci_tensor = nullptr; - OP_REQUIRES_OK(ctx, ci_list.allocate(t, batch_cell_shape, &ci_tensor)); + Tensor* co_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("co", batch_cell_shape, &co_out)); - Tensor* co_tensor = nullptr; - OP_REQUIRES_OK(ctx, co_list.allocate(t, batch_cell_shape, &co_tensor)); - - Tensor* h_tensor = nullptr; - OP_REQUIRES_OK(ctx, h_list.allocate(t, batch_cell_shape, &h_tensor)); - } + Tensor* h_out; + OP_REQUIRES_OK(ctx, ctx->allocate_output("h", batch_cell_shape, &h_out)); Tensor xh_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_temp( @@ -628,19 +775,22 @@ class BlockLSTMOp : public OpKernel { : nullptr; const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); + SliceHelper<Device, T> slicer(ctx); for (int64 t = 0; t < seq_len_max; ++t) { - const Tensor& x_tensor = x_list[t]; + const Tensor x_tensor = slicer.InputSlice(*x, t, "x"); const Tensor& cs_prev_tensor2 = - t == 0 ? *cs_prev_tensor : *cs_list[t - 1]; - const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : *h_list[t - 1]; - - Tensor* i_tensor = i_list[t]; - Tensor* cs_tensor = cs_list[t]; - Tensor* f_tensor = f_list[t]; - Tensor* o_tensor = o_list[t]; - Tensor* ci_tensor = ci_list[t]; - Tensor* co_tensor = co_list[t]; - Tensor* h_tensor = h_list[t]; + t == 0 ? *cs_prev_tensor + : slicer.OutputSlice(cs_out, t - 1, "cs_prev"); + const Tensor& h_prev_tensor2 = + t == 0 ? *h_prev_tensor : slicer.OutputSlice(h_out, t - 1, "h_prev"); + + Tensor i_tensor = slicer.OutputSlice(i_out, t, "i_out"); + Tensor cs_tensor = slicer.OutputSlice(cs_out, t, "cs_out"); + Tensor f_tensor = slicer.OutputSlice(f_out, t, "f_out"); + Tensor o_tensor = slicer.OutputSlice(o_out, t, "o_out"); + Tensor ci_tensor = slicer.OutputSlice(ci_out, t, "ci_out"); + Tensor co_tensor = slicer.OutputSlice(co_out, t, "co_out"); + Tensor h_tensor = slicer.OutputSlice(h_out, t, "h_out"); functor::LSTMBlockCellFprop<Device, T, USE_CUBLAS>(batch_size, input_size, cell_size)( @@ -648,23 +798,25 @@ class BlockLSTMOp : public OpKernel { x_tensor.matrix<T>(), cs_prev_tensor2.matrix<T>(), h_prev_tensor2.matrix<T>(), w_tensor->matrix<T>(), wci_tensor->vec<T>(), wcf_tensor->vec<T>(), wco_tensor->vec<T>(), - b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor->matrix<T>(), - cs_tensor->matrix<T>(), f_tensor->matrix<T>(), o_tensor->matrix<T>(), - ci_tensor->matrix<T>(), co_tensor->matrix<T>(), - icfo_tensor.matrix<T>(), h_tensor->matrix<T>()); + b_tensor->vec<T>(), xh_tensor.matrix<T>(), i_tensor.matrix<T>(), + cs_tensor.matrix<T>(), f_tensor.matrix<T>(), o_tensor.matrix<T>(), + ci_tensor.matrix<T>(), co_tensor.matrix<T>(), icfo_tensor.matrix<T>(), + h_tensor.matrix<T>()); + slicer.FinishTimeStep(); } - for (int64 t = seq_len_max; t < max_len_; ++t) { - Tensor* cs_tensor = cs_list[t]; - Tensor* h_tensor = h_list[t]; + if (seq_len_max < timelen) { + Tensor cs_tensor = cs_out->Slice(seq_len_max, timelen); + Tensor h_tensor = h_out->Slice(seq_len_max, timelen); - functor::TensorZero<Device, T>()(device, cs_tensor->flat<float>()); - functor::TensorZero<Device, T>()(device, h_tensor->flat<float>()); + functor::TensorUnalignedZero<Device, T>()( + device, cs_tensor.unaligned_flat<float>()); + functor::TensorUnalignedZero<Device, T>()( + device, h_tensor.unaligned_flat<float>()); } } private: - int64 max_len_; float forget_bias_; float cell_clip_; bool use_peephole_; @@ -685,7 +837,13 @@ namespace functor { void TensorZero<GPUDevice, T>::operator()(const GPUDevice& d, \ typename TTypes<T>::Flat t); \ \ - extern template struct TensorZero<GPUDevice, T>; + extern template struct TensorZero<GPUDevice, T>; \ + \ + template <> \ + void TensorUnalignedZero<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::UnalignedFlat t); \ + \ + extern template struct TensorUnalignedZero<GPUDevice, T>; DECLARE_GPU_SPEC(float); // DECLARE_GPU_SPEC(double); @@ -708,7 +866,6 @@ template <typename Device, typename T, bool USE_CUBLAS> class BlockLSTMGradOp : public OpKernel { public: explicit BlockLSTMGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("max_len", &max_len_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); } @@ -716,10 +873,12 @@ class BlockLSTMGradOp : public OpKernel { const Tensor* seq_len_max_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("seq_len_max", &seq_len_max_tensor)); - OpInputList x_list; - OP_REQUIRES_OK(ctx, ctx->input_list("x", &x_list)); - const int64 batch_size = x_list[0].dim_size(0); - const int64 input_size = x_list[0].dim_size(1); + const Tensor* x; + OP_REQUIRES_OK(ctx, ctx->input("x", &x)); + OP_REQUIRES(ctx, x->dims() == 3, errors::InvalidArgument("x must be 3D")); + const int64 timelen = x->dim_size(0); + const int64 batch_size = x->dim_size(1); + const int64 input_size = x->dim_size(2); const Tensor* cs_prev_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->input("cs_prev", &cs_prev_tensor)); @@ -751,35 +910,37 @@ class BlockLSTMGradOp : public OpKernel { errors::InvalidArgument("w and b cell_size don't match: ", cell_size, " vs. ", b_tensor->dim_size(0))); - OpInputList i_list; - OP_REQUIRES_OK(ctx, ctx->input_list("i", &i_list)); + const Tensor* i_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("i", &i_out)); - OpInputList cs_list; - OP_REQUIRES_OK(ctx, ctx->input_list("cs", &cs_list)); + const Tensor* cs_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs", &cs_out)); - OpInputList f_list; - OP_REQUIRES_OK(ctx, ctx->input_list("f", &f_list)); + const Tensor* f_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("f", &f_out)); - OpInputList o_list; - OP_REQUIRES_OK(ctx, ctx->input_list("o", &o_list)); + const Tensor* o_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("o", &o_out)); - OpInputList ci_list; - OP_REQUIRES_OK(ctx, ctx->input_list("ci", &ci_list)); + const Tensor* ci_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("ci", &ci_out)); - OpInputList co_list; - OP_REQUIRES_OK(ctx, ctx->input_list("co", &co_list)); + const Tensor* co_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("co", &co_out)); - OpInputList h_list; - OP_REQUIRES_OK(ctx, ctx->input_list("h", &h_list)); + const Tensor* h_out = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h", &h_out)); - OpInputList cs_grad_list; - OP_REQUIRES_OK(ctx, ctx->input_list("cs_grad", &cs_grad_list)); + const Tensor* cs_grad = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("cs_grad", &cs_grad)); - OpInputList h_grad_list; - OP_REQUIRES_OK(ctx, ctx->input_list("h_grad", &h_grad_list)); + const Tensor* h_grad = nullptr; + OP_REQUIRES_OK(ctx, ctx->input("h_grad", &h_grad)); - OpOutputList x_grad_list; - OP_REQUIRES_OK(ctx, ctx->output_list("x_grad", &x_grad_list)); + TensorShape batch_input_shape({timelen, batch_size, input_size}); + Tensor* x_grad; + OP_REQUIRES_OK(ctx, + ctx->allocate_output("x_grad", batch_input_shape, &x_grad)); Tensor* cs_prev_grad_tensor = nullptr; OP_REQUIRES_OK(ctx, @@ -811,13 +972,7 @@ class BlockLSTMGradOp : public OpKernel { OP_REQUIRES_OK( ctx, ctx->allocate_output("b_grad", b_tensor->shape(), &b_grad_tensor)); - TensorShape batch_input_shape({batch_size, input_size}); TensorShape batch_cell_shape({batch_size, cell_size}); - for (int64 t = 0; t < max_len_; ++t) { - Tensor* x_grad_tensor = nullptr; - OP_REQUIRES_OK( - ctx, x_grad_list.allocate(t, batch_input_shape, &x_grad_tensor)); - } Tensor xh_tensor; OP_REQUIRES_OK(ctx, ctx->allocate_temp( @@ -882,33 +1037,40 @@ class BlockLSTMGradOp : public OpKernel { functor::TensorZero<Device, T>()(device, b_grad_tensor->flat<float>()); const int64 seq_len_max = seq_len_max_tensor->scalar<int64>()(); + SliceHelper<Device, T> slicer(ctx); for (int64 t = seq_len_max - 1; t >= 0; --t) { - const Tensor& x_tensor = x_list[t]; - const Tensor& cs_prev_tensor2 = t == 0 ? *cs_prev_tensor : cs_list[t - 1]; - const Tensor& h_prev_tensor2 = t == 0 ? *h_prev_tensor : h_list[t - 1]; - const Tensor& i_tensor = i_list[t]; - const Tensor& cs_tensor = cs_list[t]; - const Tensor& f_tensor = f_list[t]; - const Tensor& o_tensor = o_list[t]; - const Tensor& ci_tensor = ci_list[t]; - const Tensor& co_tensor = co_list[t]; + const Tensor& x_tensor = slicer.InputSlice(*x, t, "x"); + const Tensor& cs_prev_tensor2 = + t == 0 ? *cs_prev_tensor + : slicer.InputSlice(*cs_out, t - 1, "cs_prev"); + const Tensor& h_prev_tensor2 = + t == 0 ? *h_prev_tensor : slicer.InputSlice(*h_out, t - 1, "h_prev"); + const Tensor& i_tensor = slicer.InputSlice(*i_out, t, "i_out"); + const Tensor& cs_tensor = slicer.InputSlice(*cs_out, t, "cs_out"); + const Tensor& f_tensor = slicer.InputSlice(*f_out, t, "f_out"); + const Tensor& o_tensor = slicer.InputSlice(*o_out, t, "o_out"); + const Tensor& ci_tensor = slicer.InputSlice(*ci_out, t, "ci_out"); + const Tensor& co_tensor = slicer.InputSlice(*co_out, t, "co_out"); // Grab previous CS grad. const Tensor& const_cs_prev_grad_tensor = *cs_prev_grad_tensor; + const Tensor const_cs_grad_slice = + slicer.InputSlice(*cs_grad, t, "cs_grad"); functor::TensorAdd<Device, T>()( device, const_cs_prev_grad_tensor.flat<T>(), - cs_grad_list[t].flat<T>(), cs_grad_tensor.flat<T>()); + const_cs_grad_slice.flat<T>(), cs_grad_tensor.flat<T>()); // Combine previous h grad and h grad coming on top. const Tensor& const_h_prev_grad_tensor = *h_prev_grad_tensor; + const Tensor const_h_grad_slice = slicer.InputSlice(*h_grad, t, "h_grad"); functor::TensorAdd<Device, T>()( - device, const_h_prev_grad_tensor.flat<T>(), h_grad_list[t].flat<T>(), - h_grad_tensor.flat<T>()); + device, const_h_prev_grad_tensor.flat<T>(), + const_h_grad_slice.flat<T>(), h_grad_tensor.flat<T>()); const Tensor& const_cs_grad_tensor = cs_grad_tensor; const Tensor& const_h_grad_tensor = h_grad_tensor; - Tensor* x_grad_tensor = x_grad_list[t]; + Tensor x_grad_tensor = slicer.OutputSlice(x_grad, t, "x_grad"); functor::BlockLSTMBprop<Device, T, USE_CUBLAS>(batch_size, input_size, cell_size)( ctx, stream, device, use_peephole_, x_tensor.matrix<T>(), @@ -922,19 +1084,20 @@ class BlockLSTMGradOp : public OpKernel { df_tensor.matrix<T>(), di_tensor.matrix<T>(), dicfo_tensor.matrix<T>(), cs_prev_grad_tensor->matrix<T>(), h_prev_grad_tensor->matrix<T>(), xh_grad_tensor.matrix<T>(), - x_grad_tensor->matrix<T>(), w_grad_tensor->matrix<T>(), + x_grad_tensor.matrix<T>(), w_grad_tensor->matrix<T>(), wci_grad_tensor->vec<T>(), wcf_grad_tensor->vec<T>(), wco_grad_tensor->vec<T>(), b_grad_tensor->vec<T>()); + slicer.FinishTimeStep(); } - for (int64 t = seq_len_max; t < max_len_; ++t) { - Tensor* x_grad_tensor = x_grad_list[t]; - functor::TensorZero<Device, T>()(device, x_grad_tensor->flat<T>()); + if (seq_len_max < timelen) { + Tensor x_grad_tensor = x_grad->Slice(seq_len_max, timelen); + functor::TensorUnalignedZero<Device, T>()( + device, x_grad_tensor.unaligned_flat<T>()); } } private: - int64 max_len_; bool use_peephole_; }; @@ -955,6 +1118,16 @@ namespace functor { typename TTypes<T>::Flat dst); \ \ template <> \ + void TensorCopyUnaligned<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::UnalignedConstFlat src, \ + typename TTypes<T>::Flat dst); \ + \ + template <> \ + void TensorCopyToUnaligned<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T>::ConstFlat src, \ + typename TTypes<T>::UnalignedFlat dst); \ + \ + template <> \ void TensorAdd<GPUDevice, T>::operator()( \ const GPUDevice& d, typename TTypes<T>::ConstFlat a, \ typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c); \ diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index 5a9dda5755..1332b88002 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -41,6 +41,13 @@ struct TensorZero { }; template <typename Device, typename T> +struct TensorUnalignedZero { + void operator()(const Device& d, typename TTypes<T>::UnalignedFlat t) { + t.device(d) = t.constant(T(0)); + } +}; + +template <typename Device, typename T> struct TensorCopy { void operator()(const Device& d, typename TTypes<T>::ConstFlat src, typename TTypes<T>::Flat dst) { @@ -49,6 +56,22 @@ struct TensorCopy { }; template <typename Device, typename T> +struct TensorCopyUnaligned { + void operator()(const Device& d, typename TTypes<T>::UnalignedConstFlat src, + typename TTypes<T>::Flat dst) { + dst.device(d) = src; + } +}; + +template <typename Device, typename T> +struct TensorCopyToUnaligned { + void operator()(const Device& d, typename TTypes<T>::ConstFlat src, + typename TTypes<T>::UnalignedFlat dst) { + dst.device(d) = src; + } +}; + +template <typename Device, typename T> struct TensorAdd { void operator()(const Device& d, typename TTypes<T>::ConstFlat a, typename TTypes<T>::ConstFlat b, typename TTypes<T>::Flat c) { diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index 47298d9501..b33ca5fc8d 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -26,7 +26,10 @@ typedef Eigen::GpuDevice GPUDevice; #define DEFINE_GPU_SPECS(T) \ template struct TensorZero<GPUDevice, T>; \ + template struct TensorUnalignedZero<GPUDevice, T>; \ template struct TensorCopy<GPUDevice, T>; \ + template struct TensorCopyUnaligned<GPUDevice, T>; \ + template struct TensorCopyToUnaligned<GPUDevice, T>; \ template struct TensorAdd<GPUDevice, T>; \ template struct LSTMBlockCellFprop<GPUDevice, T, true>; \ template struct LSTMBlockCellBprop<GPUDevice, T, true>; \ diff --git a/tensorflow/contrib/rnn/ops/lstm_ops.cc b/tensorflow/contrib/rnn/ops/lstm_ops.cc index 08dbd3c822..2de40825c9 100644 --- a/tensorflow/contrib/rnn/ops/lstm_ops.cc +++ b/tensorflow/contrib/rnn/ops/lstm_ops.cc @@ -58,8 +58,8 @@ REGISTER_OP("LSTMBlockCell") .Doc(R"doc( Computes the LSTM cell forward propagation for 1 time step. -This implementation uses 1 weight matrix and 1 bias vector, there is no -diagonal peephole connection. +This implementation uses 1 weight matrix and 1 bias vector, and there's an +optional peephole connection. This kernel op implements the following mathematical equations: @@ -68,21 +68,34 @@ xh = [x, h_prev] [i, f, ci, o] = xh * w + b f = f + forget_bias -i = sigmoid(i) -f = sigmoid(f) +if not use_peephole: + wci = wcf = wco = 0 + +i = sigmoid(cs_prev * wci + i) +f = sigmoid(cs_prev * wcf + f) ci = tanh(ci) -o = sigmoid(o) cs = ci .* i + cs_prev .* f -co = tanh(cs) +cs = clip(cs, cell_clip) +o = sigmoid(cs * wco + f) +co = tanh(cs) h = co .* o ``` +cell_clip: Value to clip the 'cs' value to. +use_peephole: Whether to use peephole weights. forget_bias: The forget gate bias. -x: The input to the LSTM cell. + +x: The input to the LSTM cell, shape (batch_size, num_inputs). +cs_prev: Value of the cell state at previous time step. +h_prev: Output of the previous cell at previous time step. w: The weight matrix. +wci: The weight matrix for input gate peephole connection. +wcf: The weight matrix for forget gate peephole connection. +wco: The weight matrix for output gate peephole connection. b: The bias vector. + i: The input gate. cs: The cell state before the tanh. f: The forget gate. @@ -139,10 +152,14 @@ Computes the LSTM cell backward propagation for 1 timestep. This implementation is to be used in conjunction of LSTMBlockCell. -x: The input to the LSTM cell. +use_peephole: Whether the cell uses peephole connections. +x: The input to the LSTM cell, shape (batch_size, num_inputs). cs_prev: The previous cell state. h_prev: The previous h state. w: The weight matrix. +wci: The weight matrix for input gate peephole connection. +wcf: The weight matrix for forget gate peephole connection. +wco: The weight matrix for output gate peephole connection. b: The bias vector. i: The input gate. cs: The cell state before the tanh. @@ -150,14 +167,18 @@ f: The forget gate. o: The output gate. ci: The cell input. co: The cell after the tanh. -h_grad: THe gradient of h vector. -cs_prev_grad: The gradient of cs. +cs_grad: The current gradient of cs. +h_grad: The gradient of h vector. +cs_prev_grad: The gradient of cs to be back-propped. dicfo: The derivative wrt to [i, cs, f, o]. +wci_grad: The gradient for wci to be back-propped. +wcf_grad: The gradient for wcf to be back-propped. +wco_grad: The gradient for wco to be back-propped. )doc"); REGISTER_OP("BlockLSTM") .Input("seq_len_max: int64") - .Input("x: max_len * T") + .Input("x: T") .Input("cs_prev: T") .Input("h_prev: T") .Input("w: T") @@ -165,46 +186,83 @@ REGISTER_OP("BlockLSTM") .Input("wcf: T") .Input("wco: T") .Input("b: T") - .Output("i: max_len * T") - .Output("cs: max_len * T") - .Output("f: max_len * T") - .Output("o: max_len * T") - .Output("ci: max_len * T") - .Output("co: max_len * T") - .Output("h: max_len * T") - .Attr("max_len: int") + .Output("i: T") + .Output("cs: T") + .Output("f: T") + .Output("o: T") + .Output("ci: T") + .Output("co: T") + .Output("h: T") .Attr("forget_bias: float = 1.0") .Attr("cell_clip: float = 3.0") .Attr("use_peephole: bool = false") .Attr("T: {float}") .SetShapeFn([](InferenceContext* c) { ShapeHandle x, b; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b)); - DimensionHandle batch_size = c->Dim(x, 0); + DimensionHandle timelen = c->Dim(x, 0); + DimensionHandle batch_size = c->Dim(x, 1); DimensionHandle cell_size; TF_RETURN_IF_ERROR( c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size)); - int64 max_len; - TF_RETURN_IF_ERROR(c->GetAttr("max_len", &max_len)); - - DCHECK_EQ(max_len * 7, c->num_outputs()); - ShapeHandle output = c->Matrix(batch_size, cell_size); - for (int i = 0; i < max_len; ++i) { - for (int j = 0; j < 7; ++j) { - c->set_output(i * 7 + j, output); - } + DCHECK_EQ(7, c->num_outputs()); + ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size}); + for (int i = 0; i < 7; ++i) { + c->set_output(i, output); } return Status::OK(); }) .Doc(R"doc( +Computes the LSTM cell forward propagation for all the time steps. + +This is equivalent to applying LSTMBlockCell in a loop, like so: + +```python +for x1 in unpack(x): + i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock( + x1, cs_prev, h_prev, w, wci, wcf, wco, b) + cs_prev = cs1 + h_prev = h1 + i.append(i1) + cs.append(cs1) + f.append(f1) + o.append(o1) + ci.append(ci1) + co.append(co1) + h.append(h1) +return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h) +``` + +cell_clip: Value to clip the 'cs' value to. +use_peephole: Whether to use peephole weights. +forget_bias: The forget gate bias. + +seq_len_max: Maximum time length actually used by this input. Outputs are padded + with zeros beyond this length. +x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs). +cs_prev: Value of the initial cell state. +h_prev: Initial output of cell (to be used for peephole). +w: The weight matrix. +wci: The weight matrix for input gate peephole connection. +wcf: The weight matrix for forget gate peephole connection. +wco: The weight matrix for output gate peephole connection. +b: The bias vector. + +i: The input gate over the whole time sequence. +cs: The cell state before the tanh over the whole time sequence. +f: The forget gate over the whole time sequence. +o: The output gate over the whole time sequence. +ci: The cell input over the whole time sequence. +co: The cell after the tanh over the whole time sequence. +h: The output h vector over the whole time sequence. )doc"); REGISTER_OP("BlockLSTMGrad") .Input("seq_len_max: int64") - .Input("x: max_len * T") + .Input("x: T") .Input("cs_prev: T") .Input("h_prev: T") .Input("w: T") @@ -212,16 +270,16 @@ REGISTER_OP("BlockLSTMGrad") .Input("wcf: T") .Input("wco: T") .Input("b: T") - .Input("i: max_len * T") - .Input("cs: max_len * T") - .Input("f: max_len * T") - .Input("o: max_len * T") - .Input("ci: max_len * T") - .Input("co: max_len * T") - .Input("h: max_len * T") - .Input("cs_grad: max_len * T") - .Input("h_grad: max_len * T") - .Output("x_grad: max_len * T") + .Input("i: T") + .Input("cs: T") + .Input("f: T") + .Input("o: T") + .Input("ci: T") + .Input("co: T") + .Input("h: T") + .Input("cs_grad: T") + .Input("h_grad: T") + .Output("x_grad: T") .Output("cs_prev_grad: T") .Output("h_prev_grad: T") .Output("w_grad: T") @@ -229,36 +287,65 @@ REGISTER_OP("BlockLSTMGrad") .Output("wcf_grad: T") .Output("wco_grad: T") .Output("b_grad: T") - .Attr("max_len: int") .Attr("use_peephole: bool") .Attr("T: {float}") .SetShapeFn([](InferenceContext* c) { - int64 max_len; - TF_RETURN_IF_ERROR(c->GetAttr("max_len", &max_len)); - ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &x)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(1 + max_len), 2, &cs_prev)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2 + max_len), 2, &h_prev)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(3 + max_len), 2, &w)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(4 + max_len), 1, &wci)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(5 + max_len), 1, &wco)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(6 + max_len), 1, &wcf)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(7 + max_len), 1, &b)); - - int out_idx = 0; - for (int i = 0; i < max_len; ++i) c->set_output(out_idx++, x); - c->set_output(out_idx++, cs_prev); - c->set_output(out_idx++, h_prev); - c->set_output(out_idx++, w); - c->set_output(out_idx++, wci); - c->set_output(out_idx++, wco); - c->set_output(out_idx++, wcf); - c->set_output(out_idx++, b); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b)); + + c->set_output(0, x); + c->set_output(1, cs_prev); + c->set_output(2, h_prev); + c->set_output(3, w); + c->set_output(4, wci); + c->set_output(5, wco); + c->set_output(6, wcf); + c->set_output(7, b); return Status::OK(); }) .Doc(R"doc( +Computes the LSTM cell backward propagation for the entire time sequence. + +This implementation is to be used in conjunction of LSTMBlock. + +use_peephole: Whether to use peephole weights. + +seq_len_max: Maximum time length actually used by this input. Outputs are padded + with zeros beyond this length. +x: The sequence input to the LSTM, shape (timelen, batch_size, num_inputs). +cs_prev: Value of the initial cell state. +h_prev: Initial output of cell (to be used for peephole). +w: The weight matrix. +wci: The weight matrix for input gate peephole connection. +wcf: The weight matrix for forget gate peephole connection. +wco: The weight matrix for output gate peephole connection. +b: The bias vector. +i: The input gate over the whole time sequence. +cs: The cell state before the tanh over the whole time sequence. +f: The forget gate over the whole time sequence. +o: The output gate over the whole time sequence. +ci: The cell input over the whole time sequence. +co: The cell after the tanh over the whole time sequence. +h: The output h vector over the whole time sequence. +cs_grad: The current gradient of cs. +h_grad: The gradient of h vector. + +x_grad: The gradient of x to be back-propped. +cs_prev_grad: The gradient of cs_prev to be back-propped. +h_prev_grad: The gradient of h_prev to be back-propped. +w_grad: The gradient for w to be back-propped. +wci_grad: The gradient for wci to be back-propped. +wcf_grad: The gradient for wcf to be back-propped. +wco_grad: The gradient for wco to be back-propped. +b_grad: The gradient for w to be back-propped. )doc"); } // end namespace tensorflow diff --git a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc index 4acbfa7ced..a25e0674ea 100644 --- a/tensorflow/contrib/rnn/ops/lstm_ops_test.cc +++ b/tensorflow/contrib/rnn/ops/lstm_ops_test.cc @@ -79,118 +79,85 @@ TEST_F(LSTMOpsTest, LSTMBlockCellGrad_ShapeFn) { TEST_F(LSTMOpsTest, BlockLSTM_ShapeFn) { ShapeInferenceTestOp op("BlockLSTM"); - auto set_op = [&op](int max_len) { - std::vector<NodeDefBuilder::NodeOut> x_list; - for (int i = 0; i < max_len; ++i) x_list.emplace_back("a", 0, DT_FLOAT); - TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTM") - .Input({"seq_len_max", 0, DT_INT64}) - .Input(x_list) - .Input({"cs_prev", 0, DT_FLOAT}) - .Input({"h_prev", 0, DT_FLOAT}) - .Input({"w", 0, DT_FLOAT}) - .Input({"wci", 0, DT_FLOAT}) - .Input({"wcf", 0, DT_FLOAT}) - .Input({"wco", 0, DT_FLOAT}) - .Input({"b", 0, DT_FLOAT}) - .Attr("max_len", max_len) - .Finalize(&op.node_def)); - }; - - for (int max_len = 1; max_len < 11; max_len += 3) { - set_op(max_len); - int num_in = max_len + 8; - int num_out = max_len * 7; - - // Middle inputs don't affect shape inference. - string infix = ";" + JoinedCopies("?", num_in - 3) + ";"; - - // Rank checks. - INFER_ERROR("must be rank 2", op, "?;[?]" + infix + "?"); - INFER_ERROR("must be rank 1", op, "?;?" + infix + "[?,?]"); - - // Output - INFER_OK(op, "?;?" + infix + "?", JoinedCopies("[?,?]", num_out)); - INFER_OK(op, "?;[?,?]" + infix + "?", JoinedCopies("[d1_0,?]", num_out)); - INFER_OK(op, "?;[?,?]" + infix + "[?]", JoinedCopies("[d1_0,?]", num_out)); - INFER_OK(op, "?;[?,?]" + infix + "[20]", JoinedCopies("[d1_0,5]", num_out)); - - // cell_size must be divisible by 4. - INFER_ERROR("must be evenly divisible", op, "?;?" + infix + "[11]"); - } + TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTM") + .Input({"seq_len_max", 0, DT_INT64}) + .Input({"x", 0, DT_FLOAT}) + .Input({"cs_prev", 0, DT_FLOAT}) + .Input({"h_prev", 0, DT_FLOAT}) + .Input({"w", 0, DT_FLOAT}) + .Input({"wci", 0, DT_FLOAT}) + .Input({"wcf", 0, DT_FLOAT}) + .Input({"wco", 0, DT_FLOAT}) + .Input({"b", 0, DT_FLOAT}) + .Finalize(&op.node_def)); + + // Middle inputs don't affect shape inference. + string infix = ";" + JoinedCopies("?", 6) + ";"; + + // Rank checks. + INFER_ERROR("must be rank 3", op, "?;[?]" + infix + "?"); + INFER_ERROR("must be rank 1", op, "?;?" + infix + "[?,?]"); + + // Output + INFER_OK(op, "?;?" + infix + "?", JoinedCopies("[?,?,?]", 7)); + INFER_OK(op, "?;[?,?,?]" + infix + "?", JoinedCopies("[d1_0,d1_1,?]", 7)); + INFER_OK(op, "?;[?,?,?]" + infix + "[?]", JoinedCopies("[d1_0,d1_1,?]", 7)); + INFER_OK(op, "?;[?,?,?]" + infix + "[20]", JoinedCopies("[d1_0,d1_1,5]", 7)); + + // cell_size must be divisible by 4. + INFER_ERROR("must be evenly divisible", op, "?;?" + infix + "[11]"); } TEST_F(LSTMOpsTest, BlockLSTMGrad_ShapeFn) { ShapeInferenceTestOp op("BlockLSTMGrad"); - - auto set_op = [&op](int max_len) { - std::vector<NodeDefBuilder::NodeOut> x_list; - for (int i = 0; i < max_len; ++i) x_list.emplace_back("a", 0, DT_FLOAT); - TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTMGrad") - .Input({"seq_len_max", 0, DT_INT64}) - .Input(x_list) - .Input({"cs_prev", 0, DT_FLOAT}) - .Input({"h_prev", 0, DT_FLOAT}) - .Input({"w", 0, DT_FLOAT}) - .Input({"wci", 0, DT_FLOAT}) - .Input({"wcf", 0, DT_FLOAT}) - .Input({"wco", 0, DT_FLOAT}) - .Input({"b", 0, DT_FLOAT}) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Input(x_list) - .Attr("max_len", max_len) - .Finalize(&op.node_def)); - }; - - for (int max_len = 1; max_len < 11; max_len += 3) { - set_op(max_len); - int num_in = max_len * 10 + 8; - int num_out = max_len + 7; - - // Last inputs don't affect shape inference. - string suffix = ";" + JoinedCopies("?", 9 * max_len); - - // Rank check for x - string invalid_x = JoinedCopies("[?]", max_len); - INFER_ERROR("must be rank 2", op, - "?;" + invalid_x + ";?;?;?;?;?;?;?" + suffix); - - // Rank checks for cs_prev through b. - string unknown_x = JoinedCopies("?", max_len); - INFER_ERROR("must be rank 2", op, - "?;" + unknown_x + ";[1];?;?;?;?;?;?" + suffix); - INFER_ERROR("must be rank 2", op, - "?;" + unknown_x + ";?;[1];?;?;?;?;?" + suffix); - INFER_ERROR("must be rank 2", op, - "?;" + unknown_x + ";?;?;[1];?;?;?;?" + suffix); - INFER_ERROR("must be rank 1", op, - "?;" + unknown_x + ";?;?;?;[1,?];?;?;?" + suffix); - INFER_ERROR("must be rank 1", op, - "?;" + unknown_x + ";?;?;?;?;[1,?];?;?" + suffix); - INFER_ERROR("must be rank 1", op, - "?;" + unknown_x + ";?;?;?;?;?;[1,?];?" + suffix); - INFER_ERROR("must be rank 1", op, - "?;" + unknown_x + ";?;?;?;?;?;?;[1,?]" + suffix); - - // Output with all input knowns makes known rank outputs. - INFER_OK(op, JoinedCopies("?", num_in), - JoinedCopies("[?,?]", num_out - 4) + ";" + JoinedCopies("[?]", 4)); - - // Output with copies input shapes to output. - string input = strings::StrCat("?;", JoinedCopies("[?,?]", max_len + 3), - ";", JoinedCopies("[?]", 4), suffix); - string expected = JoinedCopies("in1", max_len); // copies of x; - for (int i = max_len; i < num_out; ++i) { - strings::StrAppend(&expected, ";in", (i + 1)); - } - INFER_OK(op, input, expected); + TF_ASSERT_OK(NodeDefBuilder("test", "BlockLSTMGrad") + .Input({"seq_len_max", 0, DT_INT64}) + .Input({"x", 0, DT_FLOAT}) + .Input({"cs_prev", 0, DT_FLOAT}) + .Input({"h_prev", 0, DT_FLOAT}) + .Input({"w", 0, DT_FLOAT}) + .Input({"wci", 0, DT_FLOAT}) + .Input({"wcf", 0, DT_FLOAT}) + .Input({"wco", 0, DT_FLOAT}) + .Input({"b", 0, DT_FLOAT}) + .Input({"i", 0, DT_FLOAT}) + .Input({"cs", 0, DT_FLOAT}) + .Input({"f", 0, DT_FLOAT}) + .Input({"o", 0, DT_FLOAT}) + .Input({"ci", 0, DT_FLOAT}) + .Input({"co", 0, DT_FLOAT}) + .Input({"h", 0, DT_FLOAT}) + .Input({"cs_grad", 0, DT_FLOAT}) + .Input({"h_grad", 0, DT_FLOAT}) + .Finalize(&op.node_def)); + + // Last inputs don't affect shape inference. + string suffix = ";" + JoinedCopies("?", 9); + + // Rank check for x + INFER_ERROR("must be rank 3", op, "?;[?];?;?;?;?;?;?;?" + suffix); + + // Rank checks for cs_prev through b. + INFER_ERROR("must be rank 2", op, "?;?;[1];?;?;?;?;?;?" + suffix); + INFER_ERROR("must be rank 2", op, "?;?;?;[1];?;?;?;?;?" + suffix); + INFER_ERROR("must be rank 2", op, "?;?;?;?;[1];?;?;?;?" + suffix); + INFER_ERROR("must be rank 1", op, "?;?;?;?;?;[1,?];?;?;?" + suffix); + INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;[1,?];?;?" + suffix); + INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;?;[1,?];?" + suffix); + INFER_ERROR("must be rank 1", op, "?;?;?;?;?;?;?;?;[1,?]" + suffix); + + // Output with all input knowns makes known rank outputs. + INFER_OK(op, JoinedCopies("?", 18), "[?,?,?];" + JoinedCopies("[?,?]", 3) + + ";" + JoinedCopies("[?]", 4)); + + // Output with copies input shapes to output. + string input = strings::StrCat("?;[?,?,?];", JoinedCopies("[?,?]", 3), ";", + JoinedCopies("[?]", 4), suffix); + string expected = "in1"; + for (int i = 1; i < 8; ++i) { + strings::StrAppend(&expected, ";in", (i + 1)); } + INFER_OK(op, input, expected); } } // namespace tensorflow diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 7d863d17dc..4979856668 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """LSTM Block Cell ops.""" from __future__ import absolute_import from __future__ import division @@ -50,8 +49,8 @@ def _lstm_block_cell(x, name=None): r"""Computes the LSTM cell forward propagation for 1 time step. - This implementation uses 1 weight matrix and 1 bias vector, there is no - diagonal peephole connection. + This implementation uses 1 weight matrix and 1 bias vector, and there's an + optional peephole connection. This kernel op implements the following mathematical equations: @@ -60,30 +59,41 @@ def _lstm_block_cell(x, [i, f, ci, o] = xh * w + b f = f + forget_bias - i = sigmoid(i) - f = sigmoid(f) + if not use_peephole: + wci = wcf = wco = 0 + + i = sigmoid(cs_prev * wci + i) + f = sigmoid(cs_prev * wcf + f) ci = tanh(ci) - o = sigmoid(o) cs = ci .* i + cs_prev .* f - co = tanh(cs) + cs = clip(cs, cell_clip) + o = sigmoid(cs * wco + f) + co = tanh(cs) h = co .* o ``` Args: - x: A `Tensor`. Must be one of the following types: `float32`, `float64`. - The input to the LSTM cell. + x: A `Tensor`. Must be one of the following types: `float32`. + The input to the LSTM cell, shape (batch_size, num_inputs). cs_prev: A `Tensor`. Must have the same type as `x`. + Value of the cell state at previous time step. h_prev: A `Tensor`. Must have the same type as `x`. + Output of the previous cell at previous time step. w: A `Tensor`. Must have the same type as `x`. The weight matrix. b: A `Tensor`. Must have the same type as `x`. The bias vector. wci: A `Tensor`. Must have the same type as `x`. + The weight matrix for input gate peephole connection. wcf: A `Tensor`. Must have the same type as `x`. + The weight matrix for forget gate peephole connection. wco: A `Tensor`. Must have the same type as `x`. + The weight matrix for output gate peephole connection. forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. cell_clip: An optional `float`. Defaults to `3`. + Value to clip the 'cs' value to. use_peephole: An optional `bool`. Defaults to `False`. + Whether to use peephole weights. name: A name for the operation (optional). Returns: @@ -108,18 +118,19 @@ def _lstm_block_cell(x, wcf = wci # pylint: disable=protected-access - return _lstm_ops_so.lstm_block_cell(x=x, - cs_prev=cs_prev, - h_prev=h_prev, - w=w, - wci=wci, - wco=wco, - wcf=wcf, - b=b, - forget_bias=forget_bias, - cell_clip=cell_clip, - use_peephole=use_peephole, - name=name) + return _lstm_ops_so.lstm_block_cell( + x=x, + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=wci, + wco=wco, + wcf=wcf, + b=b, + forget_bias=forget_bias, + cell_clip=cell_clip, + use_peephole=use_peephole, + name=name) # pylint: enable=protected-access @@ -180,9 +191,8 @@ def _block_lstm(seq_len_max, cell_size = cell_size4 / 4 zero_state = None if cs_prev is None or h_prev is None: - zero_state = array_ops.constant(0, - dtype=dtypes.float32, - shape=[batch_size, cell_size]) + zero_state = array_ops.constant( + 0, dtype=dtypes.float32, shape=[batch_size, cell_size]) if cs_prev is None: cs_prev = zero_state if h_prev is None: @@ -193,26 +203,30 @@ def _block_lstm(seq_len_max, wcf = wci # pylint: disable=protected-access - return _lstm_ops_so.block_lstm(seq_len_max=seq_len_max, - x=x, - cs_prev=cs_prev, - h_prev=h_prev, - w=w, - wci=wci, - wco=wco, - wcf=wcf, - b=b, - forget_bias=forget_bias, - cell_clip=cell_clip, - name=name, - use_peephole=use_peephole) + i, cs, f, o, ci, co, h = _lstm_ops_so.block_lstm( + seq_len_max=seq_len_max, + x=array_ops.pack(x), + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=wci, + wco=wco, + wcf=wcf, + b=b, + forget_bias=forget_bias, + cell_clip=cell_clip, + name=name, + use_peephole=use_peephole) + + return array_ops.unpack(i), array_ops.unpack(cs), array_ops.unpack( + f), array_ops.unpack(o), array_ops.unpack(ci), array_ops.unpack( + co), array_ops.unpack(h) # pylint: enable=protected-access # pylint: enable=invalid-name _lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] - ops.RegisterShape("LSTMBlockCell")(common_shapes.call_cpp_shape_fn) @@ -283,28 +297,11 @@ ops.RegisterShape("BlockLSTM")(common_shapes.call_cpp_shape_fn) @ops.RegisterGradient("BlockLSTM") def _BlockLSTMGrad(op, *grad): """Gradient for BlockLSTM.""" - max_len = op.get_attr("max_len") - - seq_len_max = op.inputs[0] - x = op.inputs[1:1 + max_len] - cs_prev = op.inputs[-7] - h_prev = op.inputs[-6] - w = op.inputs[-5] - wci = op.inputs[-4] - wco = op.inputs[-3] - wcf = op.inputs[-2] - b = op.inputs[-1] - - i = op.outputs[0 * max_len:1 * max_len] - cs = op.outputs[1 * max_len:2 * max_len] - f = op.outputs[2 * max_len:3 * max_len] - o = op.outputs[3 * max_len:4 * max_len] - ci = op.outputs[4 * max_len:5 * max_len] - co = op.outputs[5 * max_len:6 * max_len] - h = op.outputs[6 * max_len:7 * max_len] - - cs_grad = grad[-max_len * 2:-max_len] - h_grad = grad[-max_len:] + seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs + i, cs, f, o, ci, co, h = op.outputs + + cs_grad = grad[1] + h_grad = grad[6] (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad, b_grad) = _lstm_ops_so.block_lstm_grad( @@ -328,8 +325,8 @@ def _BlockLSTMGrad(op, *grad): h_grad, use_peephole=op.get_attr("use_peephole")) - return [None] + x_grad + [cs_prev_grad, h_prev_grad, w_grad, wci_grad, - wco_grad, wcf_grad, b_grad] + return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, + wcf_grad, b_grad] ops.RegisterShape("BlockLSTMGrad")(common_shapes.call_cpp_shape_fn) @@ -379,21 +376,23 @@ class LSTMBlockCell(rnn_cell.RNNCell): input_size = x_shape[1] w = vs.get_variable("W", [input_size + self._num_units, self._num_units * 4]) - b = vs.get_variable("b", [w.get_shape().with_rank(2)[1]], - initializer=init_ops.constant_initializer(0.0)) + b = vs.get_variable( + "b", [w.get_shape().with_rank(2)[1]], + initializer=init_ops.constant_initializer(0.0)) wci = vs.get_variable("wci", [self._num_units]) wco = vs.get_variable("wco", [self._num_units]) wcf = vs.get_variable("wcf", [self._num_units]) (cs_prev, h_prev) = states_prev - (_, cs, _, _, _, _, h) = _lstm_block_cell(x, - cs_prev, - h_prev, - w, - b, - wci=wci, - wco=wco, - wcf=wcf, - forget_bias=self._forget_bias, - use_peephole=self._use_peephole) + (_, cs, _, _, _, _, h) = _lstm_block_cell( + x, + cs_prev, + h_prev, + w, + b, + wci=wci, + wco=wco, + wcf=wcf, + forget_bias=self._forget_bias, + use_peephole=self._use_peephole) return (h, (cs, h)) |