diff options
author | Christopher Olston <olston@google.com> | 2017-04-27 13:32:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-27 14:47:43 -0700 |
commit | fa8381593d0cbe354cb54d691e0a8c42bf4b69d0 (patch) | |
tree | c8d94d8dc252f8fa47ef8a498fa972b9fa63a0c3 /tensorflow/contrib/batching | |
parent | 80581f852c784d8dc1aa937d4f5193a79e236a0a (diff) |
Move Batch op input validation before enqueueing to the batch scheduler, because earlier error detection is better (and also so batch->size() doesn't crash if #dims==0 :).
Change: 154470659
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r-- | tensorflow/contrib/batching/kernels/batch_kernels.cc | 28 |
1 files changed, 12 insertions, 16 deletions
diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc index 5598d11f75..7f9fc447b1 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc @@ -239,7 +239,18 @@ class BatchResource : public ResourceBase { OpInputList tensors; TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors)); for (int i = 0; i < tensors.size(); ++i) { - batch_components->inputs.push_back(tensors[i]); + const Tensor& tensor = tensors[i]; + if (tensor.shape().dims() == 0) { + return errors::InvalidArgument( + "Batching input tensors must have at least one dimension"); + } + if (tensors.size() >= 2 && + tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) { + return errors::InvalidArgument( + "Batching input tensors supplied in a given op invocation must " + "have equal 0th-dimension size"); + } + batch_components->inputs.push_back(tensor); } batch_components->context = context; batch_components->done_callback = std::move(done_callback); @@ -279,21 +290,6 @@ class BatchResource : public ResourceBase { return errors::InvalidArgument( "Batching inputs must have equal number of edges"); } - - for (int edge = 0; edge < task.inputs.size(); ++edge) { - const Tensor& tensor = task.inputs[edge]; - - if (tensor.shape().dims() == 0) { - return errors::InvalidArgument( - "Batching input tensors must have at least one dimension"); - } - - if (tensor.shape().dim_size(0) != task.inputs[0].shape().dim_size(0)) { - return errors::InvalidArgument( - "Batching input tensors supplied in a given op invocation must " - "have equal 0th-dimension size"); - } - } } return Status::OK(); |