diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r-- | tensorflow/core/kernels/xent_op.cc | 65 |
1 files changed, 45 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index a6a71fdfaf..9a3612bd72 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -17,12 +17,14 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/core/kernels/xent_op.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/xent_op.h" +#include "tensorflow/core/util/bcast.h" namespace tensorflow { @@ -41,37 +43,56 @@ class SoftmaxXentWithLogitsOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& logits_in = context->input(0); const Tensor& labels_in = context->input(1); - OP_REQUIRES(context, logits_in.IsSameSize(labels_in), - errors::InvalidArgument( - "logits and labels must be same size: logits_size=", - logits_in.shape().DebugString(), - " labels_size=", labels_in.shape().DebugString())); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()), - errors::InvalidArgument("logits must be 2-dimensional")); - // As we already tested that both inputs have the same shape no need to - // check that "labels" is a matrix too. + + TensorShape shape_in = logits_in.shape(); + + BCast bcast(BCast::FromShape(logits_in.shape()), + BCast::FromShape(labels_in.shape())); + if (!logits_in.IsSameSize(labels_in)) { + OP_REQUIRES(context, bcast.IsValid(), + errors::InvalidArgument( + "logits and labels must be broadcastable: logits_size=", + logits_in.shape().DebugString(), + " labels_size=", labels_in.shape().DebugString())); + shape_in = BCast::ToShape(bcast.output_shape()); + } + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in), + errors::InvalidArgument("logits and labels must be beither " + "2-dimensional, or roadcasted to " + "2-dimensional")); // loss is 1-D (one per example), and size is batch_size. Tensor scratch; OP_REQUIRES_OK( context, context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({logits_in.dim_size(0), 1}), + TensorShape({shape_in.dim_size(0), 1}), &scratch)); Tensor* loss_out = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 0, TensorShape({logits_in.dim_size(0)}), &loss_out)); + 0, TensorShape({shape_in.dim_size(0)}), &loss_out)); Tensor* back_out = nullptr; // Try to reuse the logits_in buffer for the backprop output. OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( - {0}, 1, logits_in.shape(), &back_out)); - if (logits_in.dim_size(0) > 0) { + {0}, 1, shape_in, &back_out)); + if (shape_in.dim_size(0) > 0) { functor::XentFunctor<Device, T> functor; - functor(context->eigen_device<Device>(), logits_in.matrix<T>(), - labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), - back_out->matrix<T>()); + if (logits_in.IsSameSize(labels_in)) { + functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(), + Eigen::array<Eigen::DenseIndex, 2>{1, 1}, + Eigen::array<Eigen::DenseIndex, 2>{1, 1}, logits_in.matrix<T>(), + labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(), + back_out->matrix<T>()); + } else { + functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(), + BCast::ToIndexArray<2>(bcast.x_bcast()), + BCast::ToIndexArray<2>(bcast.y_bcast()), + logits_in.template shaped<T, 2>(bcast.x_reshape()), + labels_in.template shaped<T, 2>(bcast.y_reshape()), + scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>()); + } } } }; @@ -81,13 +102,17 @@ class SoftmaxXentWithLogitsOp : public OpKernel { namespace functor { template <typename Device, typename T> struct XentFunctorBase { - void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, + void operator()(const Device& d, + const Eigen::DSizes<Eigen::DenseIndex, 2>& shape, + const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast, + const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast, + typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, typename TTypes<T>::Matrix backprop) { - XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss, - backprop); + XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast, + logits, labels, scratch, loss, backprop); } }; |