aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/sksl/ir/SkSLSymbolTable.cpp
blob: 6d8e9a7ea62cf5a113165be9fa3de2c2105eeccb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/*
 * Copyright 2016 Google Inc.
 *
 * Use of this source code is governed by a BSD-style license that can be
 * found in the LICENSE file.
 */

#include "SkSLSymbolTable.h"
#include "SkSLUnresolvedFunction.h"

namespace SkSL {

std::vector<const FunctionDeclaration*> SymbolTable::GetFunctions(const Symbol& s) {
    switch (s.fKind) {
        case Symbol::kFunctionDeclaration_Kind:
            return { &((FunctionDeclaration&) s) };
        case Symbol::kUnresolvedFunction_Kind:
            return ((UnresolvedFunction&) s).fFunctions;
        default:
            return std::vector<const FunctionDeclaration*>();
    }
}

const Symbol* SymbolTable::operator[](const std::string& name) {
    const auto& entry = fSymbols.find(name);
    if (entry == fSymbols.end()) {
        if (fParent) {
            return (*fParent)[name];
        }
        return nullptr;
    }
    if (fParent) {
        auto functions = GetFunctions(*entry->second);
        if (functions.size() > 0) {
            bool modified = false;
            const Symbol* previous = (*fParent)[name];
            if (previous) {
                auto previousFunctions = GetFunctions(*previous);
                for (const FunctionDeclaration* prev : previousFunctions) {
                    bool found = false;
                    for (const FunctionDeclaration* current : functions) {
                        if (current->matches(*prev)) {
                            found = true;
                            break;
                        }
                    }
                    if (!found) {
                        functions.push_back(prev);
                        modified = true;
                    }
                }
                if (modified) {
                    ASSERT(functions.size() > 1);
                    return this->takeOwnership(new UnresolvedFunction(functions));
                }
            }
        }
    }
    return entry->second;
}

Symbol* SymbolTable::takeOwnership(Symbol* s) {
    fOwnedPointers.push_back(std::unique_ptr<Symbol>(s));
    return s;
}

void SymbolTable::add(const std::string& name, std::unique_ptr<Symbol> symbol) {
    this->addWithoutOwnership(name, symbol.get());
    fOwnedPointers.push_back(std::move(symbol));
}

void SymbolTable::addWithoutOwnership(const std::string& name, const Symbol* symbol) {
    const auto& existing = fSymbols.find(name);
    if (existing == fSymbols.end()) {
        fSymbols[name] = symbol;
    } else if (symbol->fKind == Symbol::kFunctionDeclaration_Kind) {
        const Symbol* oldSymbol = existing->second;
        if (oldSymbol->fKind == Symbol::kFunctionDeclaration_Kind) {
            std::vector<const FunctionDeclaration*> functions;
            functions.push_back((const FunctionDeclaration*) oldSymbol);
            functions.push_back((const FunctionDeclaration*) symbol);
            UnresolvedFunction* u = new UnresolvedFunction(std::move(functions));
            fSymbols[name] = u;
            this->takeOwnership(u);
        } else if (oldSymbol->fKind == Symbol::kUnresolvedFunction_Kind) {
            std::vector<const FunctionDeclaration*> functions;
            for (const auto* f : ((UnresolvedFunction&) *oldSymbol).fFunctions) {
                functions.push_back(f);
            }
            functions.push_back((const FunctionDeclaration*) symbol);
            UnresolvedFunction* u = new UnresolvedFunction(std::move(functions));
            fSymbols[name] = u;
            this->takeOwnership(u);
        }
    } else {
        fErrorReporter.error(symbol->fPosition, "symbol '" + name + "' was already defined");
    }
}


void SymbolTable::markAllFunctionsBuiltin() {
    for (const auto& pair : fSymbols) {
        switch (pair.second->fKind) {
            case Symbol::kFunctionDeclaration_Kind:
                ((FunctionDeclaration&) *pair.second).fBuiltin = true;
                break;
            case Symbol::kUnresolvedFunction_Kind:
                for (auto& f : ((UnresolvedFunction&) *pair.second).fFunctions) {
                    ((FunctionDeclaration*) f)->fBuiltin = true;
                }
                break;
            default:
                break;
        }
    }
}

} // namespace