aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/softmax_op_functor.h
diff options
context:
space:
mode:
authorGravatar Zongheng Yang <zongheng.y@gmail.com>2016-05-23 21:59:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-23 23:02:52 -0700
commit989166223cf856a669a0d35cecb40c071a31be38 (patch)
tree32fe88c1239551e33edaa21106433bdcded2d7cc /tensorflow/core/kernels/softmax_op_functor.h
parent26f54a9fcdec365307e38cb3c472beb75d19d2b8 (diff)
Speeds up Softmax by up to 43%, by changing "/ sum" to "* (1/sum)".
Benchmarked using third_party/tensorflow/core/kernels:nn_ops_test. Wall time improves 10-43%: Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ImageNetSoftmaxFwd_32_1008_1 713325 620705 +13.0% BM_ImageNetSoftmaxFwd_128_1008_1 3097766 2782433 +10.2% BM_ImageNetSoftmaxFwd_32_1008_4 1254561 703238 +43.9% BM_ImageNetSoftmaxFwd_128_1008_4 3225011 2543525 +21.1% CPU time improves 4-17%: Benchmark Base (ns) New (ns) Improvement ------------------------------------------------------------------ BM_ImageNetSoftmaxFwd_32_1008_1 711375 618729 +13.0% BM_ImageNetSoftmaxFwd_128_1008_1 3087158 2779777 +10.0% BM_ImageNetSoftmaxFwd_32_1008_4 959016 795579 +17.0% BM_ImageNetSoftmaxFwd_128_1008_4 3774543 3591573 +4.8% Change: 123074430
Diffstat (limited to 'tensorflow/core/kernels/softmax_op_functor.h')
-rw-r--r--tensorflow/core/kernels/softmax_op_functor.h27
1 files changed, 15 insertions, 12 deletions
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
index 47bb9de411..c3b0881b0c 100644
--- a/tensorflow/core/kernels/softmax_op_functor.h
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -63,31 +63,34 @@ struct SoftmaxEigenImpl {
Eigen::IndexList<Eigen::type2index<1>, int> one_by_class;
one_by_class.set(1, num_classes);
#endif
- //shifted_logits = logits - max(logits along classes);
- auto shifted_logits = (logits - logits.maximum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
+ // shifted_logits = logits - max(logits along classes);
+ auto shifted_logits = (logits -
+ logits.maximum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
if (log) {
// Calculate the log of the softmax
// softmax = logits - max(logits along classes);
softmax.device(d) = shifted_logits;
// softmax = softmax - log(sum(exp(softmax along classes)));
softmax.device(d) = (softmax -
- softmax.exp().sum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class)
- .log());
+ softmax.exp()
+ .sum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class)
+ .log());
} else {
// NOTE(touts): If you modify this implementation please run
// the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
//
// softmax = exp(logits - max(logits along classes));
softmax.device(d) = shifted_logits.exp();
- // softmax = softmax / sum(softmax along classes);
- softmax.device(d) = (softmax /
+ // softmax = softmax * (1 / sum(softmax along classes));
+ softmax.device(d) = (softmax *
softmax.sum(along_class)
+ .inverse()
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class));