diff options
author | 2015-12-09 17:40:18 -0800 | |
---|---|---|
committer | 2015-12-09 17:40:18 -0800 | |
commit | 27259353e50e6bcaeeedbc26dc3aaaa5695fe500 (patch) | |
tree | c8710d98861fe7ca767059faff0ad44858869d92 /tensorflow/core/kernels/scatter_op.cc | |
parent | 5de908567355337fdebd997fb5c60993cbe9ba2e (diff) |
TensorFlow: Upstream changes from git.
Change 109849574
Avoid some missing return warnings
Change 109837783
Add invalid aggregation to error message.
Change 109835474
Improves docstring of RegisterGradient decorator.
Fixes a typo (input -> output) and uses lowercase name for neg in the provided example.
Change 109834486
Update generated Op docs.
Change 109830497
Fix per_image_whitening to handle edge case by preventing the sqrt() of a negative number which is possible due to numerical floating point issues. Unit test added.
Fixes #411
Change 109824286
Change TensorBoard/TAG to 4
Change 109824197
Update tutorials and documentation usage of print to use print as function not statement.
This way you can copy+paste code in a python3 context and it will still work.
Change 109824020
Fix another case where TensorBoard discards values after a restart.
We also need to not discard on graph_def, since user code or SummaryWriter may add graph_defs at step 0 after every restart.
Change 109821924
Defines Templates for variable sharing.
A Template is a function that generates a sub-graph with the same variables each time it is called.
Two different templates defined with the same underlying function also return different variables.
Change 109815801
Don't instatiate the eigen expressions for additions and subtractions of
boolean since they won't be called. This reduces the size of the binary a bit.
Change 109795730
Allow casts to and from int8
Change 109791914
Python 3 fix: filter has no len
gradients.py calls len on the output of filter. A call to tuple is needed in
between.
Not sure why this wasn't caught when we ran the Python 3 tests. If I break it
for Python 2 several tests break.
Change 109757009
Fix minor grammatical errors in about.html
The missing article needs no justification, I think.
has -> have, because subjects are 'usability and functionality', not 'TensorFlow'.
and also -> and, because 'also' is superfluous in this use.
Change 109756627
TensorFlow: some doc updates to models/ files
Change 109743899
TensorFlow: remove one more clang warning (class / struct inconsistency).
Change 109741933
Document default for max_images in tf.image_summary
It used to say max_images=None which hid the C++ defalut of 3.
Now it says max_images=3.
Fixes https://github.com/tensorflow/tensorflow/issues/441.
It's unfortunate that an edit-distance-5 change produces such a large CL.
Change 109741569
Update generated Op docs.
Change 109739599
Renaming the Python variables in the layer weights of the fully connected
MNIST model so that the variable and the TensorFlow names are different. This
allows the documentation to be more explicit about the distinction between the
weights and biases of different layers. Also, the documentation gets to
describe the whether the TF name or the Python name is being used.
Base CL: 109851372
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_op.cc | 43 |
1 files changed, 26 insertions, 17 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 59315876aa..84dd625a9f 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -24,6 +24,30 @@ namespace tensorflow { enum class UpdateOp { ASSIGN, ADD, SUB }; +template <UpdateOp Op> +struct Assign {}; +template <> +struct Assign<UpdateOp::ASSIGN> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p = u; + } +}; +template <> +struct Assign<UpdateOp::ADD> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p += u; + } +}; +template <> +struct Assign<UpdateOp::SUB> { + template <typename Params, typename Update> + static void Run(Params p, Update u) { + p -= u; + } +}; + template <class T, typename Index, UpdateOp op> class ScatterUpdateOp : public OpKernel { public: @@ -105,23 +129,8 @@ class ScatterUpdateOp : public OpKernel { for (Index i = 0; i < N; i++) { // Copy last Ndim-1 dimensions of Tupdates[i] to // Tparams[Tindices[i]] - switch (op) { - case UpdateOp::ASSIGN: { - Tparams_flat.template chip<0>(Tindices_vec(i)) = - Tupdates_flat.template chip<0>(i); - break; - } - case UpdateOp::ADD: { - Tparams_flat.template chip<0>(Tindices_vec(i)) += - Tupdates_flat.template chip<0>(i); - break; - } - case UpdateOp::SUB: { - Tparams_flat.template chip<0>(Tindices_vec(i)) -= - Tupdates_flat.template chip<0>(i); - break; - } - } + Assign<op>::Run(Tparams_flat.template chip<0>(Tindices_vec(i)), + Tupdates_flat.template chip<0>(i)); } } } |