aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xent_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xent_op.cc')
-rw-r--r--tensorflow/core/kernels/xent_op.cc65
1 files changed, 45 insertions, 20 deletions
diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc
index a6a71fdfaf..9a3612bd72 100644
--- a/tensorflow/core/kernels/xent_op.cc
+++ b/tensorflow/core/kernels/xent_op.cc
@@ -17,12 +17,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/core/kernels/xent_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/xent_op.h"
+#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -41,37 +43,56 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& logits_in = context->input(0);
const Tensor& labels_in = context->input(1);
- OP_REQUIRES(context, logits_in.IsSameSize(labels_in),
- errors::InvalidArgument(
- "logits and labels must be same size: logits_size=",
- logits_in.shape().DebugString(),
- " labels_size=", labels_in.shape().DebugString()));
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(logits_in.shape()),
- errors::InvalidArgument("logits must be 2-dimensional"));
- // As we already tested that both inputs have the same shape no need to
- // check that "labels" is a matrix too.
+
+ TensorShape shape_in = logits_in.shape();
+
+ BCast bcast(BCast::FromShape(logits_in.shape()),
+ BCast::FromShape(labels_in.shape()));
+ if (!logits_in.IsSameSize(labels_in)) {
+ OP_REQUIRES(context, bcast.IsValid(),
+ errors::InvalidArgument(
+ "logits and labels must be broadcastable: logits_size=",
+ logits_in.shape().DebugString(),
+ " labels_size=", labels_in.shape().DebugString()));
+ shape_in = BCast::ToShape(bcast.output_shape());
+ }
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(shape_in),
+ errors::InvalidArgument("logits and labels must be beither "
+ "2-dimensional, or roadcasted to "
+ "2-dimensional"));
// loss is 1-D (one per example), and size is batch_size.
Tensor scratch;
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({logits_in.dim_size(0), 1}),
+ TensorShape({shape_in.dim_size(0), 1}),
&scratch));
Tensor* loss_out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
- 0, TensorShape({logits_in.dim_size(0)}), &loss_out));
+ 0, TensorShape({shape_in.dim_size(0)}), &loss_out));
Tensor* back_out = nullptr;
// Try to reuse the logits_in buffer for the backprop output.
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 1, logits_in.shape(), &back_out));
- if (logits_in.dim_size(0) > 0) {
+ {0}, 1, shape_in, &back_out));
+ if (shape_in.dim_size(0) > 0) {
functor::XentFunctor<Device, T> functor;
- functor(context->eigen_device<Device>(), logits_in.matrix<T>(),
- labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
- back_out->matrix<T>());
+ if (logits_in.IsSameSize(labels_in)) {
+ functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
+ Eigen::array<Eigen::DenseIndex, 2>{1, 1},
+ Eigen::array<Eigen::DenseIndex, 2>{1, 1}, logits_in.matrix<T>(),
+ labels_in.matrix<T>(), scratch.matrix<T>(), loss_out->vec<T>(),
+ back_out->matrix<T>());
+ } else {
+ functor(context->eigen_device<Device>(), shape_in.AsEigenDSizes<2>(),
+ BCast::ToIndexArray<2>(bcast.x_bcast()),
+ BCast::ToIndexArray<2>(bcast.y_bcast()),
+ logits_in.template shaped<T, 2>(bcast.x_reshape()),
+ labels_in.template shaped<T, 2>(bcast.y_reshape()),
+ scratch.matrix<T>(), loss_out->vec<T>(), back_out->matrix<T>());
+ }
}
}
};
@@ -81,13 +102,17 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
namespace functor {
template <typename Device, typename T>
struct XentFunctorBase {
- void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ void operator()(const Device& d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2>& shape,
+ const Eigen::array<Eigen::DenseIndex, 2>& logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2>& labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
typename TTypes<T>::Matrix backprop) {
- XentEigenImpl<Device, T>::Compute(d, logits, labels, scratch, loss,
- backprop);
+ XentEigenImpl<Device, T>::Compute(d, shape, logits_bcast, labels_bcast,
+ logits, labels, scratch, loss, backprop);
}
};