aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/snapshot_op.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/snapshot_op.h')
-rw-r--r--tensorflow/core/kernels/snapshot_op.h26
1 files changed, 8 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h
index b94834f159..a18065d42b 100644
--- a/tensorflow/core/kernels/snapshot_op.h
+++ b/tensorflow/core/kernels/snapshot_op.h
@@ -26,29 +26,19 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace functor {
+// Functor used by SnapshotOp.
template <typename Device, typename Scalar>
-class SnapshotOp : public OpKernel {
- public:
- explicit SnapshotOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- const Tensor& input = context->input(0);
- Tensor* output = nullptr;
- // Try to use buffer forwarding to avoid an explicit copy.
- OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {0}, 0, input.shape(), &output));
- if (!output->SharesBufferWith(input)) {
- // We had to allocate a new buffer since the refcount on the input was
- // greater than 1. Copy the input to the new buffer.
- const Device& device = context->eigen_device<Device>();
- device.memcpy(output->template flat<Scalar>().data(),
- input.template flat<Scalar>().data(),
- input.NumElements() * sizeof(Scalar));
- }
+struct Snapshot {
+ void operator()(const Device& device,
+ typename TTypes<Scalar>::ConstTensor input,
+ typename TTypes<Scalar>::Tensor output) {
+ device.memcpy(output.data(), input.data(), input.size() * sizeof(Scalar));
}
};
+} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_SNAPSHOT_OP_H_