aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/c
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-09-19 14:54:07 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-19 14:57:27 -0700
commitc3014ec19e23e4aad7286b3fac6b25a5fb4a6326 (patch)
tree1ffab9c512fe71884a19851b63c1025d487dbe3e /tensorflow/c
parent4e7d5f008be62bb7ca3e1646af8d4f22287d9e50 (diff)
Allow the tape tensor to have unknown shapes.
This is done by making the TapeTensor a template rather than a concrete struct. PiperOrigin-RevId: 213700425
Diffstat (limited to 'tensorflow/c')
-rw-r--r--tensorflow/c/eager/tape.h118
1 files changed, 56 insertions, 62 deletions
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 49990b6249..41b5b8ff36 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -29,15 +29,8 @@ limitations under the License.
namespace tensorflow {
namespace eager {
-// Information about a tensor.
-struct TapeTensor {
- int64 id; // Expected to be unique in the lifetime of this process.
- DataType dtype;
- TensorShape shape;
-};
-
// Represents an entry in the tape.
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct OpTapeEntry {
string op_type;
std::vector<TapeTensor> output_tensor_info;
@@ -57,8 +50,8 @@ struct OpTapeEntry {
using TensorTape = gtl::FlatMap<int64, int64>;
// Map from operation-id to tape entry.
-template <typename BackwardFunction>
-using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
+template <typename BackwardFunction, typename TapeTensor>
+using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction, TapeTensor>>;
// Operations the tape needs to perform on tensors to do backpropagation. Named
// "vspace" because a subset of these are related to a vector space, such as
@@ -79,7 +72,7 @@ using OpTape = gtl::FlatMap<int64, OpTapeEntry<BackwardFunction>>;
// TODO(apassos) provide concrete template instantiations for TFE_TensorHandle
// specialization, which is blocked by quite a few things needing to loop back
// into python now.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class VSpace {
public:
virtual ~VSpace() {}
@@ -93,10 +86,10 @@ class VSpace {
gtl::ArraySlice<Gradient*> gradient_tensors) const = 0;
// Returns a tensor of the right shape and dtype filled with zeros.
- virtual Gradient* Zeros(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Zeros(const TapeTensor& tensor) const = 0;
// Returns a Tensor which is filled with ones and like the input.
- virtual Gradient* Ones(TensorShape shape, DataType dtype) const = 0;
+ virtual Gradient* Ones(const TapeTensor& tensor) const = 0;
// Calls the passed-in backward function.
virtual Status CallBackwardFunction(
@@ -114,7 +107,7 @@ class VSpace {
// Traces the execution of operations, doing eager garbage collection, and
// exporting a full trace so other code can do backpropagation. Not thread-safe.
-template <typename Gradient, typename BackwardFunction>
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
class GradientTape {
public:
// If `persistent` is true, GradientTape will not eagerly delete backward
@@ -134,7 +127,7 @@ class GradientTape {
void Watch(int64 tensor_id);
void RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -146,17 +139,18 @@ class GradientTape {
// once) and produces the gradient of the target tensors with respect to the
// source tensors. The output gradients are used if not empty and not
// null. The result is populated with one tensor per target element.
- Status ComputeGradient(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<int64> source_tensor_id,
- gtl::ArraySlice<Gradient*> output_gradients,
- std::vector<Gradient*>* result);
+ Status ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<int64> source_tensor_id,
+ gtl::ArraySlice<Gradient*> output_gradients,
+ std::vector<Gradient*>* result);
bool IsPersistent() const { return persistent_; }
private:
TensorTape tensor_tape_;
- OpTape<BackwardFunction> op_tape_;
+ OpTape<BackwardFunction, TapeTensor> op_tape_;
int64 next_op_id_{0};
// Map from tensor id to number of remaining usages (i.e. how many entries in
@@ -186,8 +180,8 @@ inline bool IsDtypeTrainable(DataType dtype) {
}
}
-template <typename Gradient, typename BackwardFunction>
-bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+bool GradientTape<Gradient, BackwardFunction, TapeTensor>::ShouldRecord(
gtl::ArraySlice<int64> tensor_ids,
gtl::ArraySlice<tensorflow::DataType> dtypes) {
CHECK_EQ(tensor_ids.size(), dtypes.size());
@@ -201,14 +195,15 @@ bool GradientTape<Gradient, BackwardFunction>::ShouldRecord(
return false;
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::Watch(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::Watch(
+ int64 tensor_id) {
tensor_tape_.emplace(tensor_id, -1);
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::RecordOperation(
- const string& op_type, gtl::ArraySlice<TapeTensor> output_tensors,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
+ const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
BackwardFunction* backward_function,
@@ -229,16 +224,18 @@ void GradientTape<Gradient, BackwardFunction>::RecordOperation(
for (const TapeTensor& o : output_tensors) {
// Note: the tensor can have already been watched and hence be in the tape,
// so we cannot check that we're inserting it here.
- tensor_tape_[o.id] = op_id;
- tensor_usage_[o.id] = 1;
+ tensor_tape_[o.GetID()] = op_id;
+ tensor_usage_[o.GetID()] = 1;
tensors.push_back(o);
}
- op_tape_[op_id] = OpTapeEntry<BackwardFunction>{
- op_type, tensors, ids, backward_function, backward_function_deleter};
+ op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
+ op_type, std::move(tensors), ids, backward_function,
+ backward_function_deleter};
}
-template <typename Gradient, typename BackwardFunction>
-void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+void GradientTape<Gradient, BackwardFunction, TapeTensor>::DeleteTrace(
+ int64 tensor_id) {
auto it = tensor_usage_.find(tensor_id);
if (it == tensor_usage_.end()) {
return;
@@ -261,7 +258,7 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
auto op_it = op_tape_.find(op_id);
CHECK(op_it != op_tape_.end());
for (const auto& output : op_it->second.output_tensor_info) {
- if (tensor_usage_.find(output.id) != tensor_usage_.end()) {
+ if (tensor_usage_.find(output.GetID()) != tensor_usage_.end()) {
// Found a usage for an output, so cannot delete the op.
return;
}
@@ -304,9 +301,9 @@ void GradientTape<Gradient, BackwardFunction>::DeleteTrace(int64 tensor_id) {
namespace {
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
struct BackpropInitialState {
- OpTape<BackwardFunction> op_tape;
+ OpTape<BackwardFunction, TapeTensor> op_tape;
// Map from tensor ID to how many references still exist for this tensor in
// the tape.
@@ -322,17 +319,17 @@ struct BackpropInitialState {
// If `persistent_tape` is false, op_tape is cleared and backwards functions
// not needed for gradient computation are deleted. Backwards functions that
// are needed, are copied and returned in BackpropInitialState.
-template <typename BackwardFunction>
-BackpropInitialState<BackwardFunction> PrepareBackprop(
+template <typename BackwardFunction, typename TapeTensor>
+BackpropInitialState<BackwardFunction, TapeTensor> PrepareBackprop(
gtl::ArraySlice<int64> target, const TensorTape& tensor_tape,
- OpTape<BackwardFunction>* op_tape, const gtl::FlatSet<int64>& sources_set,
- bool persistent_tape) {
+ OpTape<BackwardFunction, TapeTensor>* op_tape,
+ const gtl::FlatSet<int64>& sources_set, bool persistent_tape) {
std::vector<int64> tensor_stack;
tensor_stack.reserve(target.size());
for (auto t : target) {
tensor_stack.push_back(t);
}
- BackpropInitialState<BackwardFunction> result;
+ BackpropInitialState<BackwardFunction, TapeTensor> result;
while (!tensor_stack.empty()) {
int64 tensor_id = tensor_stack.back();
tensor_stack.pop_back();
@@ -383,9 +380,9 @@ BackpropInitialState<BackwardFunction> PrepareBackprop(
return result;
}
-template <typename BackwardFunction>
+template <typename BackwardFunction, typename TapeTensor>
std::vector<int64> InitialStack(
- const OpTape<BackwardFunction>& op_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
const gtl::FlatMap<int64, int64>& op_missing_tensor) {
std::vector<int64> result;
for (auto& op_entry : op_tape) {
@@ -396,13 +393,13 @@ std::vector<int64> InitialStack(
return result;
}
-template <typename Gradient, typename BackwardFunction>
-Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
- gtl::ArraySlice<int64> target_tensor_ids,
- gtl::ArraySlice<Gradient*> output_gradients,
- const TensorTape& tensor_tape,
- const OpTape<BackwardFunction>& op_tape,
- gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status InitialGradients(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
+ gtl::ArraySlice<int64> target_tensor_ids,
+ gtl::ArraySlice<Gradient*> output_gradients, const TensorTape& tensor_tape,
+ const OpTape<BackwardFunction, TapeTensor>& op_tape,
+ gtl::FlatMap<int64, std::vector<Gradient*>>* result) {
for (int i = 0; i < target_tensor_ids.size(); ++i) {
const int64 id = target_tensor_ids[i];
if (output_gradients.empty() || output_gradients[i] == nullptr) {
@@ -416,11 +413,10 @@ Status InitialGradients(const VSpace<Gradient, BackwardFunction>& vspace,
}
bool found = false;
for (int j = 0; j < op_it->second.output_tensor_info.size(); ++j) {
- if (op_it->second.output_tensor_info[j].id == id) {
+ if (op_it->second.output_tensor_info[j].GetID() == id) {
found = true;
(*result)[id].push_back(
- vspace.Ones(op_it->second.output_tensor_info[j].shape,
- op_it->second.output_tensor_info[j].dtype));
+ vspace.Ones(op_it->second.output_tensor_info[j]));
break;
}
}
@@ -469,16 +465,16 @@ gtl::FlatMap<string, gtl::FlatSet<int>>* FunctionsAcceptingNoneForIndicesMap() {
constexpr int kMinAggregateCount = 4;
constexpr int kMinAggregateBytes = 128 * 1024 * 1024;
-template <typename Gradient, typename BackwardFunction>
-Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
- const VSpace<Gradient, BackwardFunction>& vspace,
+template <typename Gradient, typename BackwardFunction, typename TapeTensor>
+Status GradientTape<Gradient, BackwardFunction, TapeTensor>::ComputeGradient(
+ const VSpace<Gradient, BackwardFunction, TapeTensor>& vspace,
gtl::ArraySlice<int64> target_tensor_ids,
gtl::ArraySlice<int64> source_tensor_ids,
gtl::ArraySlice<Gradient*> output_gradients,
std::vector<Gradient*>* result) {
gtl::FlatSet<int64> sources_set(source_tensor_ids.begin(),
source_tensor_ids.end());
- BackpropInitialState<BackwardFunction> state = PrepareBackprop(
+ BackpropInitialState<BackwardFunction, TapeTensor> state = PrepareBackprop(
target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_);
std::vector<int64> op_stack =
InitialStack(state.op_tape, state.op_missing_tensor);
@@ -522,7 +518,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
out_gradients.reserve(trace.output_tensor_info.size());
bool any_gradient_nonzero = false;
for (int i = 0; i < trace.output_tensor_info.size(); ++i) {
- const int64 id = trace.output_tensor_info[i].id;
+ const int64 id = trace.output_tensor_info[i].GetID();
auto grad_it = gradients.find(id);
if (grad_it == gradients.end()) {
auto func_name_it =
@@ -531,9 +527,7 @@ Status GradientTape<Gradient, BackwardFunction>::ComputeGradient(
func_name_it->second.find(i) != func_name_it->second.end()) {
out_gradients.push_back(nullptr);
} else {
- out_gradients.push_back(
- vspace.Zeros(trace.output_tensor_info[i].shape,
- trace.output_tensor_info[i].dtype));
+ out_gradients.push_back(vspace.Zeros(trace.output_tensor_info[i]));
}
} else {
any_gradient_nonzero = true;