diff options
author | 2016-05-23 21:59:21 -0800 | |
---|---|---|
committer | 2016-05-23 23:02:52 -0700 | |
commit | 989166223cf856a669a0d35cecb40c071a31be38 (patch) | |
tree | 32fe88c1239551e33edaa21106433bdcded2d7cc /tensorflow/core/kernels/softmax_op_functor.h | |
parent | 26f54a9fcdec365307e38cb3c472beb75d19d2b8 (diff) |
Speeds up Softmax by up to 43%, by changing "/ sum" to "* (1/sum)".
Benchmarked using third_party/tensorflow/core/kernels:nn_ops_test.
Wall time improves 10-43%:
Benchmark Base (ns) New (ns) Improvement
------------------------------------------------------------------
BM_ImageNetSoftmaxFwd_32_1008_1 713325 620705 +13.0%
BM_ImageNetSoftmaxFwd_128_1008_1 3097766 2782433 +10.2%
BM_ImageNetSoftmaxFwd_32_1008_4 1254561 703238 +43.9%
BM_ImageNetSoftmaxFwd_128_1008_4 3225011 2543525 +21.1%
CPU time improves 4-17%:
Benchmark Base (ns) New (ns) Improvement
------------------------------------------------------------------
BM_ImageNetSoftmaxFwd_32_1008_1 711375 618729 +13.0%
BM_ImageNetSoftmaxFwd_128_1008_1 3087158 2779777 +10.0%
BM_ImageNetSoftmaxFwd_32_1008_4 959016 795579 +17.0%
BM_ImageNetSoftmaxFwd_128_1008_4 3774543 3591573 +4.8%
Change: 123074430
Diffstat (limited to 'tensorflow/core/kernels/softmax_op_functor.h')
-rw-r--r-- | tensorflow/core/kernels/softmax_op_functor.h | 27 |
1 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h index 47bb9de411..c3b0881b0c 100644 --- a/tensorflow/core/kernels/softmax_op_functor.h +++ b/tensorflow/core/kernels/softmax_op_functor.h @@ -63,31 +63,34 @@ struct SoftmaxEigenImpl { Eigen::IndexList<Eigen::type2index<1>, int> one_by_class; one_by_class.set(1, num_classes); #endif - //shifted_logits = logits - max(logits along classes); - auto shifted_logits = (logits - logits.maximum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class)); + // shifted_logits = logits - max(logits along classes); + auto shifted_logits = (logits - + logits.maximum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class)); if (log) { // Calculate the log of the softmax // softmax = logits - max(logits along classes); softmax.device(d) = shifted_logits; // softmax = softmax - log(sum(exp(softmax along classes))); softmax.device(d) = (softmax - - softmax.exp().sum(along_class) - .eval() - .reshape(batch_by_one) - .broadcast(one_by_class) - .log()); + softmax.exp() + .sum(along_class) + .eval() + .reshape(batch_by_one) + .broadcast(one_by_class) + .log()); } else { // NOTE(touts): If you modify this implementation please run // the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc. // // softmax = exp(logits - max(logits along classes)); softmax.device(d) = shifted_logits.exp(); - // softmax = softmax / sum(softmax along classes); - softmax.device(d) = (softmax / + // softmax = softmax * (1 / sum(softmax along classes)); + softmax.device(d) = (softmax * softmax.sum(along_class) + .inverse() .eval() .reshape(batch_by_one) .broadcast(one_by_class)); |