0%

ugo-lab7-生成LLVM IR

作用域处理

AST 语法树翻译到 LLVM-IR 需要处理嵌套的词法域问题,在 LLVM-IR 中,通常使用 scope 语句来处理作用域问题

作用域属于语义解析,只有 compiler 包产生 LLVM 汇编代码时需要,因此在 compiler 包定义 Scope 管理词法域:

1
2
3
4
5
6
7
8
9
10
type Scope struct {
Outer *Scope
Objects map[string]*Object
}

type Object struct {
Name string
MangledName string
ast.Node
}
  • Scope 内的 Outer 指向外层的 Scope,比如 main 函数的外层 Scope 是文件
  • Object 表示一个命名对象,其中 Name 是实体在 uGo 代码中词法域的名字,LLName 是映射到 LLVM 汇编语言的名字,其中还有指向的 AST 节点,可以用于识别更多的信息

对应的辅助函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
func NewScope(outer *Scope) *Scope {
return &Scope{
Outer: outer,
Objects: make(map[string]*Object),
}
}

func (s *Scope) HasName(name string) bool {
_, ok := s.Objects[name]
return ok
}

func (s *Scope) Lookup(name string) (*Scope, *Object) {
for ; s != nil; s = s.Outer {
if obj := s.Objects[name]; obj != nil {
return s, obj
}
}
return nil, nil
}

func (s *Scope) Insert(obj *Object) (alt *Object) {
if alt = s.Objects[obj.Name]; alt != nil {
s.Objects[obj.Name] = obj
}
return
}

func (s *Scope) String() string {
var buf bytes.Buffer
fmt.Fprintf(&buf, "scope %p", s)
if s != nil && len(s.Objects) > 0 {
fmt.Fprintln(&buf)
for _, obj := range s.Objects {
fmt.Fprintf(&buf, "\t%T %s\n", obj.Node, obj.Name)
}
}
fmt.Fprintf(&buf, "}\n")
return buf.String()
}

基于 Scope,我们可以构造出 Compiler 类用于描述编译器:

1
2
3
4
5
type Compiler struct {
file *ast.File
scope *Scope
nextId int
}

通过如下两个函数来进入和退出内层的 Scope:

1
2
3
4
5
6
7
8
9
10
11
func (p *Compiler) enterScope() {
p.scope = NewScope(p.scope)
}

func (p *Compiler) leaveScope() {
p.scope = p.scope.Outer
}

func (p *Compiler) restoreScope(scope *Scope) {
p.scope = scope
}
  • 我们需要在编译文件、块语句时进入新的 Scope,以便于存储新词法域定义的新变量

至于全局变量则在编译文件时处理,具体而言就是 compileFile 函数

翻译为 LLVM IR

接下来的工作就是把 AST 翻译为 LLVM IR(这里我把类型检查的操作提前到了语法分析中,因此不需要考虑报错检查的问题)

核心函数如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
func (p *Compiler) genHeader(w io.Writer, file *ast.File) {
fmt.Fprintf(w, "; package %s\n", file.Pkg.Name)
fmt.Fprint(w, Header)
}

func (p *Compiler) genMain(w io.Writer, file *ast.File) {
if file.Pkg.Name != "main" {
return
}
for _, fn := range file.Funcs {
if fn.Name.Name == "main" {
fmt.Fprint(w, MainMain)
return
}
}
}

func (p *Compiler) Run(file *ast.File) string {
var buf bytes.Buffer

p.file = file

p.genHeader(&buf, file)
p.compileFile(&buf, file)
p.genMain(&buf, file)

return buf.String()
}

函数 compileFile 就是 LLVM IR 翻译的起点,其详细代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
func (p *Compiler) compileFile(w io.Writer, file *ast.File) {
defer p.restoreScope(p.scope)
p.enterScope()
for _, g := range file.Globals {
var mangledName = fmt.Sprintf("@ugo_%s_%s", file.Pkg.Name, g.Name.Name)
p.scope.Insert(&Object{
Name: g.Name.Name,
MangledName: mangledName,
Node: g,
})
fmt.Fprintf(w, "%s = global i32 0\n", mangledName)
}

for _, g := range file.Funcs {
var mangledName = fmt.Sprintf("declare i32 @ugo_%s_%s", file.Pkg.Name, g.Name.Name)
p.scope.Insert(&Object{
Name: g.Name.Name,
MangledName: mangledName,
Node: g,
})
fmt.Fprintf(w, "%s(%s)\n", mangledName, p.changeType(g.Type.Name))
}

if len(file.Globals) != 0 {
fmt.Fprintln(w)
}
for _, fn := range file.Funcs {
p.compileFunc(w, file, fn)
}

p.genInit(w, file)
}
  • 在函数开始之前执行 enterScope 代表进入新的命名空间
  • 在函数结束时执行 restoreScope 退出到上一级的命名空间
  • 对于一个 ugo 代码而言,File 层就是最外层的命名空间,函数和全局变量都在此记录

下面是对函数的处理:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
func (p *Compiler) compileFunc(w io.Writer, file *ast.File, fn *ast.FuncDecl) {
defer p.restoreScope(p.scope)
p.enterScope()

var mangledName = fmt.Sprintf("@ugo_%s_%s", file.Pkg.Name, fn.Name.Name)

p.scope.Insert(&Object{
Name: fn.Name.Name,
MangledName: mangledName,
Node: fn,
})

if fn.Body == nil {
fmt.Fprintf(w, "declare i32 @ugo_%s_%s()\n", file.Pkg.Name, fn.Name.Name)
return
}
fmt.Fprintln(w)

fmt.Fprintf(w, "define i32 @ugo_%s_%s() {\n", file.Pkg.Name, fn.Name.Name)
p.compileStmt(w, fn.Body)
fmt.Fprintln(w, "\tret i32 0")
fmt.Fprintln(w, "}")
}
  • 每一个函数都要使用 scope.Insert 在当前命名空间中注册函数符号
  • 基础的处理过后就交给 compileStmt 处理语句块

函数 compileStmt 的核心步骤就是处理如下几种语句:

  • VarSpec:变量定义
  • AssignStmt:赋值语句
  • BlockStmt:语句块
  • ExprStmt:表达式语句
  • RetStmt:返回语句
  • IfStmt:if 语句
  • ForStmt:for 语句

除了 if 语句,for 语句和表达式语句的实现代码都比较简单,这里直接给出:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
case *ast.VarSpec:
var localName = "0"
if stmt.Value != nil {
localName = p.compileExpr(w, stmt.Value)
}

var mangledName = fmt.Sprintf("%%local_%s.pos.%d", stmt.Name.Name, stmt.VarPos)
p.scope.Insert(&Object{
Name: stmt.Name.Name,
MangledName: mangledName,
Node: stmt,
})

fmt.Fprintf(w, "\t%s = alloca i32, align 4\n", mangledName)
fmt.Fprintf(
w, "\tstore i32 %s, i32* %s\n",
localName, mangledName,
)

case *ast.AssignStmt:
p.compileStmt_assign(w, stmt)

case *ast.BlockStmt:
defer p.restoreScope(p.scope)
p.enterScope()

for _, x := range stmt.List {
p.compileStmt(w, x)
}

case *ast.ExprStmt:
p.compileExpr(w, stmt.X)

case *ast.RetStmt:
var retValue = p.compileExpr(w, stmt.Expr)
fmt.Fprintf(w, "\tret i32 %s\n", retValue)

处理 if 语句时需要考虑各个部分的作用域范围,因此使用匿名函数立即执行的写法,在匿名函数中添加 enterScope 和 restoreScope 来控制作用域范围

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
case *ast.IfStmt:
defer p.restoreScope(p.scope)
p.enterScope()

ifPos := fmt.Sprintf("%d", p.posLine())
ifInit := p.genLabelId("if.init.line" + ifPos)
ifCond := p.genLabelId("if.cond.line" + ifPos)
ifBody := p.genLabelId("if.body.line" + ifPos)
ifElse := p.genLabelId("if.else.line" + ifPos)
ifEnd := p.genLabelId("if.end.line" + ifPos)

// br if.init
fmt.Fprintf(w, "\tbr label %%%s\n", ifInit)

// if.init
fmt.Fprintf(w, "\n%s:\n", ifInit)
func() {
defer p.restoreScope(p.scope)
p.enterScope()

if stmt.Init != nil {
p.compileStmt(w, stmt.Init)
fmt.Fprintf(w, "\tbr label %%%s\n", ifCond)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", ifCond)
}

// if.cond
{
fmt.Fprintf(w, "\n%s:\n", ifCond)
condValue := p.compileExpr(w, stmt.Cond)
if stmt.Else != nil {
fmt.Fprintf(w, "\tbr i1 %s , label %%%s, label %%%s\n", condValue, ifBody, ifElse)
} else {
fmt.Fprintf(w, "\tbr i1 %s , label %%%s, label %%%s\n", condValue, ifBody, ifEnd)
}
}

// if.body
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", ifBody)
if stmt.Else != nil {
p.compileStmt(w, stmt.Body)
fmt.Fprintf(w, "\tbr label %%%s\n", ifElse)
} else {
p.compileStmt(w, stmt.Body)
fmt.Fprintf(w, "\tbr label %%%s\n", ifEnd)
}
}()

// if.else
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", ifElse)
if stmt.Else != nil {
p.compileStmt(w, stmt.Else)
fmt.Fprintf(w, "\tbr label %%%s\n", ifEnd)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", ifEnd)
}
}()
}()

// end
fmt.Fprintf(w, "\n%s:\n", ifEnd)

处理 for 语句时也是同样的思路:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
case *ast.ForStmt:
defer p.restoreScope(p.scope)
p.enterScope()

forPos := fmt.Sprintf("%d", p.posLine())
forInit := p.genLabelId("for.init.line" + forPos)
forCond := p.genLabelId("for.cond.line" + forPos)
forPost := p.genLabelId("for.post.line" + forPos)
forBody := p.genLabelId("for.body.line" + forPos)
forEnd := p.genLabelId("for.end.line" + forPos)

// br for.init
fmt.Fprintf(w, "\tbr label %%%s\n", forInit)

// for.init
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", forInit)
if stmt.Init != nil {
p.compileStmt(w, stmt.Init)
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
}

// for.cond
fmt.Fprintf(w, "\n%s:\n", forCond)
if stmt.Cond != nil {
condValue := p.compileExpr(w, stmt.Cond)
fmt.Fprintf(w, "\tbr i1 %s , label %%%s, label %%%s\n", condValue, forBody, forEnd)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", forBody)
}

// for.body
func() {
defer p.restoreScope(p.scope)
p.enterScope()

fmt.Fprintf(w, "\n%s:\n", forBody)
p.compileStmt(w, stmt.Body)
fmt.Fprintf(w, "\tbr label %%%s\n", forPost)
}()

// for.post
{
fmt.Fprintf(w, "\n%s:\n", forPost)
if stmt.Post != nil {
p.compileStmt(w, stmt.Post)
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
} else {
fmt.Fprintf(w, "\tbr label %%%s\n", forCond)
}
}
}()

// end
fmt.Fprintf(w, "\n%s:\n", forEnd)

最后是表达式语句的处理,具体而言就是 compileExpr 函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
func (p *Compiler) compileExpr(w io.Writer, expr ast.Expr) (localName string) {
switch expr := expr.(type) {
case *ast.CallExpr:
if _, obj := p.scope.Lookup(expr.FuncName.Name); obj != nil {
localName = p.genId()
varStr := fmt.Sprintf("\t%s = call i32(i32) @ugo_main_%s(",
localName, expr.FuncName.Name,
)
for i, arg := range expr.Args {
if i == 0 {
varStr += p.changeType(obj.Node.(*ast.FuncDecl).Type.Name) + " " + p.compileExpr(w, arg)
} else {
varStr += ", " + p.changeType(obj.Node.(*ast.FuncDecl).Type.Name) + " " + p.compileExpr(w, arg)
}

}
varStr += ")\n"

fmt.Fprint(w, varStr)

} else {
panic(fmt.Sprintf("func %s undefined", expr.FuncName.Name))
}
return localName
case *ast.Ident:
var varName string
if _, obj := p.scope.Lookup(expr.Name); obj != nil {
varName = obj.MangledName
} else {
panic(fmt.Sprintf("var %s undefined", expr.Name))
}

localName = p.genId()
fmt.Fprintf(w, "\t%s = load i32, i32* %s, align 4\n",
localName, varName,
)
return localName
case *ast.IdentAS:
var varName string
if _, obj := p.scope.Lookup(expr.Name); obj != nil {
varName = obj.MangledName

call, ok := expr.Offset.(*ast.CallExpr)
if ok {
varName = p.compileExpr(w, call)
}
} else {
panic(fmt.Sprintf("var %s undefined", expr.Name))
}

localName = p.genId()
fmt.Fprintf(w, "\t%s = load i32, i32* %s, align 4\n",
localName, varName,
)
return localName
case *ast.Number:
localName = p.genId()
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "add", `0`, expr.Value,
)
return localName
case *ast.BinaryExpr:
localName = p.genId()
switch expr.Op {
case token.ADD:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "add", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.SUB:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "sub", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.MUL:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "mul", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.DIV:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "div", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.MOD:
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "srem", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.EQL: // ==
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp eq", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.NEQ: // !=
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp ne", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.LSS: // <
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp slt", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.LEQ: // <=
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp sle", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.GTR: // >
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp sgt", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
case token.GEQ: // >=
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "icmp sge", p.compileExpr(w, expr.X), p.compileExpr(w, expr.Y),
)
return localName
default:
panic(fmt.Sprintf("unknown: %[1]T, %[1]v", expr))
}
case *ast.UnaryExpr:
if expr.Op == token.SUB {
localName = p.genId()
fmt.Fprintf(w, "\t%s = %s i32 %v, %v\n",
localName, "sub", `0`, p.compileExpr(w, expr.X),
)
return localName
}
return p.compileExpr(w, expr.X)
case *ast.ParenExpr:
return p.compileExpr(w, expr.X)

default:
panic(fmt.Sprintf("unknown: %[1]T, %[1]v", expr))
}
}