aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn.py
blob: 7a4dc25e8b686fa68eaa5fd30127f2e2b9634a39 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
# pylint: disable=wildcard-import,unused-import,g-bad-import-order
"""## Activation Functions

The activation ops provide different types of nonlinearities for use in
neural networks.  These include smooth nonlinearities (`sigmoid`,
`tanh`, and `softplus`), continuous but not everywhere differentiable
functions (`relu`, `relu6`, and `relu_x`), and random regularization
(`dropout`).

All activation ops apply componentwise, and produce a tensor of the same
shape as the input tensor.

@@relu
@@relu6
@@softplus
@@dropout
@@bias_add
@@sigmoid
@@tanh

## Convolution

The convolution ops sweep a 2-D filter over a batch of images, applying the
filter to each window of each image of the appropriate size.  The different
ops trade off between generic vs. specific filters:

* `conv2d`: Arbitrary filters that can mix channels together.
* `depthwise_conv2d`: Filters that operate on each channel independently.
* `separable_conv2d`: A depthwise spatial filter followed by a pointwise filter.

Note that although these ops are called "convolution", they are strictly
speaking "cross-correlation" since the filter is combined with an input window
without reversing the filter.  For details, see [the properties of
cross-correlation](https://en.wikipedia.org/wiki/Cross-correlation#Properties).

The filter is applied to image patches of the same size as the filter and
strided according to the `strides` argument.  `strides = [1, 1, 1, 1]` applies
the filter to a patch at every offset, `strides = [1, 2, 2, 1]` applies the
filter to every other image patch in each dimension, etc.

Ignoring channels for the moment, the spatial semantics of the convolution ops
are as follows.  If the 4-D `input` has shape
`[batch, in_height, in_width, ...]` and the 4-D `filter` has shape
`[filter_height, filter_width, ...]`, then

    output.shape = [batch,
                    (in_height - filter_height + 1) / strides[1],
                    (in_width - filter_width + 1) / strides[2],
                    ...]

    output[b, i, j, :] =
        sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, ...] *
                     filter[di, dj, ...]

Since `input` is 4-D, each `input[b, i, j, :]` is a vector.  For `conv2d`, these
vectors are multiplied by the `filter[di, dj, :, :]` matrices to produce new
vectors.  For `depthwise_conv_2d`, each scalar component `input[b, i, j, k]`
is multiplied by a vector `filter[di, dj, k]`, and all the vectors are
concatenated.

In the formula for `output.shape`, the rounding direction depends on padding:

* `padding = 'SAME'`: Round down (only full size windows are considered).
* `padding = 'VALID'`: Round up (partial windows are included).

@@conv2d
@@depthwise_conv2d
@@separable_conv2d

## Pooling

The pooling ops sweep a rectangular window over the input tensor, computing a
reduction operation for each window (average, max, or max with argmax).  Each
pooling op uses rectangular windows of size `ksize` separated by offset
`strides`.  For example, if `strides` is all ones every window is used, if
`strides` is all twos every other window is used in each dimension, etc.

In detail, the output is

    output[i] = reduce(value[strides * i:strides * i + ksize])

for each tuple of indices `i`.  The output shape is

    output.shape = (value.shape - ksize + 1) / strides

where the rounding direction depends on padding:

* `padding = 'SAME'`: Round down (only full size windows are considered).
* `padding = 'VALID'`: Round up (partial windows are included).

@@avg_pool
@@max_pool
@@max_pool_with_argmax

## Normalization

Normalization is useful to prevent neurons from saturating when inputs may
have varying scale, and to aid generalization.

@@l2_normalize
@@local_response_normalization
@@moments

## Losses

The loss ops measure error between two tensors, or between a tensor and zero.
These can be used for measuring accuracy of a network in a regression task
or for regularization purposes (weight decay).

@@l2_loss

## Classification

TensorFlow provides several operations that help you perform classification.

@@sigmoid_cross_entropy_with_logits
@@softmax
@@softmax_cross_entropy_with_logits

## Embeddings

TensorFlow provides several operations that help you compute embeddings.

@@embedding_lookup
@@embedding_lookup_sparse

## Evaluation

The evaluation ops are useful for measuring the performance of a network.
Since they are nondifferentiable, they are typically used at evaluation time.

@@top_k
@@in_top_k

## Candidate Sampling

Do you want to train a multiclass or multilabel model with thousands
or millions of output classes (for example, a language model with a
large vocabulary)?  Training with a full Softmax is slow in this case,
since all of the classes are evaluated for every training example.
Candidate Sampling training algorithms can speed up your step times by
only considering a small randomly-chosen subset of contrastive classes
(called candidates) for each batch of training examples.

See our [Candidate Sampling Algorithms Reference]
(http://www.tensorflow.org/extras/candidate_sampling.pdf)

### Sampled Loss Functions

TensorFlow provides the following sampled loss functions for faster training.

@@nce_loss
@@sampled_softmax_loss

### Candidate Samplers

TensorFlow provides the following samplers for randomly sampling candidate
classes when using one of the sampled loss functions above.

@@uniform_candidate_sampler
@@log_uniform_candidate_sampler
@@learned_unigram_candidate_sampler
@@fixed_unigram_candidate_sampler

### Miscellaneous candidate sampling utilities

@@compute_accidental_hits

"""

from tensorflow.python.framework import ops
from tensorflow.python.framework import types
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import candidate_sampling_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import numerics
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh

# Bring more nn-associated functionality into this package.
from tensorflow.python.ops.nn_ops import *
from tensorflow.python.ops.candidate_sampling_ops import *
from tensorflow.python.ops.embedding_ops import *


def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
  """Computes sigmoid cross entropy given `logits`.

  Measures the probability error in discrete classification tasks in which each
  class is independent and not mutually exclusive.  For instance, one could
  perform multilabel classification where a picture can contain both an elephant
  and a dog at the same time.

  For brevity, let `x = logits`, `z = targets`.  The logistic loss is

      x - x * z + log(1 + exp(-x))

  To ensure stability and avoid overflow, the implementation uses

      max(x, 0) - x * z + log(1 + exp(-abs(x)))

  `logits` and `targets` must have the same type and shape.

  Args:
    logits: A `Tensor` of type `float32` or `float64`.
    targets: A `Tensor` of the same type and shape as `logits`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` of the same shape as `logits` with the componentwise
    logistic losses.
  """
  with ops.op_scope([logits, targets], name, "logistic_loss") as name:
    logits = ops.convert_to_tensor(logits, name="logits")
    targets = ops.convert_to_tensor(targets, name="targets")
    # The logistic loss formula from above is
    #   x - x * z + log(1 + exp(-x))
    # For x < 0, a more numerically stable formula is
    #   -x * z + log(1 + exp(x))
    # To avoid branching, we use the combined version
    #   max(x, 0) - x * z + log(1 + exp(-abs(x)))
    return math_ops.add(nn_ops.relu(logits) - logits * targets,
                        math_ops.log(1 + math_ops.exp(-math_ops.abs(logits))),
                        name=name)


def xw_plus_b(x, weights, biases, name=None):
  """Computes matmul(x, weights) + biases.

  Args:
    x: a 2D tensor.  Dimensions typically: batch, in_units
    weights: a 2D tensor.  Dimensions typically: in_units, out_units
    biases: a 1D tensor.  Dimensions: out_units
    name: A name for the operation (optional).  If not specified
      "wx_plus_b" is used.

  Returns:
    A 2-D Tensor computing matmul(x, weights) + biases.
    Dimensions typically: batch, out_units.
  """
  with ops.op_scope([x, weights, biases], name, "xw_plus_b") as name:
    x = ops.convert_to_tensor(x, name="x")
    weights = ops.convert_to_tensor(weights, name="weights")
    biases = ops.convert_to_tensor(biases, name="biases")
    mm = math_ops.matmul(x, weights)
    return nn_ops.bias_add(mm, biases, name=name)


def relu_layer(x, weights, biases, name=None):
  """Computes Relu(x * weight + biases).

  Args:
    x: a 2D tensor.  Dimensions typically: batch, in_units
    weights: a 2D tensor.  Dimensions typically: in_units, out_units
    biases: a 1D tensor.  Dimensions: out_units
    name: A name for the operation (optional).  If not specified
      "nn_relu_layer" is used.

  Returns:
    A 2-D Tensor computing relu(matmul(x, weights) + biases).
    Dimensions typically: batch, out_units.
  """
  with ops.op_scope([x, weights, biases], name, "relu_layer") as name:
    x = ops.convert_to_tensor(x, name="x")
    weights = ops.convert_to_tensor(weights, name="weights")
    biases = ops.convert_to_tensor(biases, name="biases")
    xw_plus_b = nn_ops.bias_add(math_ops.matmul(x, weights), biases)
    return nn_ops.relu(xw_plus_b, name=name)


def l2_normalize(x, dim, epsilon=1e-12, name=None):
  """Normalizes along dimension `dim` using an L2 norm.

  For a 1-D tensor with `dim = 0`, computes

      output = x / sqrt(max(sum(x**2), epsilon))

  For `x` with more dimensions, independently normalizes each 1-D slice along
  dimension `dim`.

  Args:
    x: A `Tensor`.
    dim: Dimension along which to normalize.
    epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
      divisor if `norm < sqrt(epsilon)`.
    name: A name for this operation (optional).

  Returns:
    A `Tensor` with the same shape as `x`.
  """
  with ops.op_scope([x], name, "l2_normalize") as name:
    x = ops.convert_to_tensor(x, name="x")
    square_sum = math_ops.reduce_sum(math_ops.square(x), [dim], keep_dims=True)
    x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
    return math_ops.mul(x, x_inv_norm, name=name)


def zero_fraction(value, name=None):
  """Returns the fraction of zeros in `value`.

  If `value` is empty, the result is `nan`.

  This is useful in summaries to measure and report sparsity.  For example,

      z = tf.Relu(...)
      summ = tf.scalar_summary('sparsity', tf.zero_fraction(z))

  Args:
    value: A tensor of numeric type.
    name: A name for the operation (optional).

  Returns:
    The fraction of zeros in `value`, with type `float32`.
  """
  with ops.op_scope([value], name, "zero_fraction"):
    value = ops.convert_to_tensor(value, name="value")
    zero = constant_op.constant(0, dtype=value.dtype, name="zero")
    return math_ops.reduce_mean(math_ops.cast(math_ops.equal(value, zero),
                                              types.float32))


def dropout(x, keep_prob, noise_shape=None, seed=None, name=None):
  """Computes dropout.

  With probability `keep_prob`, outputs the input element scaled up by
  `1 / keep_prob`, otherwise outputs `0`.  The scaling is so that the expected
  sum is unchanged.

  By default, each element is kept or dropped independently.  If `noise_shape`
  is specified, it must be
  [broadcastable](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
  to the shape of `x`, and only dimensions with `noise_shape[i] == x.shape[i]`
  will make independent decisions.  For example, if `x.shape = [b, x, y, c]` and
  `noise_shape = [b, 1, 1, c]`, each batch and channel component will be
  kept independently and each row and column will be kept or not kept together.

  Args:
    x: A tensor.
    keep_prob: Float probability that each element is kept.
    noise_shape: Shape for randomly generated keep/drop flags.
    seed: A Python integer. Used to create a random seed.
      See [`set_random_seed`](constant_op.md#set_random_seed) for behavior.
    name: A name for this operation (optional).

  Returns:
    A Tensor of the same shape of `x`.

  Raises:
    ValueError: If `keep_prob` is not in `(0, 1]`.
  """
  if not (0 < keep_prob <= 1):
    raise ValueError("Expected keep_prob in (0, 1], got %g" % keep_prob)
  with ops.op_scope([x], name, "dropout") as name:
    x = ops.convert_to_tensor(x, name="x")
    noise_shape = noise_shape or array_ops.shape(x)
    # uniform [keep_prob, 1.0 + keep_prob)
    random_tensor = keep_prob
    random_tensor += random_ops.random_uniform(
        noise_shape, seed=seed, dtype=x.dtype)
    # 0. if [keep_prob, 1.0) and 1. if [1.0, 1.0 + keep_prob)
    binary_tensor = math_ops.floor(random_tensor)
    return x * (1.0 / keep_prob) * binary_tensor


def depthwise_conv2d(input, filter, strides, padding, name=None):
  """Depthwise 2-D convolution.

  Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
  and a filter tensor of shape
  `[filter_height, filter_width, in_channels, channel_multiplier]`
  containing `in_channels` convolutional filters of depth 1, `depthwise_conv2d`
  applies a different filter to each input channel (expanding from 1 channel
  to `channel_multiplier` channels for each), then concatenates the results
  together.  The output has `in_channels * channel_multiplier` channels.

  In detail,

      output[b, i, j, k * channel_multiplier + q] =
          sum_{di, dj} input[b, strides[1] * i + di, strides[2] * j + dj, k] *
                       filter[di, dj, k, q]

  Must have `strides[0] = strides[3] = 1`.  For the most common case of the
  same horizontal and vertical strides, `strides = [1, stride, stride, 1]`.

  Args:
    input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
    filter: 4-D with shape
      `[filter_height, filter_width, in_channels, channel_multiplier]`.
    strides: 1-D of size 4.  The stride of the sliding window for each
      dimension of `input`.
    padding: A string, either `'VALID'` or `'SAME'`.  The padding algorithm.
    name: A name for this operation (optional).

  Returns:
    A 4-D `Tensor` of shape
    `[batch, out_height, out_width, in_channels * channel_multiplier].`
  """
  with ops.op_scope([input, filter], name, "depthwise") as name:
    input = ops.convert_to_tensor(input, name="tensor_in")
    filter = ops.convert_to_tensor(filter, name="filter_in")
    # A shape is required to statically compute the number of separable filters.
    if filter.get_shape().ndims is not None:
      assert len(filter.get_shape()) == 4
      in_channels = filter.get_shape()[2]
      # Sanity checks, if shape information is available for the inputs.
      if input.get_shape().ndims is not None:
        assert len(input.get_shape()) == 4
        assert input.get_shape()[3] == in_channels, (
            "Mismatched input depth %d and number of depthwise filters %d." % (
                input.get_shape()[3].value, in_channels))
    else:
      assert input.get_shape().ndims is not None, (
          "Either tensor must provide static shape information.")
      assert input.get_shape().ndims == 4
      in_channels = input.get_shape()[3]

    if in_channels == 1:
      return nn_ops.conv2d(input, filter, strides, padding, name=name)
    else:
      # Create one separate convolution per channel.
      convs = []
      for channel in xrange(in_channels):
        with ops.name_scope("depth%d" % channel) as channel_scope:
          t_in = array_ops.slice(input, [0, 0, 0, channel], [-1, -1, -1, 1],
                                 name="slice_inputs")
          f_in = array_ops.slice(filter, [0, 0, channel, 0], [-1, -1, 1, -1],
                                 name="slice_params")
          convs.append(nn_ops.conv2d(t_in, f_in,
                                     strides, padding, name=channel_scope))
      # Concatenate the per-channel convolutions along the channel dimension.
      return array_ops.concat(3, convs, name=name)


def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
                     padding,
                     name=None):
  """2-D convolution with separable filters.

  Performs a depthwise convolution that acts separately on channels followed by
  a pointwise convolution that mixes channels.  Note that this is separability
  between dimensions `[1, 2]` and `3`, not spatial separability between
  dimensions `1` and `2`.

  In detail,

      output[b, i, j, k] = sum_{di, dj, q, r]
          input[b, strides[1] * i + di, strides[2] * j + dj, q] *
          depthwise_filter[di, dj, q, r] *
          pointwise_filter[0, 0, q * channel_multiplier + r, k]

  `strides` controls the strides for the depthwise convolution only, since
  the pointwise convolution has implicit strides of `[1, 1, 1, 1]`.  Must have
  `strides[0] = strides[3] = 1`.  For the most common case of the same
  horizontal and vertical strides, `strides = [1, stride, stride, 1]`.

  Args:
    input: 4-D `Tensor` with shape `[batch, in_height, in_width, in_channels]`.
    depthwise_filter: 4-D `Tensor` with shape
      `[filter_height, filter_width, in_channels, channel_multiplier]`.
      Contains `in_channels` convolutional filters of depth 1.
    pointwise_filter: 4-D `Tensor` with shape
      `[1, 1, channel_multiplier * in_channels, out_channels]`.  Pointwise
      filter to mix channels after `depthwise_filter` has convolved spatially.
    strides: 1-D of size 4.  The strides for the depthwise convolution for
      each dimension of `input`.
    padding: A string, either `'VALID'` or `'SAME'`.  The padding algorithm.
    name: A name for this operation (optional).

  Returns:
    A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`.
  """
  with ops.op_scope([input, depthwise_filter, pointwise_filter],
                   name, "separable_conv2d") as name:
    input = ops.convert_to_tensor(input, name="tensor_in")
    depthwise_filter = ops.convert_to_tensor(depthwise_filter,
                                             name="depthwise_filter")
    pointwise_filter = ops.convert_to_tensor(pointwise_filter,
                                             name="pointwise_filter")

    if pointwise_filter.get_shape().ndims is not None:
      assert len(pointwise_filter.get_shape()) == 4
      assert pointwise_filter.get_shape()[0] == 1
      assert pointwise_filter.get_shape()[1] == 1
      if depthwise_filter.get_shape().ndims and input.get_shape().ndims:
        channel_multiplier = depthwise_filter.get_shape()[3]
        in_channels = input.get_shape()[3]
        out_channels = pointwise_filter.get_shape()[3]
        # This would mean the separable convolutions is over-parametrized.
        assert channel_multiplier * in_channels < out_channels
    # The layout of the ops in the graph are expected to be as follows:
    # separable_conv2d  // Conv2D op corresponding to the pointwise conv.
    # separable_conv2d/depthwise  // Concat op for the deptwise outputs.
    # separable_conv2d/depthwise/depth0  // Conv2D op for depth 0
    # separable_conv2d/depthwise/depth1  // Conv2D op for depth 1
    # separable_conv2d/depthwise/depth2  // Conv2D op for depth 2
    depthwise = depthwise_conv2d(input, depthwise_filter, strides,
                                 padding, name="depthwise")
    return nn_ops.conv2d(depthwise, pointwise_filter, [1, 1, 1, 1],
                         padding="VALID", name=name)


def moments(x, axes, name=None):
  """Calculate the mean and variance of `x`.

  The mean and variance are calculated by aggregating the contents of `x`
  across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
  and variance of a vector.

  For so-called "global normalization" needed for convolutional filters pass
  `axes=[0, 1, 2]` (batch, height, width).  For batch normalization pass
  `axes=[0]` (batch).

  Args:
    x: A `Tensor`.
    axes: array of ints.  Axes along which to compute mean and
      variance.
    name: Name used to scope the operations that compute the moments.

  Returns:
    Two `Tensors`: `mean` and `variance`.
  """
  with ops.op_scope([x, axes], name, "moments"):
    x = ops.convert_to_tensor(x, name="x")
    divisor = 1.0
    for d in xrange(len(x.get_shape())):
      if d in axes:
        divisor *= x.get_shape()[d].value
    divisor = constant_op.constant(1.0 / divisor, x.dtype, name="divisor")
    axes = constant_op.constant(axes, name="axes")
    # Note: We do not use Mean here because it is very slow on GPU.
    # Note 2: The expression below is potentially more stable.
    # It is however a bit slower and stability doesn't appear to be an issue.
    # mean = math_ops.reduce_sum(math_ops.mul(x, divisor), axes, name="mean")
    # var = math_ops.reduce_sum(math_ops.mul(math_ops.square(x - mean),
    #                                        divisor), axes,
    #                    name="variance")
    mean = math_ops.mul(math_ops.reduce_sum(x, axes), divisor, name="mean")
    var = math_ops.mul(math_ops.reduce_sum(math_ops.square(x - mean), axes),
                       divisor, name="variance")
    return mean, var


def _sum_rows(x):
  """Returns a vector summing up each row of the matrix x."""
  # _sum_rows(x) is equivalent to math_ops.reduce_sum(x, 1) when x is
  # a matrix.  The gradient of _sum_rows(x) is more efficient than
  # reduce_sum(x, 1)'s gradient in today's implementation. Therefore,
  # we use _sum_rows(x) in the nce_loss() computation since the loss
  # is mostly used for training.
  cols = array_ops.shape(x)[1]
  ones_shape = array_ops.pack([cols, 1])
  ones = array_ops.ones(ones_shape, x.dtype)
  return array_ops.reshape(math_ops.matmul(x, ones), [-1])


def _compute_sampled_logits(weights, biases, inputs, labels, num_sampled,
                            num_classes, num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            remove_accidental_hits=False,
                            name=None):
  """Helper function for nce_loss and sampled_softmax_loss functions.

  Computes sampled output training logits and labels suitable for implementing
  e.g. noise-contrastive estimation (see nce_loss) or sampled softmax (see
  sampled_softmax_loss).

  Note: In the case where num_true > 1, we assign to each target class
  the target probability 1 / num_true so that the target probabilities
  sum to 1 per-example.

  Args:
    weights: tensor of label embeddings with shape = [num_classes, dim]
    biases: tensor of num_classes label biases
    inputs: tensor with shape = [batch_size, dim] corresponding to forward
        activations of the input network
    labels: int tensor with shape [batch_size, num_true]
    num_sampled: number of label classes to sample per batch
    num_classes: number of possible label classes in the data (e.g. vocab size)
    num_true: number of target classes per example (default: 1)
    sampled_values: a tuple of (sampled_candidates, true_expected_count,
        sampled_expected_count) returned by a *CandidateSampler function to use
        (if None, we default to LogUniformCandidateSampler)
    subtract_log_q: subtract the log expected count of the labels in the sample
        to get the logits of the true labels (default: True)
        Turn off for Negative Sampling.
    remove_accidental_hits: whether to remove "accidental hits" where a sampled
        label equals the true labels (bool, default: False)
    name: name for this op

  Returns:
    out_logits, out_labels: tensors with shape [batch_size, num_true +
        num_sampled] for passing to either SigmoidCrossEntropyWithLogits (NCE)
        or SoftmaxCrossEntropyWithLogits (sampled softmax).

  """

  with ops.op_scope(
      [weights, biases, inputs, labels], name, "compute_sampled_logits"):
    if labels.dtype != types.int64:
      labels = math_ops.cast(labels, types.int64)
    labels_flat = array_ops.reshape(labels, [-1])

    # Sample the negative labels.
    #   sampled shape: num_sampled vector
    #   true_expected_count shape = [batch_size, 1]
    #   sampled_expected_count shape = num_sampled vector
    if sampled_values is None:
      sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes)
    # NOTE: pylint cannot tell that 'sampled_values' is a sequence
    # pylint: disable=unpacking-non-sequence
    sampled, true_expected_count, sampled_expected_count = sampled_values
    # pylint: enable=unpacking-non-sequence

    # weights shape is [num_classes, dim]
    # labels_flat is a [batch_size * num_true] vector
    # true_w shape is [batch_size * num_true, dim]
    # true_b is a [batch_size * num_true] vector
    true_w = embedding_ops.embedding_lookup(weights, labels_flat)
    true_b = embedding_ops.embedding_lookup(biases, labels_flat)

    # inputs shape is [batch_size, dim]
    # true_w shape is [batch_size * num_true, dim]
    # row_wise_dots is [batch_size, num_true, dim]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat(0, [[-1, num_true], dim])
    row_wise_dots = math_ops.mul(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    # We want the row-wise dot plus biases which yields a
    # [batch_size, num_true] tensor of true_logits.
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat(0, [[-1], dim]))
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b

    # Lookup weights and biases for sampled labels.
    #   sampled is a num_sampled int vector
    #   sampled_w shape is [num_sampled, dim]
    #   sampled_b is a num_sampled float vector
    sampled_w = embedding_ops.embedding_lookup(weights, sampled)
    sampled_b = embedding_ops.embedding_lookup(biases, sampled)

    # inputs has shape [batch_size, dim]
    # sampled_w has shape [num_sampled, dim]
    # sampled_b has shape [num_sampled]
    # Apply X*W'+B, which yields [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs,
                                     sampled_w,
                                     transpose_b=True) + sampled_b

    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(math_ops.cast(
          acc_ids, types.int32), [-1, 1])
      sparse_indices = array_ops.concat(
          1, [acc_indices_2d, acc_ids_2d_int32], "sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          0,
          [array_ops.shape(labels)[:1], array_ops.expand_dims(num_sampled, 0)])
      sampled_logits += sparse_ops.sparse_to_dense(
          sparse_indices, sampled_logits_shape, acc_weights, 0.0)

    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)

    # Construct output logits and labels. The true labels/logits start at col 0.
    out_logits = array_ops.concat(1, [true_logits, sampled_logits])
    # true_logits is a float tensor, ones_like(true_logits) is a float tensor
    # of ones. We then divide by num_true to ensure the per-example labels sum
    # to 1.0, i.e. form a proper probability distribution.
    out_labels = array_ops.concat(
        1, [array_ops.ones_like(true_logits) / num_true,
            array_ops.zeros_like(sampled_logits)])

  return out_logits, out_labels


def nce_loss(weights, biases, inputs, labels, num_sampled, num_classes,
             num_true=1,
             sampled_values=None,
             remove_accidental_hits=False,
             name="nce_loss"):
  """Computes and returns the noise-contrastive estimation training loss.

  See [Noise-contrastive estimation: A new estimation principle for
  unnormalized statistical models]
  (http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf).
  Also see our [Candidate Sampling Algorithms Reference]
  (http://www.tensorflow.org/extras/candidate_sampling.pdf)

  Note: In the case where num_true > 1, we assign to each target class
  the target probability 1 / num_true so that the target probabilities
  sum to 1 per-example.

  Note: It would be useful to allow a variable number of target classes per
  example.  We hope to provide this functionality in a future release.
  For now, if you have a variable number of target classes, you can pad them
  out to a constant number by either repeating them or by padding
  with an otherwise unused class.

  Args:
    weights: A `Tensor` of shape [num_classes, dim].  The class embeddings.
    biases: A `Tensor` of shape [num_classes].  The class biases.
    inputs: A `Tensor` of shape [batch_size, dim].  The forward
        activations of the input network.
    labels: A `Tensor` of type `int64` and shape `[batch_size,
      num_true]`. The target classes.
    num_sampled: An `int`.  The number of classes to randomly sample per batch.
    num_classes: An `int`. The number of possible classes.
    num_true: An `int`.  The number of target classes per training example.
    sampled_values: a tuple of `(sampled_candidates, true_expected_count,
        sampled_expected_count)` returned by a *_candidate_sampler function.
        (if None, we default to LogUniformCandidateSampler)
    remove_accidental_hits:  A `bool`.  Whether to remove "accidental hits"
        where a sampled class equals one of the target classes.  If set to
        `True`, this is a "Sampled Logistic" loss instead of NCE, and we are
        learning to generate log-odds instead of log probabilities.  See
        our [Candidate Sampling Algorithms Reference]
        (http://www.tensorflow.org/extras/candidate_sampling.pdf).
        Default is False.
    name: A name for the operation (optional).

  Returns:
    A batch_size 1-D tensor of per-example NCE losses.
  """
  logits, labels = _compute_sampled_logits(
      weights, biases, inputs, labels, num_sampled, num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      name=name)
  sampled_losses = sigmoid_cross_entropy_with_logits(logits,
                                                     labels,
                                                     name="sampled_losses")
  # sampled_losses is batch_size x {true_loss, sampled_losses...}
  # We sum out true and sampled losses.
  return _sum_rows(sampled_losses)


def sampled_softmax_loss(weights, biases, inputs, labels, num_sampled,
                         num_classes, num_true=1,
                         sampled_values=None,
                         remove_accidental_hits=True,
                         name="sampled_softmax_loss"):
  """Computes and returns the sampled softmax training loss.

  This is a faster way to train a softmax classifier over a huge number of
  classes.

  This operation is for training only.  It is generally an underestimate of
  the full softmax loss.

  At inference time, you can compute full softmax probabilities with the
  expression `tf.nn.softmax(tf.matmul(inputs, weights) + biases)`.

  See our [Candidate Sampling Algorithms Reference]
  (http://www.tensorflow.org/extras/candidate_sampling.pdf)

  Also see Section 3 of http://arxiv.org/abs/1412.2007 for the math.

  Args:
    weights: A `Tensor` of shape [num_classes, dim].  The class embeddings.
    biases: A `Tensor` of shape [num_classes].  The class biases.
    inputs: A `Tensor` of shape [batch_size, dim].  The forward
        activations of the input network.
    labels: A `Tensor` of type `int64` and shape `[batch_size,
      num_true]`. The target classes.  Note that this format differs from
      the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
    num_sampled: An `int`.  The number of classes to randomly sample per batch.
    num_classes: An `int`. The number of possible classes.
    num_true: An `int`.  The number of target classes per training example.
    sampled_values: a tuple of `(sampled_candidates, true_expected_count,
        sampled_expected_count)` returned by a *_candidate_sampler function.
        (if None, we default to LogUniformCandidateSampler)
    remove_accidental_hits:  A `bool`.  whether to remove "accidental hits"
        where a sampled class equals one of the target classes.  Default is
        True.
    name: A name for the operation (optional).

  Returns:
    A batch_size 1-D tensor of per-example sampled softmax losses.

  """
  logits, labels = _compute_sampled_logits(
      weights, biases, inputs, labels, num_sampled, num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      name=name)
  sampled_losses = nn_ops.softmax_cross_entropy_with_logits(logits, labels)
  # sampled_losses is a batch_size vector.
  return sampled_losses