aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/single_loss_example.py
blob: 09b351ffa4165656e2fc9666ab4b7725ef061f50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple network to use in tests and examples."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.contrib.distribute.python import step_fn
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.layers import core
from tensorflow.python.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops


def single_loss_example(optimizer_fn, distribution, use_bias=False,
                        iterations_per_step=1):
  """Build a very simple network to use in tests and examples."""

  def dataset_fn():
    return dataset_ops.Dataset.from_tensors([[1.]]).repeat()

  optimizer = optimizer_fn()
  layer = core.Dense(1, use_bias=use_bias)

  def loss_fn(ctx, x):
    del ctx
    y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
    return y * y

  single_loss_step = step_fn.StandardSingleLossStep(
      dataset_fn, loss_fn, optimizer, distribution, iterations_per_step)

  # Layer is returned for inspecting the kernels in tests.
  return single_loss_step, layer


def minimize_loss_example(optimizer_fn,
                          use_bias=False,
                          use_callable_loss=True,
                          create_optimizer_inside_model_fn=False):
  """Example of non-distribution-aware legacy code."""

  def dataset_fn():
    dataset = dataset_ops.Dataset.from_tensors([[1.]]).repeat()
    # TODO(isaprykin): batch with drop_remainder causes shapes to be
    # fully defined for TPU.  Remove this when XLA supports dynamic shapes.
    return dataset.batch(1, drop_remainder=True)

  # An Optimizer instance is created either outside or inside model_fn.
  outer_optimizer = None
  if not create_optimizer_inside_model_fn:
    outer_optimizer = optimizer_fn()

  layer = core.Dense(1, use_bias=use_bias)

  def model_fn(x):
    """A very simple model written by the user."""

    def loss_fn():
      y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
      return y * y

    optimizer = outer_optimizer or optimizer_fn()

    if use_callable_loss:
      return optimizer.minimize(loss_fn)
    else:
      return optimizer.minimize(loss_fn())

  return model_fn, dataset_fn, layer


def batchnorm_example(optimizer_fn,
                      batch_per_epoch=1,
                      momentum=0.9,
                      renorm=False,
                      update_ops_in_tower_mode=False):
  """Example of non-distribution-aware legacy code with batch normalization."""

  def dataset_fn():
    # input shape is [16, 8], input values are increasing in both dimensions.
    return dataset_ops.Dataset.from_tensor_slices(
        [[[float(x * 8 + y + z * 100)
           for y in range(8)]
          for x in range(16)]
         for z in range(batch_per_epoch)]).repeat()

  optimizer = optimizer_fn()
  batchnorm = normalization.BatchNormalization(
      renorm=renorm, momentum=momentum, fused=False)
  layer = core.Dense(1, use_bias=False)

  def model_fn(x):
    """A model that uses batchnorm."""

    def loss_fn():
      y = batchnorm(x, training=True)
      with ops.control_dependencies(
          ops.get_collection(ops.GraphKeys.UPDATE_OPS)
          if update_ops_in_tower_mode else []):
        loss = math_ops.reduce_mean(
            math_ops.reduce_sum(layer(y)) - constant_op.constant(1.))
      # `x` and `y` will be fetched by the gradient computation, but not `loss`.
      return loss

    # Callable loss.
    return optimizer.minimize(loss_fn)

  return model_fn, dataset_fn, batchnorm