diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op.h')
-rw-r--r-- | tensorflow/core/kernels/xent_op.h | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h index e689fca7ff..87be17fca9 100644 --- a/tensorflow/core/kernels/xent_op.h +++ b/tensorflow/core/kernels/xent_op.h @@ -18,6 +18,7 @@ limitations under the License. // Functor definition for XentOp, must be compilable by nvcc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { @@ -33,7 +34,11 @@ struct XentFunctor { // scratch: temporary tensor, dims: batch_size, 1 // loss: output tensor for the loss, dims: batch_size. // backprop: output tensor for the backprop, dims: batch_size, num_classes. - void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, + void operator()(const Device &d, + const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, + const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, + const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, + typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, @@ -45,7 +50,11 @@ struct XentFunctor { // specializations for both device types. template <typename Device, typename T> struct XentEigenImpl { - static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits, + static void Compute(const Device &d, + const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, + const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, + const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, + typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, @@ -57,8 +66,8 @@ struct XentEigenImpl { const int kBatchDim = 0; const int kClassDim = 1; - const int batch_size = logits.dimension(kBatchDim); - const int num_classes = logits.dimension(kClassDim); + const int batch_size = shape[kBatchDim]; + const int num_classes = shape[kClassDim]; // These arrays are used to reduce along the class dimension, and broadcast // the resulting value to all classes. @@ -84,10 +93,12 @@ struct XentEigenImpl { #endif // max_logits along classes. - scratch.reshape(batch_only).device(d) = logits.maximum(along_class); + scratch.reshape(batch_only).device(d) = + logits.broadcast(logits_bcast).maximum(along_class); // logits - max_logits. - backprop.device(d) = logits - scratch.broadcast(one_by_class); + backprop.device(d) = + logits.broadcast(logits_bcast) - scratch.broadcast(one_by_class); // sum(exp(logits - max_logits)) along classes. scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class); @@ -99,15 +110,15 @@ struct XentEigenImpl { // sum(-labels * // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes - loss.device(d) = - (labels * (scratch.log().eval().broadcast(one_by_class) - backprop)) - .eval() - .sum(along_class); + loss.device(d) = (labels.broadcast(labels_bcast) * + (scratch.log().eval().broadcast(one_by_class) - backprop)) + .eval() + .sum(along_class); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) - backprop.device(d) = - (backprop.exp() / scratch.broadcast(one_by_class)) - labels; + backprop.device(d) = (backprop.exp() / scratch.broadcast(one_by_class)) - + labels.broadcast(labels_bcast); } }; |