aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/xent_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/xent_op.h')
-rw-r--r--tensorflow/core/kernels/xent_op.h35
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h
index e689fca7ff..87be17fca9 100644
--- a/tensorflow/core/kernels/xent_op.h
+++ b/tensorflow/core/kernels/xent_op.h
@@ -18,6 +18,7 @@ limitations under the License.
// Functor definition for XentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
#include "tensorflow/core/framework/tensor_types.h"
namespace tensorflow {
@@ -33,7 +34,11 @@ struct XentFunctor {
// scratch: temporary tensor, dims: batch_size, 1
// loss: output tensor for the loss, dims: batch_size.
// backprop: output tensor for the backprop, dims: batch_size, num_classes.
- void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ void operator()(const Device &d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
+ const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
@@ -45,7 +50,11 @@ struct XentFunctor {
// specializations for both device types.
template <typename Device, typename T>
struct XentEigenImpl {
- static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits,
+ static void Compute(const Device &d,
+ const Eigen::DSizes<Eigen::DenseIndex, 2> &shape,
+ const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast,
+ const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast,
+ typename TTypes<T>::ConstMatrix logits,
typename TTypes<T>::ConstMatrix labels,
typename TTypes<T>::Matrix scratch,
typename TTypes<T>::Vec loss,
@@ -57,8 +66,8 @@ struct XentEigenImpl {
const int kBatchDim = 0;
const int kClassDim = 1;
- const int batch_size = logits.dimension(kBatchDim);
- const int num_classes = logits.dimension(kClassDim);
+ const int batch_size = shape[kBatchDim];
+ const int num_classes = shape[kClassDim];
// These arrays are used to reduce along the class dimension, and broadcast
// the resulting value to all classes.
@@ -84,10 +93,12 @@ struct XentEigenImpl {
#endif
// max_logits along classes.
- scratch.reshape(batch_only).device(d) = logits.maximum(along_class);
+ scratch.reshape(batch_only).device(d) =
+ logits.broadcast(logits_bcast).maximum(along_class);
// logits - max_logits.
- backprop.device(d) = logits - scratch.broadcast(one_by_class);
+ backprop.device(d) =
+ logits.broadcast(logits_bcast) - scratch.broadcast(one_by_class);
// sum(exp(logits - max_logits)) along classes.
scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class);
@@ -99,15 +110,15 @@ struct XentEigenImpl {
// sum(-labels *
// ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
// along classes
- loss.device(d) =
- (labels * (scratch.log().eval().broadcast(one_by_class) - backprop))
- .eval()
- .sum(along_class);
+ loss.device(d) = (labels.broadcast(labels_bcast) *
+ (scratch.log().eval().broadcast(one_by_class) - backprop))
+ .eval()
+ .sum(along_class);
// backprop: prob - labels, where
// prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
- backprop.device(d) =
- (backprop.exp() / scratch.broadcast(one_by_class)) - labels;
+ backprop.device(d) = (backprop.exp() / scratch.broadcast(one_by_class)) -
+ labels.broadcast(labels_bcast);
}
};