aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Max Galkin <maxgalkin@google.com>2017-09-26 11:08:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-26 11:12:37 -0700
commit272a2c86ab4a040c4dd08933e4272b0cd5458ebb (patch)
tree6210e360b920fd531cf3caaa3e955f0bf66888a9 /tensorflow/core/framework/shape_inference.cc
parent202d7e812ebcb2a88fc44cba145dbde560b31ffe (diff)
Shape inference for user-defined functions in TF. For now it is completely "opt-in" via ShapeRefiner API and it doesn't yet affect any existing validation and inferences anywhere. Eventually graph validation should start using it.
Doesn't yet support recursive functions and doesn't yet support more complex shape propagation scenarios where several iterations may be needed to infer shapes. PiperOrigin-RevId: 170078811
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc39
1 files changed, 30 insertions, 9 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index ca6eb5b7fb..ffa235d15c 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -38,7 +38,7 @@ InferenceContext::InferenceContext(
std::unique_ptr<std::vector<std::pair<TensorShapeProto, DataType>>>>&
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
- node_def_(*CHECK_NOTNULL(node_def)) {
+ node_def_(CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
for (const TensorShapeProto& p : input_tensors_as_shapes) {
ShapeHandle shape;
@@ -58,6 +58,7 @@ InferenceContext::InferenceContext(
}
inputs_.push_back(shape);
}
+
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> handle_data(
input_shapes.size());
for (int i = 0; i < input_handle_shapes_and_types.size(); ++i) {
@@ -90,7 +91,7 @@ InferenceContext::InferenceContext(
std::unique_ptr<std::vector<std::pair<PartialTensorShape, DataType>>>>&
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
- node_def_(*CHECK_NOTNULL(node_def)) {
+ node_def_(CHECK_NOTNULL(node_def)) {
std::vector<ShapeHandle> input_tensors_as_shape_handles;
for (const PartialTensorShape& p : input_tensors_as_shapes) {
ShapeHandle shape;
@@ -140,7 +141,7 @@ InferenceContext::InferenceContext(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
input_handle_shapes_and_types)
: graph_def_version_(graph_def_version),
- node_def_(*CHECK_NOTNULL(node_def)) {
+ node_def_(CHECK_NOTNULL(node_def)) {
PreInputInit(op_def, input_tensors, input_tensors_as_shapes);
if (!construction_status_.ok()) return;
inputs_ = input_shapes;
@@ -159,7 +160,7 @@ Status InferenceContext::Run(
#ifndef NDEBUG
for (int i = 0; i < num_outputs(); ++i) {
DCHECK(output(i).IsSet())
- << i << " for " << node_def_.name() << " of type " << node_def_.op();
+ << i << " for " << node_def_->name() << " of type " << node_def_->op();
}
#endif // NDEBUG
return s;
@@ -212,14 +213,16 @@ Status InferenceContext::output(StringPiece output_name,
return Status::OK();
}
+string InferenceContext::op() const { return node_def_->op(); }
+
void InferenceContext::PreInputInit(
const OpDef& op_def, const std::vector<const Tensor*>& input_tensors,
const std::vector<ShapeHandle>& input_tensors_as_shapes) {
input_tensors_ = input_tensors;
input_tensors_as_shapes_ = input_tensors_as_shapes;
- construction_status_ =
- NameRangesForNode(node_def_, op_def, &input_name_map_, &output_name_map_);
+ construction_status_ = NameRangesForNode(*node_def_, op_def, &input_name_map_,
+ &output_name_map_);
if (!construction_status_.ok()) return;
int num_outputs = 0;
@@ -266,6 +269,24 @@ void InferenceContext::PostInputInit(
requested_input_tensor_as_partial_shape_.resize(inputs_.size());
}
+void InferenceContext::ShapeHandleToProto(ShapeHandle handle,
+ TensorShapeProto* proto) {
+ if (!RankKnown(handle)) {
+ proto->set_unknown_rank(true);
+ return;
+ }
+
+ for (int32 i = 0; i < Rank(handle); ++i) {
+ DimensionHandle dim = Dim(handle, i);
+ auto* dim_shape = proto->add_dim();
+ if (ValueKnown(dim)) {
+ dim_shape->set_size(Value(dim));
+ } else {
+ dim_shape->set_size(-1);
+ }
+ }
+}
+
bool InferenceContext::FullyDefined(ShapeHandle s) {
if (!RankKnown(s)) return false;
for (int i = 0; i < Rank(s); ++i) {
@@ -302,7 +323,7 @@ string InferenceContext::DebugString(DimensionHandle d) {
string InferenceContext::DebugString() const {
return strings::StrCat("InferenceContext for node: ",
- ProtoDebugString(node_def_));
+ ProtoDebugString(*node_def_));
}
Status InferenceContext::WithRank(ShapeHandle shape, int64 rank,
@@ -642,7 +663,7 @@ ShapeHandle InferenceContext::UnknownShape() {
ShapeHandle InferenceContext::UnknownShapeOfRank(int64 rank) {
CHECK_LE(rank, kint32max) << "rank must be less than kint32max";
- if(rank == kUnknownRank) {
+ if (rank == kUnknownRank) {
return UnknownShape();
}
CHECK_GE(rank, 0) << "rank must not be negative";
@@ -994,7 +1015,7 @@ Status InferenceContext::AttachContext(const Status& status) {
}
string error_context = strings::StrCat(
- " for '", node_def_.name(), "' (op: '", node_def_.op(),
+ " for '", node_def_->name(), "' (op: '", node_def_->op(),
"') with input shapes: ", str_util::Join(input_shapes, ", "));
if (!input_from_tensors_str.empty()) {
strings::StrAppend(&error_context, " and with computed input tensors: ",