diff options
Diffstat (limited to 'tensorflow/core/framework/shape_inference.h')
-rw-r--r-- | tensorflow/core/framework/shape_inference.h | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index bb6a66dc53..6385177bc1 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -17,6 +17,8 @@ limitations under the License. #include <vector> +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -80,7 +82,10 @@ class InferenceContext { // the same Dimension*. // // <input_tensors> is NULL-padded to be the same size as <input_shapes>. - InferenceContext(const std::vector<string>& input_shapes, int num_outputs, + // + // REQUIRES: <node_def> is not NULL, and must outlive the InferenceContext. + InferenceContext(const NodeDef* node_def, + const std::vector<string>& input_shapes, int num_outputs, const std::vector<const Tensor*>& input_tensors = {}); ~InferenceContext(); @@ -162,6 +167,12 @@ class InferenceContext { const Dimension* CreateDim(int64 value); const Dimension* CreateUnknownDim(); + // Look up the attr for the NodeDef being evaluated with name attr_name and + // set *value to its value. If no attr with attr_name is found in def(), or + // the attr does not have a matching type, a non-ok status will be returned. + template <class T> + Status GetAttr(StringPiece attr_name, T* value) const; + private: Status ReturnUnknownShape(const Shape** out) { *out = CreateUnknownShape(); @@ -181,9 +192,14 @@ class InferenceContext { std::vector<const Tensor*> input_tensors_; std::vector<const Shape*> outputs_; + const NodeDef& node_def_; + TF_DISALLOW_COPY_AND_ASSIGN(InferenceContext); }; +// ----------------------------------------------------------------------------- +// Template and inline method implementations, please ignore + inline Dimension::Dimension() : value_(InferenceContext::kUnknownDim) {} inline Dimension::Dimension(int64 value) : value_(value) {} @@ -191,6 +207,11 @@ inline Shape::Shape() : rank_(InferenceContext::kUnknownRank) {} inline Shape::Shape(const std::vector<const Dimension*> dims) : rank_(dims.size()), dims_(dims) {} +template <class T> +Status InferenceContext::GetAttr(StringPiece attr_name, T* value) const { + return GetNodeAttr(node_def_, attr_name, value); +} + } // namespace shape_inference } // namespace tensorflow |