aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dense_update_ops_test.py
blob: 120e10314f66f95a574cceeb4335c34066c096e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.Assign*."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class AssignOpTest(test.TestCase):

  def _initAssignFetch(self, x, y, use_gpu=False):
    """Initialize a param to init and update it with y."""
    super(AssignOpTest, self).setUp()
    with self.test_session(use_gpu=use_gpu):
      p = variables.Variable(x)
      assign = state_ops.assign(p, y)
      p.initializer.run()
      new_value = assign.eval()
      return p.eval(), new_value

  def _initAssignAddFetch(self, x, y, use_gpu=False):
    """Initialize a param to init, and compute param += y."""
    with self.test_session(use_gpu=use_gpu):
      p = variables.Variable(x)
      add = state_ops.assign_add(p, y)
      p.initializer.run()
      new_value = add.eval()
      return p.eval(), new_value

  def _initAssignSubFetch(self, x, y, use_gpu=False):
    """Initialize a param to init, and compute param -= y."""
    with self.test_session(use_gpu=use_gpu):
      p = variables.Variable(x)
      sub = state_ops.assign_sub(p, y)
      p.initializer.run()
      new_value = sub.eval()
      return p.eval(), new_value

  def _testTypes(self, vals):
    for dtype in [np.float32, np.float64, np.int32, np.int64]:
      x = np.zeros(vals.shape).astype(dtype)
      y = vals.astype(dtype)
      var_value, op_value = self._initAssignFetch(x, y, use_gpu=False)
      self.assertAllEqual(y, var_value)
      self.assertAllEqual(y, op_value)
      var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=False)
      self.assertAllEqual(x + y, var_value)
      self.assertAllEqual(x + y, op_value)
      var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
      self.assertAllEqual(x - y, var_value)
      self.assertAllEqual(x - y, op_value)
      if test.is_built_with_cuda() and dtype in [np.float32, np.float64]:
        var_value, op_value = self._initAssignFetch(x, y, use_gpu=True)
        self.assertAllEqual(y, var_value)
        self.assertAllEqual(y, op_value)
        var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=True)
        self.assertAllEqual(x + y, var_value)
        self.assertAllEqual(x + y, op_value)
        var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
        self.assertAllEqual(x - y, var_value)
        self.assertAllEqual(x - y, op_value)

  def testBasic(self):
    self._testTypes(np.arange(0, 20).reshape([4, 5]))

  def testAssignNonStrictShapeChecking(self):
    with self.cached_session():
      data = array_ops.fill([1024, 1024], 0)
      p = variables.VariableV1([1])
      a = state_ops.assign(p, data, validate_shape=False)
      a.op.run()
      self.assertAllEqual(p.eval(), data.eval())

      # Assign to yet another shape
      data2 = array_ops.fill([10, 10], 1)
      a2 = state_ops.assign(p, data2, validate_shape=False)
      a2.op.run()
      self.assertAllEqual(p.eval(), data2.eval())

  def testInitRequiredAssignAdd(self):
    with self.cached_session():
      p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
      a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
      with self.assertRaisesOpError("use uninitialized"):
        a.op.run()

  def testInitRequiredAssignSub(self):
    with self.cached_session():
      p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
      a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
      with self.assertRaisesOpError("use uninitialized"):
        a.op.run()


if __name__ == "__main__":
  test.main()