diff options
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r-- | tensorflow/contrib/batching/kernels/batch_kernels.cc | 28 |
1 files changed, 12 insertions, 16 deletions
diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc index 5598d11f75..7f9fc447b1 100644 --- a/tensorflow/contrib/batching/kernels/batch_kernels.cc +++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc @@ -239,7 +239,18 @@ class BatchResource : public ResourceBase { OpInputList tensors; TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors)); for (int i = 0; i < tensors.size(); ++i) { - batch_components->inputs.push_back(tensors[i]); + const Tensor& tensor = tensors[i]; + if (tensor.shape().dims() == 0) { + return errors::InvalidArgument( + "Batching input tensors must have at least one dimension"); + } + if (tensors.size() >= 2 && + tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) { + return errors::InvalidArgument( + "Batching input tensors supplied in a given op invocation must " + "have equal 0th-dimension size"); + } + batch_components->inputs.push_back(tensor); } batch_components->context = context; batch_components->done_callback = std::move(done_callback); @@ -279,21 +290,6 @@ class BatchResource : public ResourceBase { return errors::InvalidArgument( "Batching inputs must have equal number of edges"); } - - for (int edge = 0; edge < task.inputs.size(); ++edge) { - const Tensor& tensor = task.inputs[edge]; - - if (tensor.shape().dims() == 0) { - return errors::InvalidArgument( - "Batching input tensors must have at least one dimension"); - } - - if (tensor.shape().dim_size(0) != task.inputs[0].shape().dim_size(0)) { - return errors::InvalidArgument( - "Batching input tensors supplied in a given op invocation must " - "have equal 0th-dimension size"); - } - } } return Status::OK(); |