aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/batching
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/batching')
-rw-r--r--tensorflow/contrib/batching/kernels/batch_kernels.cc28
1 files changed, 12 insertions, 16 deletions
diff --git a/tensorflow/contrib/batching/kernels/batch_kernels.cc b/tensorflow/contrib/batching/kernels/batch_kernels.cc
index 5598d11f75..7f9fc447b1 100644
--- a/tensorflow/contrib/batching/kernels/batch_kernels.cc
+++ b/tensorflow/contrib/batching/kernels/batch_kernels.cc
@@ -239,7 +239,18 @@ class BatchResource : public ResourceBase {
OpInputList tensors;
TF_RETURN_IF_ERROR(context->input_list("in_tensors", &tensors));
for (int i = 0; i < tensors.size(); ++i) {
- batch_components->inputs.push_back(tensors[i]);
+ const Tensor& tensor = tensors[i];
+ if (tensor.shape().dims() == 0) {
+ return errors::InvalidArgument(
+ "Batching input tensors must have at least one dimension");
+ }
+ if (tensors.size() >= 2 &&
+ tensor.shape().dim_size(0) != tensors[0].shape().dim_size(0)) {
+ return errors::InvalidArgument(
+ "Batching input tensors supplied in a given op invocation must "
+ "have equal 0th-dimension size");
+ }
+ batch_components->inputs.push_back(tensor);
}
batch_components->context = context;
batch_components->done_callback = std::move(done_callback);
@@ -279,21 +290,6 @@ class BatchResource : public ResourceBase {
return errors::InvalidArgument(
"Batching inputs must have equal number of edges");
}
-
- for (int edge = 0; edge < task.inputs.size(); ++edge) {
- const Tensor& tensor = task.inputs[edge];
-
- if (tensor.shape().dims() == 0) {
- return errors::InvalidArgument(
- "Batching input tensors must have at least one dimension");
- }
-
- if (tensor.shape().dim_size(0) != task.inputs[0].shape().dim_size(0)) {
- return errors::InvalidArgument(
- "Batching input tensors supplied in a given op invocation must "
- "have equal 0th-dimension size");
- }
- }
}
return Status::OK();