aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/assign_op.h
blob: d5a34bae736f59cdfc226d0020ac122b3db01353 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/* Copyright 2015 Google Inc. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_
#define TENSORFLOW_KERNELS_ASSIGN_OP_H_

#define EIGEN_USE_THREADS

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"

namespace tensorflow {

// TODO(jeff): Get rid of use_exclusive_lock_ option

// Computes *input[0] = input[1]
class AssignOp : public OpKernel {
 public:
  explicit AssignOp(OpKernelConstruction* context) : OpKernel(context) {
    OP_REQUIRES_OK(context,
                   context->GetAttr("use_locking", &use_exclusive_lock_));
    OP_REQUIRES_OK(context,
                   context->GetAttr("validate_shape", &validate_shape_));
    OP_REQUIRES(context, IsRefType(context->input_type(0)),
                errors::InvalidArgument("lhs input needs to be a ref type"));
  }

  void Compute(OpKernelContext* context) override {
    Tensor rhs = context->input(1);

    // We always return the input ref.
    context->forward_ref_input_to_ref_output(0, 0);

    // If the left hand side is not initialized, or the shape of the
    // right-hand side is different than the left hand side, we need
    // to allocate a new tensor.
    {
      mutex_lock l(*context->input_ref_mutex(0));

      Tensor old_lhs = context->mutable_input(0, true);

      if (validate_shape_) {
        OP_REQUIRES(
            context, old_lhs.shape().IsSameSize(rhs.shape()),
            errors::InvalidArgument(
                "Assign requires shapes of both tensors to match. lhs shape= ",
                old_lhs.shape().ShortDebugString(), " rhs shape= ",
                rhs.shape().ShortDebugString()));
      }

      const bool same_shape = old_lhs.shape().IsSameSize(rhs.shape());
      if (!old_lhs.IsInitialized() || !same_shape) {
        // Create new tensor whose shape matches the right hand side
        // and copy then hand off to lhs.
        // We can't always know how this value will be used downstream,
        // so make conservative assumptions in specifying the memory
        // allocation attributes.
        AllocatorAttributes attr;
        attr.set_gpu_compatible(true);
        PersistentTensor copy;
        Tensor* copyTensor = nullptr;
        OP_REQUIRES_OK(
            context, context->allocate_persistent(old_lhs.dtype(), rhs.shape(),
                                                  &copy, &copyTensor, attr));
        Copy(context, copyTensor, rhs);
        context->replace_ref_input(0, *copyTensor, true);
        return;
      }

      // The tensor has already been initialized and the right hand side
      // matches the left hand side's shape.
      if (use_exclusive_lock_) {
        Copy(context, &old_lhs, rhs);
        return;
      }
    }

    // The tensor has already been initialized and the right hand side
    // matches the left hand side's shape. We have been told to do the
    // copy outside the lock.
    Tensor old_unlocked_lhs = context->mutable_input(0, false);
    Copy(context, &old_unlocked_lhs, rhs);
  }

  virtual void Copy(OpKernelContext* context, Tensor* lhs,
                    const Tensor& rhs) = 0;

  bool use_exclusive_lock_;
  bool validate_shape_;
};

}  // end namespace tensorflow

#endif  // TENSORFLOW_KERNELS_ASSIGN_OP_H_