aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
authorGravatar Christopher Olston <olston@google.com>2017-04-27 13:32:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-27 14:47:43 -0700
commitfa8381593d0cbe354cb54d691e0a8c42bf4b69d0 (patch)
treec8d94d8dc252f8fa47ef8a498fa972b9fa63a0c3 /tensorflow/contrib/batching
parent80581f852c784d8dc1aa937d4f5193a79e236a0a (diff)
Move Batch op input validation before enqueueing to the batch scheduler, because earlier error detection is better (and also so batch->size() doesn't crash if #dims==0 :).
Change: 154470659
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/kernels/batch_kernels.cc28
1 files changed, 12 insertions, 16 deletions
diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc
index 5598d11f75..7f9fc447b1 100644
--- a/tensorflow/contrib/batching/kernels/batch_kernels.cc
+++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc
@@ -239,7 +239,18 @@ class BatchResource : public ResourceBase {
OpInputList tensors;
TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
for (int i = 0; i < tensors.size(); ++i) {
- batch_components->inputs.push_back(tensors[i]);
+ const Tensor& tensor = tensors[i];
+ if (tensor.shape().dims() == 0) {
+ return errors::InvalidArgument(
+ "Batching input tensors must have at least one dimension");
+ }
+ if (tensors.size() >= 2 &&
+ tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
+ return errors::InvalidArgument(
+ "Batching input tensors supplied in a given op invocation must "
+ "have equal 0th-dimension size");
+ }
+ batch_components->inputs.push_back(tensor);
}
batch_components->context = context;
batch_components->done_callback = std::move(done_callback);
@@ -279,21 +290,6 @@ class BatchResource : public ResourceBase {
return errors::InvalidArgument(
"Batching inputs must have equal number of edges");
}
-
- for (int edge = 0; edge < task.inputs.size(); ++edge) {
- const Tensor& tensor = task.inputs[edge];
-
- if (tensor.shape().dims() == 0) {
- return errors::InvalidArgument(
- "Batching input tensors must have at least one dimension");
- }
-
- if (tensor.shape().dim_size(0) != task.inputs[0].shape().dim_size(0)) {
- return errors::InvalidArgument(
- "Batching input tensors supplied in a given op invocation must "
- "have equal 0th-dimension size");
- }
- }
}
return Status::OK();