aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/candidate_sampling_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/candidate_sampling_ops.cc')
-rw-r--r--tensorflow/core/ops/candidate_sampling_ops.cc5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc
index 6e4d100b04..6e589c8d1c 100644
--- a/tensorflow/core/ops/candidate_sampling_ops.cc
+++ b/tensorflow/core/ops/candidate_sampling_ops.cc
@@ -145,12 +145,15 @@ REGISTER_OP("ComputeAccidentalHits")
int64 num_true;
TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true));
- // Validate true_classes.
+ // Validate true_classes, must be a matrix.
ShapeHandle true_classes;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes));
DimensionHandle unused;
TF_RETURN_IF_ERROR(
c->WithValue(c->Dim(true_classes, 1), num_true, &unused));
+ // Validate sampled_candidates, must be a vector.
+ ShapeHandle sampled_candidates;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates));
// All three outputs are the same shape.
ShapeHandle v = c->Vector(InferenceContext::kUnknownDim);