diff options
Diffstat (limited to 'tensorflow/core/kernels/xent_op.h')
-rw-r--r-- | tensorflow/core/kernels/xent_op.h | 35 |
1 files changed, 12 insertions, 23 deletions
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h index 87be17fca9..e689fca7ff 100644 --- a/tensorflow/core/kernels/xent_op.h +++ b/tensorflow/core/kernels/xent_op.h @@ -18,7 +18,6 @@ limitations under the License. // Functor definition for XentOp, must be compilable by nvcc. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" - #include "tensorflow/core/framework/tensor_types.h" namespace tensorflow { @@ -34,11 +33,7 @@ struct XentFunctor { // scratch: temporary tensor, dims: batch_size, 1 // loss: output tensor for the loss, dims: batch_size. // backprop: output tensor for the backprop, dims: batch_size, num_classes. - void operator()(const Device &d, - const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, - const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, - const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, - typename TTypes<T>::ConstMatrix logits, + void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, @@ -50,11 +45,7 @@ struct XentFunctor { // specializations for both device types. template <typename Device, typename T> struct XentEigenImpl { - static void Compute(const Device &d, - const Eigen::DSizes<Eigen::DenseIndex, 2> &shape, - const Eigen::array<Eigen::DenseIndex, 2> &logits_bcast, - const Eigen::array<Eigen::DenseIndex, 2> &labels_bcast, - typename TTypes<T>::ConstMatrix logits, + static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits, typename TTypes<T>::ConstMatrix labels, typename TTypes<T>::Matrix scratch, typename TTypes<T>::Vec loss, @@ -66,8 +57,8 @@ struct XentEigenImpl { const int kBatchDim = 0; const int kClassDim = 1; - const int batch_size = shape[kBatchDim]; - const int num_classes = shape[kClassDim]; + const int batch_size = logits.dimension(kBatchDim); + const int num_classes = logits.dimension(kClassDim); // These arrays are used to reduce along the class dimension, and broadcast // the resulting value to all classes. @@ -93,12 +84,10 @@ struct XentEigenImpl { #endif // max_logits along classes. - scratch.reshape(batch_only).device(d) = - logits.broadcast(logits_bcast).maximum(along_class); + scratch.reshape(batch_only).device(d) = logits.maximum(along_class); // logits - max_logits. - backprop.device(d) = - logits.broadcast(logits_bcast) - scratch.broadcast(one_by_class); + backprop.device(d) = logits - scratch.broadcast(one_by_class); // sum(exp(logits - max_logits)) along classes. scratch.reshape(batch_only).device(d) = backprop.exp().sum(along_class); @@ -110,15 +99,15 @@ struct XentEigenImpl { // sum(-labels * // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) // along classes - loss.device(d) = (labels.broadcast(labels_bcast) * - (scratch.log().eval().broadcast(one_by_class) - backprop)) - .eval() - .sum(along_class); + loss.device(d) = + (labels * (scratch.log().eval().broadcast(one_by_class) - backprop)) + .eval() + .sum(along_class); // backprop: prob - labels, where // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) - backprop.device(d) = (backprop.exp() / scratch.broadcast(one_by_class)) - - labels.broadcast(labels_bcast); + backprop.device(d) = + (backprop.exp() / scratch.broadcast(one_by_class)) - labels; } }; |