diff options
Diffstat (limited to 'tensorflow/python/ops/attention_ops.py')
-rw-r--r-- | tensorflow/python/ops/attention_ops.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/python/ops/attention_ops.py b/tensorflow/python/ops/attention_ops.py new file mode 100644 index 0000000000..4829bcd7cd --- /dev/null +++ b/tensorflow/python/ops/attention_ops.py @@ -0,0 +1,34 @@ +"""Operations for implementing attention. +""" +import tensorflow.python.platform + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import gen_attention_ops +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_attention_ops import * + + +# TODO(bsteiner): Implement the gradient function for extract_glimpse +ops.NoGradient("ExtractGlimpse") + + +@ops.RegisterShape("ExtractGlimpse") +def _ExtractGlimpseShape(op): + """Shape function for ExtractGlimpse op.""" + input_shape = op.inputs[0].get_shape().with_rank(4) + unused_size_shape = op.inputs[1].get_shape().merge_with( + tensor_shape.vector(2)) + offsets_shape = op.inputs[2].get_shape().merge_with( + input_shape[:1].concatenate([2])) + offsets_shape = offsets_shape + size_value = tensor_util.ConstantValue(op.inputs[1]) + if size_value is not None: + height = size_value[0] + width = size_value[1] + else: + height = None + width = None + return [tensor_shape.TensorShape( + [input_shape[0], height, width, input_shape[3]])] |