1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.scatter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
def _AsType(v, vtype):
return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
def _NumpyUpdate(ref, indices, updates):
for i, indx in np.ndenumerate(indices):
indx = i[:-1] + (indx,)
ref[indx] = updates[i]
_TF_OPS_TO_NUMPY = {
state_ops.batch_scatter_update: _NumpyUpdate,
}
class ScatterTest(test.TestCase):
def _VariableRankTest(self,
tf_scatter,
vtype,
itype,
repeat_indices=False,
updates_are_scalar=False):
np.random.seed(8)
with self.test_session(use_gpu=False):
for indices_shape in (2,), (3, 7), (3, 4, 7):
for extra_shape in (), (5,), (5, 9):
# Generate random indices with no duplicates for easy numpy comparison
sparse_dim = len(indices_shape) - 1
indices = np.random.randint(
indices_shape[sparse_dim], size=indices_shape, dtype=itype)
updates = _AsType(
np.random.randn(*(indices_shape + extra_shape)), vtype)
old = _AsType(np.random.randn(*(indices_shape + extra_shape)), vtype)
# Scatter via numpy
new = old.copy()
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref = variables.Variable(old)
ref.initializer.run()
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
def testVariableRankUpdate(self):
vtypes = [np.float32, np.float64]
for vtype in vtypes:
for itype in (np.int32, np.int64):
self._VariableRankTest(
state_ops.batch_scatter_update, vtype, itype)
def testBooleanScatterUpdate(self):
with self.test_session(use_gpu=False) as session:
var = variables.Variable([True, False])
update0 = state_ops.batch_scatter_update(var, [1], [True])
update1 = state_ops.batch_scatter_update(
var, constant_op.constant(
[0], dtype=dtypes.int64), [False])
var.initializer.run()
session.run([update0, update1])
self.assertAllEqual([False, True], var.eval())
def testScatterOutOfRange(self):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.test_session(use_gpu=False):
ref = variables.Variable(params)
ref.initializer.run()
# Indices all in range, no problem.
indices = np.array([2, 0, 5])
state_ops.batch_scatter_update(ref, indices, updates).eval()
# Test some out of range errors.
indices = np.array([-1, 0, 5])
with self.assertRaisesOpError(
r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
state_ops.batch_scatter_update(ref, indices, updates).eval()
indices = np.array([2, 0, 6])
with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
r'shape \[6\]'):
state_ops.batch_scatter_update(ref, indices, updates).eval()
if __name__ == '__main__':
test.main()
|