aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-01-23 11:20:29 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-23 11:24:07 -0800
commit2a3559feb6564e4e46a56a71b200f6a17afe69e7 (patch)
tree0ae6820ca8e962d0de7861281b8d2a6165aeb730 /tensorflow/contrib/batching
parentd1020bfdedadc3da7c89a69651c45cc790922b69 (diff)
Moves batch ops to core and exposes tf.contrib.batching
PiperOrigin-RevId: 182963906
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/BUILD47
-rw-r--r--tensorflow/contrib/batching/kernels/BUILD34
-rw-r--r--tensorflow/contrib/batching/kernels/batch_kernels.cc997
-rw-r--r--tensorflow/contrib/batching/ops/batch_ops.cc164
-rw-r--r--tensorflow/contrib/batching/python/ops/batch_ops.py12
5 files changed, 14 insertions, 1240 deletions
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index cd98f0e703..ee67909133 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -67,48 +67,14 @@ load(
)
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
-tf_custom_op_library(
- name = "python/ops/_batch_ops.so",
- srcs = ["ops/batch_ops.cc"],
- deps = [
- "//tensorflow/contrib/batching/kernels:batch_kernels",
- ],
-)
-
-tf_gen_op_libs(
- op_lib_names = ["batch_ops"],
-)
-
-tf_gen_op_wrapper_py(
- name = "batch_ops",
- deps = [":batch_ops_op_lib"],
-)
-
-tf_kernel_library(
- name = "batch_ops_kernels",
- deps = [
- "//tensorflow/contrib/batching/kernels:batch_kernels",
- "//tensorflow/contrib/batching/util:periodic_function",
- "//tensorflow/core/kernels:concat_lib",
- "//tensorflow/core/kernels:ops_util",
- "//tensorflow/core/kernels:split_lib",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
+py_library(
name = "batch_py",
srcs = glob(["python/ops/*.py"]) + ["__init__.py"],
- dso = [":python/ops/_batch_ops.so"],
- kernels = [
- ":batch_ops_kernels",
- ":batch_ops_op_lib",
- ],
srcs_version = "PY2AND3",
deps = [
- ":batch_ops",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:batch_ops_gen",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
@@ -118,6 +84,14 @@ tf_custom_op_py_library(
],
)
+cc_library(
+ name = "batch_ops_kernels",
+ deps = [
+ "//tensorflow/core/kernels:batch_kernels",
+ ],
+ alwayslink = 1,
+)
+
py_test(
name = "batch_ops_test",
size = "small",
@@ -133,6 +107,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
"//tensorflow/python:gradients",
"//tensorflow/python:script_ops",
],
diff --git a/tensorflow/contrib/batching/kernels/BUILD b/tensorflow/contrib/batching/kernels/BUILD
deleted file mode 100644
index 6e53dd9a5f..0000000000
--- a/tensorflow/contrib/batching/kernels/BUILD
+++ /dev/null
@@ -1,34 +0,0 @@
-# Description:
-# Contains kernels for the batching ops.
-
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-cc_library(
- name = "batch_kernels",
- srcs = ["batch_kernels.cc"],
- deps = [
- "//tensorflow/contrib/batching:shared_batch_scheduler_hdrs",
- "//tensorflow/contrib/batching/util:periodic_function_dynamic",
- "//tensorflow/core:framework_headers_lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core/kernels:concat_lib_hdrs",
- "//tensorflow/core/kernels:ops_util_hdrs",
- "//tensorflow/core/kernels:split_lib_hdrs",
- ],
- alwayslink = 1,
-)
-
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
-)
diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc
deleted file mode 100644
index 6041d8c9b2..0000000000
--- a/tensorflow/contrib/batching/kernels/batch_kernels.cc
+++ /dev/null
@@ -1,997 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/contrib/batching/shared_batch_scheduler.h"
-#include "tensorflow/contrib/batching/util/periodic_function.h"
-#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/register_types.h"
-#include "tensorflow/core/framework/resource_mgr.h"
-#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/framework/tensor_util.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/concat_lib.h"
-#include "tensorflow/core/kernels/ops_util.h"
-#include "tensorflow/core/kernels/split_lib.h"
-#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/platform/macros.h"
-
-namespace tensorflow {
-
-typedef Eigen::ThreadPoolDevice CPUDevice;
-typedef Eigen::GpuDevice GPUDevice;
-#ifdef TENSORFLOW_USE_SYCL
-typedef Eigen::SyclDevice SYCLDevice;
-#endif // TENSORFLOW_USE_SYCL
-
-// Concatenates 'inputs' into a single tensor along the zeroth dimension.
-// Requires that all elements of 'inputs' have element type T. Writes to the
-// op's output at position 'output_index', using 'context' for the allocation to
-// ensure proper device placement.
-template <typename T>
-Status Concat(OpKernelContext* context, const gtl::ArraySlice<Tensor>& inputs,
- int output_index) {
- const int input_dims = inputs[0].dims();
- const TensorShape& input_shape = inputs[0].shape();
-
- // Note that we reduce the concat of k-dimensional tensors into a two
- // dimensional concat. Assuming the dimensions of any input tensor are
- // {y0, y1,...,ym-1}, we flatten it to {1, y}, where y = Prod_i(yi).
- std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> inputs_flat;
- inputs_flat.reserve(inputs.size());
- int64 output_dim0 = 0;
- for (size_t i = 0; i < inputs.size(); ++i) {
- const Tensor& input = inputs[i];
- if (input.dims() != input_dims) {
- return errors::InvalidArgument(
- "Ranks of all input tensors should match: shape[0] = ",
- input_shape.DebugString(), " vs. shape[", i,
- "] = ", input.shape().DebugString());
- }
- for (int j = 1; j < input_dims; ++j) {
- if (input.dim_size(j) != input_shape.dim_size(j)) {
- return errors::InvalidArgument(
- "Dimensions of inputs should match: shape[0] = ",
- input_shape.DebugString(), " vs. shape[", i,
- "] = ", input.shape().DebugString());
- }
- }
- if (input.NumElements() > 0) {
- inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
- input.shaped<T, 2>({1, input.NumElements()})));
- }
- output_dim0 += input.dim_size(0);
- }
-
- TensorShape output_shape(input_shape);
- output_shape.set_dim(0, output_dim0);
- Tensor* output = nullptr;
- TF_RETURN_IF_ERROR(
- context->allocate_output(output_index, output_shape, &output));
- if (output->NumElements() > 0) {
- auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
-#if GOOGLE_CUDA
- if (std::is_same<Device, GPUDevice>::value) {
- ConcatGPU<T>(context, inputs_flat, output, &output_flat);
- return Status::OK();
- }
-#endif // GOOGLE_CUDA
- ConcatCPU<T>(context->device(), inputs_flat, &output_flat);
- }
-
- return Status::OK();
-}
-
-// The Split*() functions split 'input' with element type T into 'sizes.size()'
-// tensors along the zeroth dimension, with the ith split having zeroth-
-// dimension size 'sizes[i]'. They allocate the output tensors using 'context',
-// for proper device placement.
-
-// Handles special cases that are cheap. Sets 'done==true' iff it found an
-// applicable special case and wrote to the outputs. Otherwise acts as a no-op.
-template <typename T>
-Status SplitEasyCases(OpKernelContext* context, const Tensor& input,
- const gtl::ArraySlice<int64>& sizes,
- std::vector<Tensor>* outputs, bool* done) {
- *done = false;
-
- int64 total_size = 0;
- for (const int64 size : sizes) {
- total_size += size;
- }
- if (total_size > input.shape().dim_size(0)) {
- return errors::InvalidArgument(
- "Sum of split sizes must not exceed dim0-size of input tensor");
- }
-
- // Special case 0: trivial 1-way split.
- if (sizes.size() == 1 && sizes.at(0) == input.shape().dim_size(0)) {
- outputs->push_back(input);
- *done = true;
- return Status::OK();
- }
-
- // Special case 1: input is aligned.
- if (IsInnerDimsSizeAligned<T>(input.shape())) {
- int64 position = 0;
- for (const int64 size : sizes) {
- outputs->emplace_back(input.Slice(position, position + size));
- position += size;
- }
- *done = true;
- return Status::OK();
- }
-
- return Status::OK();
-}
-
-// Handles the general case, on CPU.
-template <typename T>
-Status SplitCPU(OpKernelContext* context, const Tensor& input,
- const gtl::ArraySlice<int64>& sizes,
- std::vector<Tensor>* outputs) {
- int64 suffix_dim_size = 1;
- for (int i = 1; i < input.shape().dims(); ++i) {
- suffix_dim_size *= input.shape().dim_size(i);
- }
- auto input_reshaped =
- input.shaped<T, 3>({1, input.shape().dim_size(0), suffix_dim_size});
-
- int64 position = 0;
- for (const int64 size : sizes) {
- TensorShape output_shape = input.shape();
- output_shape.set_dim(0, size);
- Tensor output;
- TF_RETURN_IF_ERROR(
- context->allocate_temp(input.dtype(), output_shape, &output));
- auto output_shaped = output.shaped<T, 3>({1, size, suffix_dim_size});
-
- Eigen::DSizes<Eigen::DenseIndex, 3> slice_indices{0, position, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> slice_sizes{1, size, suffix_dim_size};
- functor::Split<CPUDevice, T>()(context->eigen_device<CPUDevice>(),
- output_shaped, input_reshaped, slice_indices,
- slice_sizes);
-
- outputs->emplace_back(output);
-
- position += size;
- }
-
- return Status::OK();
-}
-
-#if GOOGLE_CUDA
-
-// Handles the general case, on GPU.
-template <typename T>
-Status SplitGPU(OpKernelContext* context, const Tensor& input,
- const gtl::ArraySlice<int64>& sizes,
- std::vector<Tensor>* outputs) {
- // TODO(olston, apassos): Implement this.
- LOG(FATAL) << "Not yet implemented"; // Crash ok
-}
-
-#endif // GOOGLE_CUDA
-
-// The outer function that dispatches to the various Split*() functions above.
-template <typename T>
-Status Split(OpKernelContext* context, const Tensor& input,
- const gtl::ArraySlice<int64>& sizes,
- std::vector<Tensor>* outputs) {
- bool easy_cases_done;
- TF_RETURN_IF_ERROR(
- SplitEasyCases<T>(context, input, sizes, outputs, &easy_cases_done));
- if (easy_cases_done) {
- return Status::OK();
- }
-
-#if GOOGLE_CUDA
-// TODO(olston, apassos): Handle non-CPU cases.
-// return SplitGPU<T>(context, input, sizes, outputs);
-#endif // GOOGLE_CUDA
- return SplitCPU<T>(context, input, sizes, outputs);
-}
-
-// A class encapsulating the state and logic for batching tensors.
-class BatchResource : public ResourceBase {
- public:
- static Status Create(int32 num_batch_threads, int32 max_batch_size,
- int32 batch_timeout_micros,
- const std::vector<int32>& allowed_batch_sizes,
- std::unique_ptr<BatchResource>* resource) {
- std::unique_ptr<BatchResource> new_resource(new BatchResource);
-
- Batcher::Options batcher_options;
- batcher_options.num_batch_threads = num_batch_threads;
- TF_RETURN_IF_ERROR(
- Batcher::Create(batcher_options, &new_resource->batcher_));
-
- new_resource->batcher_queue_options_.max_batch_size = max_batch_size;
- new_resource->batcher_queue_options_.batch_timeout_micros =
- batch_timeout_micros;
-
- new_resource->allowed_batch_sizes_ = allowed_batch_sizes;
-
- *resource = std::move(new_resource);
- return Status::OK();
- }
-
- string DebugString() final { return "BatchResource"; }
-
- // Ingests data from one invocation of the batch op. The data is enqueued to
- // be combined with others into a batch, asynchronously.
- Status RegisterInput(int64 guid, OpKernelContext* context,
- const string& batcher_queue_name,
- AsyncOpKernel::DoneCallback done_callback) {
- std::unique_ptr<BatchTask> batch_components(new BatchTask);
- batch_components->guid = guid;
- OpInputList tensors;
- TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
- for (int i = 0; i < tensors.size(); ++i) {
- const Tensor& tensor = tensors[i];
- if (tensor.shape().dims() == 0) {
- return errors::InvalidArgument(
- "Batching input tensors must have at least one dimension");
- }
- if (tensors.size() >= 2 &&
- tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
- return errors::InvalidArgument(
- "Batching input tensors supplied in a given op invocation must "
- "have equal 0th-dimension size");
- }
- batch_components->inputs.push_back(tensor);
- }
- batch_components->context = context;
- batch_components->done_callback = std::move(done_callback);
-
- BatcherQueue* batcher_queue;
- TF_RETURN_IF_ERROR(
- LookupOrCreateBatcherQueue(batcher_queue_name, &batcher_queue));
- return batcher_queue->Schedule(&batch_components);
- }
-
- private:
- BatchResource() = default;
-
- // One input to be batched. Corresponds to one invocation of the batch op.
- struct BatchTask : public serving::BatchTask {
- // A unique ID to identify this invocation of Batch.
- int64 guid;
-
- std::vector<Tensor> inputs;
- OpKernelContext* context;
- AsyncOpKernel::DoneCallback done_callback;
-
- size_t size() const override { return inputs[0].shape().dim_size(0); }
- };
-
- using Batcher = serving::SharedBatchScheduler<BatchTask>;
- using BatcherQueue = serving::BatchScheduler<BatchTask>;
- using Batch = serving::Batch<BatchTask>;
-
- // Validates that it's legal to combine the tasks in 'batch' into a batch.
- // Assumes the batch is non-empty.
- static Status ValidateBatch(const Batch& batch) {
- for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
- const BatchTask& task = batch.task(task_idx);
-
- if (task.inputs.size() != batch.task(0).inputs.size()) {
- return errors::InvalidArgument(
- "Batching inputs must have equal number of edges");
- }
- }
-
- return Status::OK();
- }
-
- // Returns the smallest entry in 'allowed_batch_sizes_' that is greater than
- // or equal to 'batch_size'. If 'allowed_batch_sizes_' is empty, simply
- // returns 'batch_size'.
- int RoundToLowestAllowedBatchSize(int batch_size) const {
- if (allowed_batch_sizes_.empty()) {
- return batch_size;
- }
- for (int allowed_size : allowed_batch_sizes_) {
- if (allowed_size >= batch_size) {
- return allowed_size;
- }
- }
- LOG(ERROR) << "Maximum batch size greater than largest allowed size; "
- "ignoring allowed sizes constraint";
- return batch_size;
- }
-
- // Processes a batch of one or more BatchTask entries.
- void ProcessBatch(std::unique_ptr<Batch> batch) const {
- if (batch->empty()) {
- return;
- }
- const int padded_batch_size = RoundToLowestAllowedBatchSize(batch->size());
- const int padding_amount = padded_batch_size - batch->size();
-
- OpKernelContext* last_task_context =
- batch->task(batch->num_tasks() - 1).context;
- AsyncOpKernel::DoneCallback last_task_callback =
- batch->task(batch->num_tasks() - 1).done_callback;
-
- OP_REQUIRES_OK_ASYNC(last_task_context, ValidateBatch(*batch),
- last_task_callback);
-
- // All tasks should have the same number of input edges.
- const int num_input_edges = batch->task(0).inputs.size();
-
- // Process each input edge one at a time (the typical case has just one).
- for (int i = 0; i < num_input_edges; ++i) {
- // Emit batch->num_tasks() - 1 empty output tensors.
- for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
- const BatchTask& task = batch->task(task_idx);
- TensorShape output_shape(task.inputs.at(i).shape());
- output_shape.set_dim(0, 0);
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- task.context,
- task.context->allocate_output(i, output_shape, &output),
- task.done_callback);
- }
-
- // Concatenate the tasks ith input tensors into a big output tensor.
- std::vector<Tensor> to_concatenate;
- to_concatenate.reserve(batch->num_tasks());
- for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
- to_concatenate.push_back(batch->task(task_idx).inputs.at(i));
- }
-
- // Add padding as needed. Use the first row of the first task's tensor as
- // the data for padding.
- if (padding_amount > 0) {
- const Tensor& padding_source = batch->task(0).inputs.at(i);
- Tensor padding;
- if (padding_source.shape().dim_size(0) == 1) {
- padding = padding_source;
- } else {
- const std::vector<int64> slice_sizes = {1};
- const DataType type = padding_source.dtype();
- Status slice_status;
- std::vector<Tensor> slices;
- switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- slice_status = SplitCPU<type>(last_task_context, padding_source, \
- slice_sizes, &slices); \
- break;
- TF_CALL_ALL_TYPES(CASE);
-#undef CASE
- default:
- slice_status =
- errors::InvalidArgument("Unsupported data type: ", type);
- break;
- }
- OP_REQUIRES_OK_ASYNC(last_task_context, slice_status,
- last_task_callback);
- padding = slices.at(0);
- }
- for (int i = 0; i < padding_amount; ++i) {
- to_concatenate.push_back(padding);
- }
- }
-
- const DataType type = to_concatenate[0].dtype();
- Status concat_status;
- switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- concat_status = Concat<type>(last_task_context, to_concatenate, i); \
- break;
- TF_CALL_ALL_TYPES(CASE);
-#undef CASE
- default:
- concat_status =
- errors::InvalidArgument("Unsupported data type: ", type);
- break;
- }
- OP_REQUIRES_OK_ASYNC(last_task_context, concat_status,
- last_task_callback);
- }
-
- // Emit batch->num_tasks() - 1 empty index tensors.
- for (int task_idx = 0; task_idx < batch->num_tasks() - 1; ++task_idx) {
- const BatchTask& task = batch->task(task_idx);
- TensorShape index_shape({0, 3});
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- task.context,
- task.context->allocate_output(num_input_edges, index_shape, &output),
- task.done_callback);
- }
- // Emit all ID tensors.
- for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
- const BatchTask& task = batch->task(task_idx);
- Tensor* id;
- OP_REQUIRES_OK_ASYNC(task.context,
- task.context->allocate_output(num_input_edges + 1,
- TensorShape({}), &id),
- task.done_callback);
- id->scalar<int64>()() = task.guid;
- }
- OP_REQUIRES_OK_ASYNC(
- last_task_context,
- EmitIndexTensor(last_task_context, *batch, num_input_edges),
- last_task_callback);
-
- // Signal done for each element of the batch. (At this point, the contexts
- // are no longer guaranteed to remain live.)
- for (int task_idx = 0; task_idx < batch->num_tasks(); ++task_idx) {
- batch->mutable_task(task_idx)->done_callback();
- }
- }
-
- // Emits an index tensor, which the Unbatch op will use to un-concatenate
- // the tensor and attribute the pieces to the right batch keys. The index
- // tensor contains, for each input: [batch_key, start_offset, end_offset]
- // where start_offset and end_offset represent the range of entries in the
- // concatenated tensors that belong to that input.
- //
- // Emits the result to the output at 'output_index' using 'context'.
- static Status EmitIndexTensor(OpKernelContext* context, const Batch& batch,
- int output_index) {
- const TensorShape index_shape({batch.num_tasks(), 3});
- Tensor* index = nullptr;
- TF_RETURN_IF_ERROR(
- context->allocate_output(output_index, index_shape, &index));
- auto index_flat = index->shaped<int64, 2>({batch.num_tasks(), 3});
- size_t offset = 0;
- for (int task_idx = 0; task_idx < batch.num_tasks(); ++task_idx) {
- const BatchTask& task = batch.task(task_idx);
- index_flat(task_idx, 0) = task.guid;
- index_flat(task_idx, 1) = offset;
- index_flat(task_idx, 2) = offset + task.size();
- offset += task.size();
- }
- return Status::OK();
- }
-
- // Looks up the batcher queue for 'queue_name'. If it didn't previously exist,
- // creates it.
- Status LookupOrCreateBatcherQueue(const string& queue_name,
- BatcherQueue** queue) {
- mutex_lock l(batcher_queues_mu_);
-
- auto it = batcher_queues_.find(queue_name);
- if (it != batcher_queues_.end()) {
- *queue = it->second.get();
- return Status::OK();
- }
-
- std::unique_ptr<BatcherQueue> new_queue;
- auto process_batch_callback = [this](std::unique_ptr<Batch> batch) {
- ProcessBatch(std::move(batch));
- };
- TF_RETURN_IF_ERROR(batcher_->AddQueue(batcher_queue_options_,
- process_batch_callback, &new_queue));
- *queue = new_queue.get();
- batcher_queues_[queue_name] = std::move(new_queue);
- return Status::OK();
- }
-
- // A batch scheduler, and options for creating queues.
- std::shared_ptr<Batcher> batcher_;
- Batcher::QueueOptions batcher_queue_options_;
-
- // A collection of batcher queues, keyed on queue name.
- // TODO(olston): Garbage-collect unused queues (perhaps simply remove empty
- // ones (with a time delay?); it's okay if they get recreated later).
- mutable mutex batcher_queues_mu_;
- std::map<string, std::unique_ptr<BatcherQueue>> batcher_queues_
- GUARDED_BY(batcher_queues_mu_);
-
- std::vector<int32> allowed_batch_sizes_;
-};
-
-class BatchKernel : public AsyncOpKernel {
- public:
- explicit BatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
- OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
- // If shared_name is not supplied, use name instead (prevent collisions by
- // default).
- if (shared_name_.empty()) {
- shared_name_ = name();
- }
- OP_REQUIRES_OK(c, c->GetAttr("batching_queue", &batcher_queue_));
- OP_REQUIRES_OK(c, c->GetAttr("num_batch_threads", &num_batch_threads_));
- OP_REQUIRES_OK(c, c->GetAttr("max_batch_size", &max_batch_size_));
- OP_REQUIRES_OK(c,
- c->GetAttr("batch_timeout_micros", &batch_timeout_micros_));
- OP_REQUIRES_OK(c, c->GetAttr("allowed_batch_sizes", &allowed_batch_sizes_));
- OP_REQUIRES_OK(c, ValidateAllowedBatchSizes());
- }
-
- void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
- BatchResource* br;
- std::function<Status(BatchResource * *r)> creator =
- [this](BatchResource** r) {
- std::unique_ptr<BatchResource> new_resource;
- TF_RETURN_IF_ERROR(BatchResource::Create(
- num_batch_threads_, max_batch_size_, batch_timeout_micros_,
- allowed_batch_sizes_, &new_resource));
- *r = new_resource.release();
- return Status::OK();
- };
- OP_REQUIRES_OK_ASYNC(c,
- c->resource_manager()->LookupOrCreate(
- container_, shared_name_, &br, creator),
- done);
- const Status status =
- br->RegisterInput(random::New64(), c, batcher_queue_, done);
- br->Unref();
- if (!status.ok()) {
- OP_REQUIRES_OK_ASYNC(c, status, done);
- }
- // Assume br calls done, so nothing to do here.
- }
-
- // Validates 'allowed_batch_sizes_'. The entries must increase monotonically,
- // and the last one must equal 'max_batch_size_'.
- Status ValidateAllowedBatchSizes() const {
- if (allowed_batch_sizes_.empty()) {
- return Status::OK();
- }
- int32 last_size = 0;
- for (size_t i = 0; i < allowed_batch_sizes_.size(); ++i) {
- const int32 size = allowed_batch_sizes_.at(i);
- if (i > 0 && size <= last_size) {
- return errors::InvalidArgument(
- "allowed_batch_sizes entries must be monotonically increasing");
- }
- if (i == allowed_batch_sizes_.size() - 1 && size != max_batch_size_) {
- return errors::InvalidArgument(
- "final entry in allowed_batch_sizes must equal max_batch_size");
- }
- last_size = size;
- }
- return Status::OK();
- }
-
- private:
- string container_;
- string shared_name_;
- string batcher_queue_;
- int32 num_batch_threads_;
- int32 max_batch_size_;
- int32 batch_timeout_micros_;
- std::vector<int32> allowed_batch_sizes_;
-};
-
-REGISTER_KERNEL_BUILDER(Name("Batch").Device(DEVICE_CPU), BatchKernel);
-
-// A class encapsulating the state and logic for unbatching tensors.
-//
-// UnbatchResource keeps two data structures indexed by batch-key: one which has
-// the continuations for all concurrent kernels which are waiting for tensors
-// and another which has tensors which are waiting for their corresponding
-// kernels to run. Whenever a kernel runs, we either grab its tensor if it's
-// waiting already, or we insert it in the queue and then look at its tensor to
-// see if it can be used to dispatch any stored continuations.
-class UnbatchResource : public ResourceBase {
- public:
- explicit UnbatchResource(int32 timeout_micros)
- : timeout_micros_(timeout_micros),
- timeout_enforcer_(new serving::PeriodicFunction(
- [this] { EnforceTimeout(); }, 1000 /* 1 ms */)) {}
-
- ~UnbatchResource() override {
- // Tear down 'timeout_enforcer_' first, since it accesses other state in
- // this class.
- timeout_enforcer_ = nullptr;
- }
-
- string DebugString() final { return "UnbatchResource"; }
-
- Status Compute(OpKernelContext* context, AsyncOpKernel::DoneCallback done) {
- const Tensor& data_t = context->input(0);
- const Tensor& batch_index_t = context->input(1);
-
- if (batch_index_t.shape().dim_size(0) > data_t.shape().dim_size(0)) {
- return errors::InvalidArgument(
- "Wrong shape for index tensor. Expected 0th dimension size to be no "
- "greater than ",
- data_t.shape().dim_size(0),
- "; Got: ", batch_index_t.shape().dim_size(0), ".");
- }
- if (batch_index_t.shape().dim_size(1) != 3) {
- return errors::InvalidArgument(
- "Wrong shape for index tensor. Expected 1st dimension size to be 3 ; "
- "Got: ",
- batch_index_t.shape().dim_size(1), ".");
- }
-
- const int64 batch_key = context->input(2).scalar<int64>()();
- const bool nonempty_input = batch_index_t.dim_size(0) > 0;
-
- // If we have a non-empty tensor, slice it up.
- // (It is important to do this outside of the critical section below.)
- // The following variables are populated iff 'nonempty_input==true'.
- std::vector<int64> sizes;
- std::vector<int64> batch_keys;
- std::vector<Tensor> split_inputs;
- if (nonempty_input) {
- auto batch_indices =
- batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
- for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
- sizes.push_back(batch_indices(i, 2) - batch_indices(i, 1));
- batch_keys.push_back(batch_indices(i, 0));
- }
-
- const DataType type = data_t.dtype();
- switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- TF_RETURN_IF_ERROR(Split<type>(context, data_t, sizes, &split_inputs)); \
- break;
- TF_CALL_ALL_TYPES(CASE);
-#undef CASE
- default:
- return errors::InvalidArgument("Unsupported data type: ", type);
- }
- }
-
- // Critical section.
- std::vector<AsyncOpKernel::DoneCallback> done_callbacks_to_call;
- Status status = [&]() -> Status {
- mutex_lock ml(mu_);
-
- // Check to see whether the tensor we want is already ready.
- auto tensor_it = waiting_tensors_.find(batch_key);
- if (tensor_it != waiting_tensors_.end()) {
- context->set_output(0, tensor_it->second.tensor);
- waiting_tensors_.erase(tensor_it);
- done_callbacks_to_call.push_back(done);
- return Status::OK();
- }
-
- const uint64 deadline_micros =
- Env::Default()->NowMicros() + timeout_micros_;
-
- // Add ourselves to the waitlist for tensors.
- if (!waiting_callbacks_
- .emplace(batch_key,
- WaitingCallback{deadline_micros, context, done})
- .second) {
- return errors::AlreadyExists(
- "Multiple session runs with the same batch key.");
- }
-
- // If we have a non-empty tensor, finish the waitlisted runs,
- // and store any remaining pieces.
- if (nonempty_input) {
- for (size_t i = 0; i < batch_keys.size(); ++i) {
- auto runs_it = waiting_callbacks_.find(batch_keys[i]);
- if (runs_it != waiting_callbacks_.end()) {
- runs_it->second.context->set_output(0, split_inputs[i]);
- done_callbacks_to_call.push_back(runs_it->second.done);
- waiting_callbacks_.erase(runs_it);
- } else {
- // Note: the deadline here is in case we are arriving late and the
- // kernel that should rendezvous with this tensor has already waited
- // and timed out.
- if (!waiting_tensors_
- .emplace(batch_keys[i],
- WaitingTensor{deadline_micros, split_inputs[i]})
- .second) {
- return errors::AlreadyExists(
- "Multiple tensors returned for same batch key.");
- }
- }
- }
- }
-
- return Status::OK();
- }();
-
- for (const AsyncOpKernel::DoneCallback& done_callback :
- done_callbacks_to_call) {
- done_callback();
- }
-
- return status;
- }
-
- private:
- // Evicts waiting tensors and callbacks that have exceeded their deadline.
- void EnforceTimeout() {
- const uint64 now = Env::Default()->NowMicros();
- std::vector<WaitingCallback> evicted_callbacks;
-
- {
- mutex_lock ml(mu_);
-
- for (auto it = waiting_tensors_.begin(); it != waiting_tensors_.end();) {
- const WaitingTensor& waiting_tensor = it->second;
- if (waiting_tensor.deadline_micros < now) {
- it = waiting_tensors_.erase(it);
- } else {
- ++it;
- }
- }
-
- for (auto it = waiting_callbacks_.begin();
- it != waiting_callbacks_.end();) {
- const WaitingCallback& waiting_callback = it->second;
- if (waiting_callback.deadline_micros < now) {
- evicted_callbacks.push_back(waiting_callback);
- it = waiting_callbacks_.erase(it);
- } else {
- ++it;
- }
- }
- }
-
- for (const WaitingCallback& evicted_callback : evicted_callbacks) {
- evicted_callback.context->CtxFailureWithWarning(errors::DeadlineExceeded(
- "Batched data did not arrive within timeout window."));
- evicted_callback.done();
- }
- }
-
- struct WaitingTensor {
- uint64 deadline_micros;
- Tensor tensor;
- };
-
- struct WaitingCallback {
- uint64 deadline_micros;
- OpKernelContext* context;
- AsyncOpKernel::DoneCallback done;
- };
-
- const int32 timeout_micros_;
-
- mutex mu_;
-
- // Maps keyed by BatchKey of tensors waiting for callbacks and callbacks
- // waiting for tensors.
- std::unordered_map<int64, WaitingTensor> waiting_tensors_ GUARDED_BY(mu_);
- std::unordered_map<int64, WaitingCallback> waiting_callbacks_ GUARDED_BY(mu_);
-
- // A thread that evicts waiting tensors and callbacks that have exceeded their
- // deadline.
- std::unique_ptr<serving::PeriodicFunction> timeout_enforcer_;
-};
-
-class UnbatchKernel : public AsyncOpKernel {
- public:
- explicit UnbatchKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
- OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
- // If shared_name is not supplied, use name instead (prevent collisions by
- // default).
- if (shared_name_.empty()) {
- shared_name_ = name();
- }
- OP_REQUIRES_OK(c, c->GetAttr("timeout_micros", &timeout_micros_));
- }
-
- void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
- UnbatchResource* ubr;
- std::function<Status(UnbatchResource * *r)> creator =
- [this](UnbatchResource** r) {
- *r = new UnbatchResource(timeout_micros_);
- return Status::OK();
- };
- OP_REQUIRES_OK_ASYNC(c,
- c->resource_manager()->LookupOrCreate(
- container_, shared_name_, &ubr, creator),
- done);
- auto status = ubr->Compute(c, done);
- ubr->Unref();
- if (!status.ok()) {
- OP_REQUIRES_OK_ASYNC(c, status, done);
- }
- // Assume ubr calls done, so nothing to do here.
- }
-
- private:
- string container_;
- string shared_name_;
- int32 timeout_micros_;
-};
-REGISTER_KERNEL_BUILDER(Name("Unbatch").Device(DEVICE_CPU), UnbatchKernel);
-
-// A class encapsulating the state and logic for batching tensors
-// deterministically for the gradient of unbatch.
-class UnbatchGradResource : public ResourceBase {
- public:
- UnbatchGradResource() {}
-
- string DebugString() final { return "UnbatchGradResource"; }
-
- // Flushes the information for one batch, given its context and done
- // callback. Clears all information about it from the available_tensors_.
- Status OutputBatch(OpKernelContext* context,
- const AsyncOpKernel::DoneCallback& done)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const Tensor& batch_index_t = context->input(1);
- auto batch_index =
- batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
- std::vector<Tensor> tensors;
- for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
- auto available_it = available_tensors_.find(batch_index(i, 0));
- if (available_it == available_tensors_.end()) {
- return errors::Internal("bad bookkeeping of available tensors.");
- }
- tensors.push_back(available_it->second);
- available_tensors_.erase(available_it);
- }
-
- const DataType type = tensors[0].dtype();
- switch (type) {
-#define CASE(type) \
- case DataTypeToEnum<type>::value: \
- TF_RETURN_IF_ERROR(Concat<type>(context, tensors, 0)); \
- break;
- TF_CALL_ALL_TYPES(CASE);
-#undef CASE
- default:
- return errors::InvalidArgument("Unsupported data type: ", type);
- }
- done();
- return Status::OK();
- }
-
- // Ingests data from one invocation of the op.
- Status Compute(OpKernelContext* context,
- const AsyncOpKernel::DoneCallback& done) {
- const Tensor& data_t = context->input(0);
- const Tensor& batch_index_t = context->input(1);
- const Tensor& grad_t = context->input(2);
-
- mutex_lock ml(mu_);
-
- const int64 batch_key = context->input(3).scalar<int64>()();
- // Mark our tensor as available.
- if (!available_tensors_.emplace(batch_key, grad_t).second) {
- return errors::InvalidArgument("Two runs with the same batch key.");
- }
-
- // Check whether we have a valid input tensor and, if so, create its
- // dispatch logic.
- if (data_t.NumElements() > 0) {
- if (batch_index_t.NumElements() == 0) {
- return errors::InvalidArgument(
- "batch_index is empty while the tensor isn't.");
- }
- std::unordered_set<int64> missing_tensors;
- const auto batch_index =
- batch_index_t.shaped<int64, 2>({batch_index_t.dim_size(0), 3});
- for (int i = 0; i < batch_index_t.dim_size(0); ++i) {
- const int64 batch_key = batch_index(i, 0);
- if (available_tensors_.find(batch_key) == available_tensors_.end()) {
- missing_tensors.emplace(batch_key);
- }
- }
- if (missing_tensors.empty()) {
- return OutputBatch(context, done);
- }
- if (!available_batches_
- .emplace(batch_key, Batch{missing_tensors, context, done})
- .second) {
- return errors::InvalidArgument(
- "Batch key with valid batch used twice.");
- }
- for (const int64 i : missing_tensors) {
- if (!desired_tensor_to_batch_map_.emplace(i, batch_key).second) {
- return errors::InvalidArgument(
- "Missing tensor wanted by more than one batch.");
- }
- }
- } else {
- // If we don't have a valid input tensor we can output an empty tensor and
- // call our done closure.
- TensorShape output_shape(grad_t.shape());
- output_shape.set_dim(0, 0);
- Tensor* output = nullptr;
- TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, &output));
- done();
- }
-
- // Search to see whether our tensor is desired by any existing batch.
- auto desire_it = desired_tensor_to_batch_map_.find(batch_key);
- if (desire_it != desired_tensor_to_batch_map_.end()) {
- // Mark our tensor as no longer missing.
- auto batch_it = available_batches_.find(desire_it->second);
- desired_tensor_to_batch_map_.erase(desire_it);
- if (batch_it == available_batches_.end()) {
- return errors::InvalidArgument("Batch no longer exists.");
- }
- batch_it->second.missing_tensors.erase(batch_key);
- // If all tensors are available we should concatenate them and dispatch
- // the batch.
- if (batch_it->second.missing_tensors.empty()) {
- TF_RETURN_IF_ERROR(
- OutputBatch(batch_it->second.context, batch_it->second.done));
- available_batches_.erase(batch_it);
- }
- }
- return Status::OK();
- }
-
- private:
- mutex mu_;
-
- // Represents a still-incomplete batch of tensors. When all tensors become
- // available they will be concatenated in the right order and sent through the
- // context.
- struct Batch {
- // Batch keys for tensors which are still missing from this batch. When this
- // is empty the Tensors can be concatenated and forwarded.
- std::unordered_set<int64> missing_tensors;
-
- // Context and callback for the session responsible for finishing this
- // batch.
- OpKernelContext* context;
- AsyncOpKernel::DoneCallback done;
- };
-
- // Map from batch key of the session which will output the batched gradients
- // to still-incomplete batches.
- std::unordered_map<int64, Batch> available_batches_;
-
- // Map from batch key to tensors which are waiting for their batches to be
- // available.
- std::unordered_map<int64, Tensor> available_tensors_;
-
- // Map from batch key of a tensor which is not yet available to the batch key
- // of the batch to which it belongs.
- std::unordered_map<int64, int64> desired_tensor_to_batch_map_;
-};
-
-class UnbatchGradKernel : public AsyncOpKernel {
- public:
- explicit UnbatchGradKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {
- OP_REQUIRES_OK(c, c->GetAttr("container", &container_));
- OP_REQUIRES_OK(c, c->GetAttr("shared_name", &shared_name_));
- // If shared_name is not supplied, use name instead (prevent collisions by
- // default).
- if (shared_name_.empty()) {
- shared_name_ = name();
- }
- }
-
- void ComputeAsync(OpKernelContext* c, DoneCallback done) final {
- UnbatchGradResource* ubr;
- std::function<Status(UnbatchGradResource * *r)> creator =
- [this](UnbatchGradResource** r) {
- *r = new UnbatchGradResource();
- return Status::OK();
- };
- OP_REQUIRES_OK_ASYNC(c,
- c->resource_manager()->LookupOrCreate(
- container_, shared_name_, &ubr, creator),
- done);
- Status status = ubr->Compute(c, done);
- ubr->Unref();
- if (!status.ok()) {
- OP_REQUIRES_OK_ASYNC(c, status, done);
- }
- // Assume ubr calls done, so nothing to do here.
- }
-
- private:
- string container_;
- string shared_name_;
-};
-REGISTER_KERNEL_BUILDER(Name("UnbatchGrad").Device(DEVICE_CPU),
- UnbatchGradKernel);
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/ops/batch_ops.cc b/tensorflow/contrib/batching/ops/batch_ops.cc
deleted file mode 100644
index 85e0ccba4a..0000000000
--- a/tensorflow/contrib/batching/ops/batch_ops.cc
+++ /dev/null
@@ -1,164 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/shape_inference.h"
-
-namespace tensorflow {
-
-REGISTER_OP("Batch")
- .Input("in_tensors: T")
- .Output("batched_tensors: T")
- .Output("batch_index: int64")
- .Output("id: int64")
- .Attr("num_batch_threads: int")
- .Attr("max_batch_size: int")
- .Attr("batch_timeout_micros: int")
- .Attr("allowed_batch_sizes: list(int) = []")
- .Attr("grad_timeout_micros: int")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("batching_queue: string = ''")
- .Attr("T: list(type)")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<shape_inference::ShapeHandle> in_shapes;
- TF_RETURN_IF_ERROR(c->input("in_tensors", &in_shapes));
- std::vector<shape_inference::ShapeHandle> out_shapes(in_shapes.size());
- for (int i = 0; i < in_shapes.size(); ++i) {
- TF_RETURN_IF_ERROR(
- c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i]));
- }
- TF_RETURN_IF_ERROR(c->set_output("batched_tensors", out_shapes));
- TF_RETURN_IF_ERROR(c->set_output("id", {c->Scalar()}));
- TF_RETURN_IF_ERROR(c->set_output(
- "batch_index",
- {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()),
- shape_inference::DimensionOrConstant(3)})}));
- return Status::OK();
- })
- .Doc(R"doc(
-Batches all input tensors nondeterministically.
-
-When many instances of this Op are being run concurrently with the same
-container/shared_name in the same device, some will output zero-shaped Tensors
-and others will output Tensors of size up to max_batch_size.
-
-All Tensors in in_tensors are batched together (so, for example, labels and
-features should be batched with a single instance of this operation.
-
-Each invocation of batch emits an `id` scalar which will be used to identify
-this particular invocation when doing unbatch or its gradient.
-
-Each op which emits a non-empty batch will also emit a non-empty batch_index
-Tensor, which, is a [K, 3] matrix where each row contains the invocation's id,
-start, and length of elements of each set of Tensors present in batched_tensors.
-
-Batched tensors are concatenated along the first dimension, and all tensors in
-in_tensors must have the first dimension of the same size.
-
-in_tensors: The tensors to be batched.
-num_batch_threads: Number of scheduling threads for processing batches of work.
- Determines the number of batches processed in parallel.
-max_batch_size: Batch sizes will never be bigger than this.
-batch_timeout_micros: Maximum number of microseconds to wait before outputting
- an incomplete batch.
-allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does
- nothing. Otherwise, supplies a list of batch sizes, causing the op to pad
- batches up to one of those sizes. The entries must increase monotonically, and
- the final entry must equal max_batch_size.
-grad_timeout_micros: The timeout to use for the gradient. See Unbatch.
-batched_tensors: Either empty tensors or a batch of concatenated Tensors.
-batch_index: If out_tensors is non-empty, has information to invert it.
-container: Controls the scope of sharing of this batch.
-id: always contains a scalar with a unique ID for this invocation of Batch.
-shared_name: Concurrently running instances of batch in the same device with the
- same container and shared_name will batch their elements together. If left
- empty, the op name will be used as the shared name.
-T: the types of tensors to be batched.
-)doc");
-
-REGISTER_OP("Unbatch")
- .Input("batched_tensor: T")
- .Input("batch_index: int64")
- .Input("id: int64")
- .Output("unbatched_tensor: T")
- .Attr("timeout_micros: int")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("T: type")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle out_shape;
- TF_RETURN_IF_ERROR(
- c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape));
- c->set_output(0, out_shape);
- return Status::OK();
- })
- .Doc(R"doc(
-Reverses the operation of Batch for a single output Tensor.
-
-An instance of Unbatch either receives an empty batched_tensor, in which case it
-asynchronously waits until the values become available from a concurrently
-running instance of Unbatch with the same container and shared_name, or receives
-a non-empty batched_tensor in which case it finalizes all other concurrently
-running instances and outputs its own element from the batch.
-
-batched_tensor: The possibly transformed output of Batch. The size of the first
- dimension should remain unchanged by the transformations for the operation to
- work.
-batch_index: The matching batch_index obtained from Batch.
-id: The id scalar emitted by Batch.
-unbatched_tensor: The Tensor corresponding to this execution.
-timeout_micros: Maximum amount of time (in microseconds) to wait to receive the
- batched input tensor associated with a given invocation of the op.
-container: Container to control resource sharing.
-shared_name: Instances of Unbatch with the same container and shared_name are
- assumed to possibly belong to the same batch. If left empty, the op name will
- be used as the shared name.
-)doc");
-
-REGISTER_OP("UnbatchGrad")
- .Input("original_input: T")
- .Input("batch_index: int64")
- .Input("grad: T")
- .Input("id: int64")
- .Output("batched_grad: T")
- .Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Attr("T: type")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2))));
- return Status::OK();
- })
- .Doc(R"doc(
-Gradient of Unbatch.
-
-Acts like Batch but using the given batch_index index of batching things as they
-become available. This ensures that the gradients are propagated back in the
-same session which did the forward pass.
-
-original_input: The input to the Unbatch operation this is the gradient of.
-batch_index: The batch_index given to the Unbatch operation this is the gradient
-of.
-grad: The downstream gradient.
-id: The id scalar emitted by Batch.
-batched_grad: The return value, either an empty tensor or the batched gradient.
-container: Container to control resource sharing.
-shared_name: Instances of UnbatchGrad with the same container and shared_name
- are assumed to possibly belong to the same batch. If left empty, the op name
- will be used as the shared name.
- )doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py
index cee4d7b4a9..4e0b3f9af9 100644
--- a/tensorflow/contrib/batching/python/ops/batch_ops.py
+++ b/tensorflow/contrib/batching/python/ops/batch_ops.py
@@ -18,18 +18,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.batching.ops import gen_batch_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_batch_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
-from tensorflow.contrib.batching.ops.gen_batch_ops import *
+from tensorflow.python.ops.gen_batch_ops import *
# pylint: enable=wildcard-import
-from tensorflow.contrib.util import loader
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import resource_loader
-
-
-_batch_ops = loader.load_op_library(
- resource_loader.get_path_to_datafile("_batch_ops.so"))
@ops.RegisterGradient("Batch")