aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-08 13:50:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 13:58:40 -0700
commiteec9ca8f0baccd249a49046fe31b460903e44850 (patch)
treeb6397af544af7c05abca4bea08bd6354f90bedf1 /tensorflow/core/framework
parent494bbdfced3fd8596721d12e73676c4967f452e4 (diff)
Partial support tfe.defun in tf.gradients.
Doesn't attempt to deal with cases where we might have already generated the functiondef for the parent function as in that case we cannot easily modify the forward pass. PiperOrigin-RevId: 216243224
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
2 files changed, 15 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {