diff options
author | 2016-05-05 09:40:47 -0800 | |
---|---|---|
committer | 2016-05-05 10:51:04 -0700 | |
commit | 5b5ff3f57264d19538e9d7c0bc862429a557f547 (patch) | |
tree | d61bb813ff3aaeaf52f478530e95f690ac5aa4f5 /tensorflow/python/ops/constant_op.py | |
parent | b8476afedf7332d3554455a94e531c99ccfc410c (diff) |
Introduce a Multinomial op and a parallel CPU kernel.
Example usage:
samples = tf.multinomial(tf.log([[0.5, 0.5]]), 10)
# samples has shape [1, 10], where each value is either 0 or 1 (equal prob.).
samples = tf.multinomial([[1, -1, -1]], 10)
# samples is equivalent to tf.zeros([1, 10], dtype=tf.int64).
The implementation uses the Gumbel nosie trick. To validate the worthiness of
adding a native op, we benchmark against the one-liner approach of composing
existing TF ops to compute the same things. From
"third_party/tensorflow/python:multinomial_op_test" built with "-c opt --copt=-mavx":
("sec" represents wall-time in seconds aggregated for 5 iters.)
Composition of existing ops vs. Native Multinomial op [5 iters]
BatchSize NumClasses NumSamples sec(composed) sec(native) speedup
1 10000 1 0.069 0.040 1.74
1 10000 4 0.006 0.004 1.54
1 10000 128 0.056 0.063 0.89
1 100000 1 0.009 0.008 1.16
1 100000 4 0.017 0.022 0.77
1 100000 128 0.328 0.600 0.55
32 10000 1 0.019 0.007 2.86
32 10000 4 0.048 0.009 5.56
32 10000 128 0.847 0.091 9.31
32 100000 1 0.102 0.027 3.74
32 100000 4 0.274 0.064 4.28
32 100000 128 10.579 0.880 12.02
128 10000 1 0.050 0.036 1.39
128 10000 4 0.135 0.048 2.84
128 10000 128 3.071 0.377 8.15
128 100000 1 0.352 0.133 2.65
128 100000 4 0.995 0.260 3.82
128 100000 128 40.455 3.574 11.32
The speedup is up to 12x.
Change: 121593174
Diffstat (limited to 'tensorflow/python/ops/constant_op.py')
-rw-r--r-- | tensorflow/python/ops/constant_op.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/python/ops/constant_op.py b/tensorflow/python/ops/constant_op.py index 4c328fb7d4..26fb7e1f8a 100644 --- a/tensorflow/python/ops/constant_op.py +++ b/tensorflow/python/ops/constant_op.py @@ -94,8 +94,8 @@ print(sess.run(var)) @@random_uniform @@random_shuffle @@random_crop +@@multinomial @@set_random_seed - """ # Must be separate from array_ops to avoid a cyclic dependency. |