diff options
Diffstat (limited to 'tensorflow/core/ops/candidate_sampling_ops.cc')
-rw-r--r-- | tensorflow/core/ops/candidate_sampling_ops.cc | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/core/ops/candidate_sampling_ops.cc b/tensorflow/core/ops/candidate_sampling_ops.cc index 6e4d100b04..6e589c8d1c 100644 --- a/tensorflow/core/ops/candidate_sampling_ops.cc +++ b/tensorflow/core/ops/candidate_sampling_ops.cc @@ -145,12 +145,15 @@ REGISTER_OP("ComputeAccidentalHits") int64 num_true; TF_RETURN_IF_ERROR(c->GetAttr("num_true", &num_true)); - // Validate true_classes. + // Validate true_classes, must be a matrix. ShapeHandle true_classes; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &true_classes)); DimensionHandle unused; TF_RETURN_IF_ERROR( c->WithValue(c->Dim(true_classes, 1), num_true, &unused)); + // Validate sampled_candidates, must be a vector. + ShapeHandle sampled_candidates; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &sampled_candidates)); // All three outputs are the same shape. ShapeHandle v = c->Vector(InferenceContext::kUnknownDim); |