aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-23 13:33:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 14:46:32 -0700
commit011402d8987e753acd54c6251a7edd2e2d8155ba (patch)
treed718a73c7e5d2f340f1ab07079d2bce711b30237 /tensorflow/core/framework/shape_inference.cc
parentd0550db5736f484c12ac7f52dfaf2aa581d3170f (diff)
Allow a python shape inference fn to delegate to the cpp shape
inference function. Enable this for MatMul and SparseMatMul. Change: 131097313
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 4300784ffe..c6da445165 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -49,6 +49,25 @@ InferenceContext::InferenceContext(
InferenceContext::InferenceContext(
const NodeDef* node_def, const OpDef& op_def,
const std::vector<string>& input_shapes_string,
+ const std::vector<TensorShapeProto>& input_shapes,
+ const std::vector<const Tensor*>& input_tensors)
+ : node_def_(*CHECK_NOTNULL(node_def)) {
+ PreInputInit(op_def, input_tensors);
+ if (!construction_status_.ok()) return;
+ for (const TensorShapeProto& p : input_shapes) {
+ const Shape* shape;
+ construction_status_.Update(MakeShapeFromShapeProto(p, &shape));
+ if (!construction_status_.ok()) {
+ return;
+ }
+ inputs_.push_back(shape);
+ }
+ PostInputInit();
+}
+
+InferenceContext::InferenceContext(
+ const NodeDef* node_def, const OpDef& op_def,
+ const std::vector<string>& input_shapes_string,
const std::vector<const Shape*>& input_shapes,
const std::vector<const Tensor*>& input_tensors)
: node_def_(*CHECK_NOTNULL(node_def)) {