1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
|
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tensorflow
// #include <stdlib.h>
// #include "tensorflow/c/c_api.h"
import "C"
import "unsafe"
// Operation that has been added to the graph.
type Operation struct {
c *C.TF_Operation
}
// Output represents one of the outputs of an operation in the graph. Has a
// DataType (and eventually a Shape). May be passed as an input argument to a
// function for adding operations to a graph, or to a Session's Run() method to
// fetch that output as a tensor.
type Output struct {
// Op is the Operation that produces this Output.
Op *Operation
// Index specifies the index of the output within the Operation.
Index int
}
func (p *Output) c() C.TF_Port {
return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)}
}
// opBuilder is for use by the generated op code to create new Operations.
// Build() must be called for any in-progress Operation, or else we leak.
type opBuilder struct {
c *C.TF_OperationDescription
}
func newOpBuilder(g *Graph, typ string, name string) *opBuilder {
opType := C.CString(typ)
opName := C.CString(name)
b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName)}
C.free(unsafe.Pointer(opType))
C.free(unsafe.Pointer(opName))
return b
}
func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error {
status := newStatus()
attrName := C.CString(name)
C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c)
C.free(unsafe.Pointer(attrName))
return status.Err()
}
func (b *opBuilder) SetAttrType(name string, typ DataType) {
attrName := C.CString(name)
C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ))
C.free(unsafe.Pointer(attrName))
}
func (b *opBuilder) AddInput(port Output) {
C.TF_AddInput(b.c, port.c())
}
func (b *opBuilder) Build() (*Operation, error) {
status := newStatus()
op := &Operation{c: C.TF_FinishOperation(b.c, status.c)}
if err := status.Err(); err != nil {
return nil, err
}
b.c = nil
return op, nil
}
|